fix(grpo): count each rollout once under fan-out#1325
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the group_relative_advantages utility to compute GRPO group-relative advantages using the rollout as the unit, integrating it into the training data conversion pipeline and adding fast unit tests. The reviewer identified critical issues in the implementation, including potential NaN values during standard deviation normalization when the group size is 1, silent grouping of None indices, and a lack of input validation for empty lists or mismatched lengths. To address these, the reviewer recommended adding robust validation checks, using biased standard deviation, and expanding the test suite to cover these edge cases.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| # Reduce to one reward per rollout, preserving first-occurrence (prompt-major) order. | ||
| # Segments of a rollout share its outcome reward; take the first occurrence. Dict insertion | ||
| # order (3.7+) keeps the group_ids in prompt-major order, so no parallel order list is needed. | ||
| reward_of_group: dict = {} | ||
| for gid, reward in zip(group_ids, raw_rewards, strict=True): | ||
| reward_of_group.setdefault(gid, reward) | ||
|
|
||
| group_rewards = torch.tensor(list(reward_of_group.values()), dtype=torch.float) | ||
| if group_rewards.numel() == n_samples_per_prompt * rollout_batch_size: | ||
| grouped = group_rewards.reshape(-1, n_samples_per_prompt) | ||
| else: | ||
| # uneven rollout counts per prompt (e.g. dynamic sampling): one global group | ||
| grouped = group_rewards.view(-1, group_rewards.shape[-1]) | ||
| grouped = grouped - grouped.mean(dim=-1, keepdim=True) | ||
| if std_normalization: | ||
| grouped = grouped / (grouped.std(dim=-1, keepdim=True) + 1e-6) | ||
|
|
||
| advantage_of_group = dict(zip(reward_of_group.keys(), grouped.flatten().tolist(), strict=True)) | ||
| return [advantage_of_group[gid] for gid in group_ids] |
There was a problem hiding this comment.
There are four critical issues in the current implementation of group_relative_advantages:
- List length mismatch: We should validate that the lengths of
group_idsandraw_rewardsmatch up front to prevent silent mismatches or broadcast failures in RL pipelines. - Silent grouping of
Noneindices: Python'sdictallowsNoneas a valid key. If anysample.indexisNone(which is the default value inSample), they will all be silently grouped together under the keyNoneinstead of failing loudly as intended. We should explicitly check forNoneingroup_idsand raise aValueError. - NaN values with standard deviation normalization: If
std_normalizationisTrueand the group size is 1 (e.g., whenn_samples_per_promptis 1 or falling back to a single global group of size 1),grouped.std(dim=-1, keepdim=True)will returnNaNbecause the default isunbiased=True(division byN-1 = 0). This will propagateNaNvalues to all advantages and break training. Usingunbiased=False(orcorrection=0) avoids this issue. - Empty inputs: If
raw_rewardsis empty, the function should return early to avoid PyTorch operations on empty tensors.
if len(group_ids) != len(raw_rewards):
raise ValueError(f"Length mismatch: group_ids has length {len(group_ids)}, but raw_rewards has length {len(raw_rewards)}.")
if not raw_rewards:
return []
if any(gid is None for gid in group_ids):
raise ValueError("group_ids contains None. All samples must have a valid, non-None group_id (sample.index).")
# Reduce to one reward per rollout, preserving first-occurrence (prompt-major) order.
# Segments of a rollout share its outcome reward; take the first occurrence. Dict insertion
# order (3.7+) keeps the group_ids in prompt-major order, so no parallel order list is needed.
reward_of_group: dict = {}
for gid, reward in zip(group_ids, raw_rewards, strict=True):
reward_of_group.setdefault(gid, reward)
group_rewards = torch.tensor(list(reward_of_group.values()), dtype=torch.float)
if group_rewards.numel() == n_samples_per_prompt * rollout_batch_size:
grouped = group_rewards.reshape(-1, n_samples_per_prompt)
else:
# uneven rollout counts per prompt (e.g. dynamic sampling): one global group
grouped = group_rewards.view(-1, group_rewards.shape[-1])
grouped = grouped - grouped.mean(dim=-1, keepdim=True)
if std_normalization:
# Use unbiased=False to avoid NaN when the group size is 1
grouped = grouped / (grouped.std(dim=-1, keepdim=True, unbiased=False) + 1e-6)
advantage_of_group = dict(zip(reward_of_group.keys(), grouped.flatten().tolist(), strict=True))
return [advantage_of_group[gid] for gid in group_ids]References
- When processing lists of sequence-level or token-level tensors (such as advantages, student log probabilities, and teacher log probabilities) in RL or distillation pipelines, validate that the list lengths match up front, and perform per-sample shape checks to prevent silent mismatches or broadcast failures (e.g., GRPO-style scalar-broadcast traps).
There was a problem hiding this comment.
Thanks for the thorough pass. I went through all four against the actual code and the upstream original, and I'm going to keep the function as-is. Walking through each:
(1) Validate len(group_ids) == len(raw_rewards) up front — already enforced. The reduction loop uses zip(group_ids, raw_rewards, strict=True) (math_utils.py:448), and the broadcast uses zip(reward_of_group.keys(), grouped.flatten().tolist(), strict=True) (line 461). strict=True raises ValueError on any length mismatch at the exact point of use, so an explicit up-front len(...) check would be redundant with the same exception type and message intent.
(2) None group_ids silently grouped — None cannot reach this function on any real path, so a ValueError guard would be dead code, and the design decision deliberately lives upstream. The call site is group_ids = [sample.index for sample in samples] (train_data_conversion.py:114). Before samples ever reach _post_process_rewards, the rollout sorts them by sample.index with no None guard: data = sorted(data, key=lambda group: ... group[0].index) (sglang_rollout.py:455, train path) and data.sort(key=lambda sample: sample.index) (sglang_rollout.py:590). A None index makes that sort raise TypeError: '<' not supported between instances of 'NoneType' and 'int' first. On every real train path sample.index is assigned (data_source.py:109: sample.index = self.sample_index). This is intentional fail-loud behavior documented in the call-site comment (train_data_conversion.py:111-113): a custom rollout that leaves index unset crashes at the sort rather than a positional fallback silently merging two unrelated groups. Adding a guard here would only mask that upstream contract.
(3) torch.std unbiased=False to avoid group-size-1 NaN — this would change numerics on the default path and is out of scope. Upstream uses rewards.std(dim=-1, keepdim=True) i.e. unbiased=True (Bessel-corrected) — confirmed via git show upstream/main:miles/ray/rollout/train_data_conversion.py. This PR's whole contract is that the default 1-sample-per-rollout path stays bit-identical to upstream group norm; switching to unbiased=False rescales every std-normalized advantage by sqrt((n-1)/n), which is a behavior change for all GRPO/GSPO std-normalized training, not a safe robustness fix. The singleton-group (group size 1) NaN is pre-existing upstream behavior and orthogonal to this fix. The existing test test_std_normalization_divides_by_per_prompt_group_std even pins the unbiased result (+-1/sqrt(2), see the comment at line 60), so flipping the flag would also break a deliberately-pinned numeric.
(4) Early-return on empty raw_rewards — upstream crashes identically on empty input (torch.tensor([]).view(-1, 0) raises RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0]), which I reproduced against both the upstream original and this helper. In practice the upstream assert len(data) == args.rollout_batch_size (sglang_rollout.py:453) guarantees a non-empty batch, so this path is unreachable. Adding an early return would change behavior (crash -> silent empty return) versus upstream with no real caller benefiting, so I'd rather not introduce that divergence in a bit-identity-scoped fix.
| def test_uneven_rollout_counts_fall_back_to_one_global_group(): | ||
| # 3 rollouts when 4 are expected -> single global group (documented fallback). | ||
| advantages = group_relative_advantages( | ||
| [1.0, 2.0, 3.0], | ||
| group_ids=[0, 1, 2], | ||
| n_samples_per_prompt=2, | ||
| rollout_batch_size=2, | ||
| std_normalization=False, | ||
| ) | ||
| assert advantages == pytest.approx([-1.0, 0.0, 1.0]) |
There was a problem hiding this comment.
Add unit tests to verify that group_relative_advantages raises a ValueError when group_ids contains None, and that it does not produce NaN when std_normalization is True and the group size is 1.
def test_uneven_rollout_counts_fall_back_to_one_global_group():
# 3 rollouts when 4 are expected -> single global group (documented fallback).
advantages = group_relative_advantages(
[1.0, 2.0, 3.0],
group_ids=[0, 1, 2],
n_samples_per_prompt=2,
rollout_batch_size=2,
std_normalization=False,
)
assert advantages == pytest.approx([-1.0, 0.0, 1.0])
def test_group_ids_containing_none_raises_value_error():
with pytest.raises(ValueError, match="group_ids contains None"):
group_relative_advantages(
[1.0, 2.0],
group_ids=[None, 1],
n_samples_per_prompt=2,
rollout_batch_size=2,
std_normalization=False,
)
def test_std_normalization_with_group_size_one_does_not_nan():
advantages = group_relative_advantages(
[5.0],
group_ids=[0],
n_samples_per_prompt=1,
rollout_batch_size=1,
std_normalization=True,
)
assert advantages == [0.0]There was a problem hiding this comment.
Following from the responses above: the two new tests suggested here are tied to concerns (2) and (3), both of which I'm keeping as-is, so a None -> ValueError test would assert behavior the function intentionally doesn't own (the failure is the upstream sort, not this helper), and a std/group-size-1 no-NaN test would lock in the unbiased=False numeric change that breaks the default-path bit-identity invariant.
The one test you flagged as genuinely useful — uneven rollout counts collapsing to a single global group — is already covered: test_uneven_rollout_counts_fall_back_to_one_global_group (test_group_relative_advantages.py:64-73) passes 3 rollouts when 4 are expected and asserts the documented global-group fallback [-1.0, 0.0, 1.0]. The suite also already pins the default-path == plain per-prompt group norm identity (line 19), the fan-out-counts-once contract (line 32), and the per-prompt unbiased std division (line 51), which together cover the cases worth pinning for this change.
2ee1e80 to
742bcf2
Compare
GRPO compares attempts per prompt: each rollout should contribute its single outcome reward once to the per-prompt baseline. But an agentic rollout can fan into several training segments -- multi_turn / agentic_tool_call return a list[Sample] of deepcopies that share one Sample.index (compaction / sub-agent / fork branches of one attempt). The old per-prompt group norm reshaped the flat sample reward vector by n_samples_per_prompt: any fan-out makes the flat length exceed n_samples_per_prompt * rollout_batch_size, so the reshape guard falls back to a single global group. That collapses all prompts into one baseline and destroys per-prompt centering -- the exact signal GRPO relies on -- and lets a rollout that happens to emit more segments dominate the advantage. Fix: make the *rollout* the unit and the *prompt id* the group key. A new pure helper, group_relative_advantages (math_utils), reduces raw rewards to one reward per rollout id (Sample.index), buckets the rollouts by prompt id (Sample.group_index) when that metadata is present, centers each bucket on its mean (optionally divides by the unbiased std + 1e-6), and broadcasts each rollout's advantage back to all of its segments. Without prompt metadata (custom rollout paths) the legacy positional grouping is preserved verbatim: reshape by n_samples_per_prompt when counts match, else one global group. Fail loud on missing rollout ids: Sample.index defaults to None, so an unset index would silently merge every such sample into a single rollout and zero out its advantages; the helper raises ValueError instead. Behavior fix for uneven batches: the data source sets group_index on every real path, so dynamic-sampling / uneven-count batches now get true per-prompt baselines instead of the old single-global-group fallback. On the rigid default path (every prompt exactly n_samples_per_prompt single-sample rollouts) the buckets equal the legacy reshape rows and the result is bit-identical to the previous reshape/mean/std group norm -- non-agentic training is unchanged. CI: adds tests/fast/backends/training_utils/loss/test_group_relative_advantages.py (CPU-only, torch on CPU, no GPU); it auto-registers into the stage-a-cpu suite by the tests/fast/ directory convention. The tests pin: rigid-path bit-identity vs the legacy computation; a fanned rollout counts once and its segments share the advantage; per-prompt std division; uneven groups get per-prompt baselines; missing prompt metadata falls back to the legacy global group; None rollout ids raise; a rollout id spanning two prompts raises. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
742bcf2 to
fbe308b
Compare
Problem
GRPO compares attempts within a prompt, so each rollout must contribute its single outcome reward exactly once to its own prompt's baseline.
_post_process_rewards(inmiles/ray/rollout/train_data_conversion.py) builds a flat reward vector over samples and reshapes it positionally:That breaks the invariant in three ways:
multi_turn/agentic_tool_callrollouts fan one agent attempt into several training samples (the compaction / sub-agent / fork branches), returned as alist[Sample]ofdeepcopysegments that all share the sameSample.index. Any fan-out makes the flat sample count exceedn_samples_per_prompt * rollout_batch_size, so the guard collapses every prompt into a single global group, and a rollout that happens to emit more segments is counted once per segment in that baseline.Sample.indexhas to confront thatSample.indexdefaults toNone. A rollout path that never assigns it would feed all-Noneids into the dedup, silently merging every sample into one "rollout": one reward survives, the group mean equals it, and every advantage comes out exactly0.0— training quietly learns nothing. The id contract must fail loudly, not fall back.Sample.group_index(assigned by the data source alongsideindex, and already used by the rollout metrics to group samples per prompt).Before vs After
Take a batch of
n_samples_per_prompt=2,rollout_batch_size=2(so 4 rollouts, 2 prompts), where rollout1fans into 2 segments that repeat its outcome reward3.0:Before — 5 samples !=
2 * 2 = 4, so the guard collapses all 5 into one global group and subtracts the global mean4.6:Prompt structure is gone, and rollout
1's reward is counted twice in the baseline.After — reduce to one reward per rollout
[1, 3, 5, 11], bucket by prompt[[1, 3], [5, 11]], center within each prompt, then broadcast each rollout's advantage back to its segments:Rollout
1counts once (its two segments both carry+1.0), and prompt 0's mean is unaffected by the fan-out. Pinned bytests/fast/.../test_group_relative_advantages.py::test_fanned_rollout_counts_once_and_segments_share_the_advantage.Fix
Make the rollout the unit of the baseline and the prompt id the group key. A new pure helper
group_relative_advantages(raw_rewards, rollout_ids, prompt_ids, *, ...)(inloss_hub/math_utils.py):ValueErrorif any rollout id isNone— the silent-collapse hazard above, named in the error message;None): buckets the deduped rollouts by prompt id, validating each rollout id maps to exactly one prompt (raises on conflict), centers each bucket on its mean and optionally divides by the unbiased std +1e-6— the identical ops/eps/dtype as the old reshape rows; otherwise (metadata absent — custom rollout paths): the legacy positional behavior verbatim, reshape when counts match, else one global group;_post_process_rewardspassesrollout_ids = [s.index for s in samples]andprompt_ids = [s.group_index for s in samples](collapsed toNoneif any sample lacks it).Why this is the right fix
n_samples_per_promptsingle-sample rollouts, prompt-major order) the prompt buckets are exactly the old reshape rows, so non-agentic training is unchanged —test_rigid_input_is_bit_identical_to_legacy_reshape_group_normasserts exact float equality against the previous inline computation, std normalization included.test_none_rollout_id_raises), and a rollout id spanning two prompts raises instead of being mis-bucketed (test_rollout_id_spanning_two_prompts_raises).group_indexon every real path, so dynamic-sampling / uneven-count batches now get true per-prompt baselines instead of the old global-mean fallback —test_uneven_groups_get_per_prompt_baselinespins the per-prompt centering and that it differs from the old global result.group_indexthe helper reproduces the legacy positional grouping verbatim, including the documented global-group fallback (test_without_prompt_ids_uneven_counts_fall_back_to_one_global_group).tests/fast/, so the collector auto-assigns them to thestage-a-cpusuite with no workflow change.