GRPOMetric

The GRPOMetric tracks policy optimization diagnostics during GRPO training, including KL divergence, clipping rates, entropy, and log-probability statistics.

Usage

from twinkle.metric import GRPOMetric

metric = GRPOMetric(
    device_mesh=device_mesh,
    process_group=process_group,
    epsilon=0.2,          # PPO clip range
    temperature=1.0,      # Sampling temperature for logp rescaling
    top_k_kl=10,          # Track top-K high-KL tokens per step
)

# During training loop
metric.accumulate(inputs, outputs, old_logps=old_logps, advantages=advantages)

# At log interval
results = metric.calculate()
# results: {
#   'train/policy_confidence': 0.85,
#   'train/mean_new_logp': -1.23,
#   'train/mean_old_logp': -1.30,
#   'train/logp_diff_mean': 0.07,
#   'train/approx_kl': 0.003,
#   'train/token_kl_max': 0.15,
#   'train/entropy': 2.1,
#   'train/clip_ratio': 0.02,
#   'train/clip_ratio_low': 0.01,
#   'train/clip_ratio_high': 0.01,
# }

Reported Metrics

MetricDescription
train/policy_confidenceexp(mean_new_logp) — higher means model is more confident
train/mean_new_logpAverage log-probability of generated tokens under current policy
train/mean_old_logpAverage log-probability under reference policy
train/logp_diff_meanMean (new - old) log-probability difference
train/approx_klSchulman K3 estimator of KL(old || new)
train/token_kl_maxMaximum per-token KL across all ranks
train/token_ratio_maxMaximum importance weight across all ranks
train/entropyAverage token-level entropy
train/clip_ratioFraction of tokens clipped (low + high)
train/clip_ratio_lowFraction clipped below (ratio < 1-ε, negative advantage)
train/clip_ratio_highFraction clipped above (ratio > 1+ε, positive advantage)

Variants

  • GSPOMetric — Computes clip rate at sequence level (geometric-mean ratio per sequence)
  • CISPOMetric — Unconditional clip rate (not gated by advantage sign)

Parameters

ParameterTypeDefaultDescription
epsilonfloat0.2Lower clip boundary
epsilon_highfloatNoneUpper clip boundary (defaults to epsilon)
temperaturefloat1.0Rescale logps to T=1 before computing KL
top_k_klint0If > 0, record top-K high-KL token details
ignore_indexint-100Label value to mask out
docs