Skip to content

fix(grpo): count each rollout once under fan-out#1325

Open
EazyReal wants to merge 1 commit into
radixark:mainfrom
EazyReal:upstream-pr/grpo-rollout-baseline
Open

fix(grpo): count each rollout once under fan-out#1325
EazyReal wants to merge 1 commit into
radixark:mainfrom
EazyReal:upstream-pr/grpo-rollout-baseline

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 12, 2026

Copy link
Copy Markdown

Problem

GRPO compares attempts within a prompt, so each rollout must contribute its single outcome reward exactly once to its own prompt's baseline. _post_process_rewards (in miles/ray/rollout/train_data_conversion.py) builds a flat reward vector over samples and reshapes it positionally:

rewards = torch.tensor(raw_rewards, dtype=torch.float)
if rewards.shape[-1] == args.n_samples_per_prompt * args.rollout_batch_size:
    rewards = rewards.reshape(-1, args.n_samples_per_prompt)
else:
    # when sample count is not equal in each group
    rewards = rewards.view(-1, rewards.shape[-1])  # one global group

That breaks the invariant in three ways:

  1. Fan-out double counting. multi_turn / agentic_tool_call rollouts fan one agent attempt into several training samples (the compaction / sub-agent / fork branches), returned as a list[Sample] of deepcopy segments that all share the same Sample.index. Any fan-out makes the flat sample count exceed n_samples_per_prompt * rollout_batch_size, so the guard collapses every prompt into a single global group, and a rollout that happens to emit more segments is counted once per segment in that baseline.
  2. Silent None collapse. Any fix that re-keys the baseline on Sample.index has to confront that Sample.index defaults to None. A rollout path that never assigns it would feed all-None ids into the dedup, silently merging every sample into one "rollout": one reward survives, the group mean equals it, and every advantage comes out exactly 0.0 — training quietly learns nothing. The id contract must fail loudly, not fall back.
  3. Global-mean fallback for uneven groups. When rollout counts per prompt are genuinely uneven (e.g. dynamic sampling), position alone cannot recover the prompt structure, so the old code could only center everything against one global mean — even though every sample already carries its prompt id in Sample.group_index (assigned by the data source alongside index, and already used by the rollout metrics to group samples per prompt).

Before vs After

Take a batch of n_samples_per_prompt=2, rollout_batch_size=2 (so 4 rollouts, 2 prompts), where rollout 1 fans into 2 segments that repeat its outcome reward 3.0:

raw_rewards = [1.0, 3.0, 3.0, 5.0, 11.0]
rollout_ids = [  0,   1,   1,   2,    3]   # Sample.index; rollout 1 emitted 2 segments
prompt_ids  = [  0,   0,   0,   1,    1]   # Sample.group_index

Before — 5 samples != 2 * 2 = 4, so the guard collapses all 5 into one global group and subtracts the global mean 4.6:

advantages = [-3.6, -1.6, -1.6, 0.4, 6.4]

Prompt structure is gone, and rollout 1's reward is counted twice in the baseline.

After — reduce to one reward per rollout [1, 3, 5, 11], bucket by prompt [[1, 3], [5, 11]], center within each prompt, then broadcast each rollout's advantage back to its segments:

advantages = [-1.0, 1.0, 1.0, -3.0, 3.0]

Rollout 1 counts once (its two segments both carry +1.0), and prompt 0's mean is unaffected by the fan-out. Pinned by tests/fast/.../test_group_relative_advantages.py::test_fanned_rollout_counts_once_and_segments_share_the_advantage.

Fix

Make the rollout the unit of the baseline and the prompt id the group key. A new pure helper group_relative_advantages(raw_rewards, rollout_ids, prompt_ids, *, ...) (in loss_hub/math_utils.py):

  1. raises ValueError if any rollout id is None — the silent-collapse hazard above, named in the error message;
  2. reduces raw rewards to one reward per rollout id, keeping the first occurrence (dict insertion order preserves prompt-major sample order);
  3. when prompt ids are present (every value non-None): buckets the deduped rollouts by prompt id, validating each rollout id maps to exactly one prompt (raises on conflict), centers each bucket on its mean and optionally divides by the unbiased std + 1e-6 — the identical ops/eps/dtype as the old reshape rows; otherwise (metadata absent — custom rollout paths): the legacy positional behavior verbatim, reshape when counts match, else one global group;
  4. broadcasts each rollout's advantage back to all of its segments.

_post_process_rewards passes rollout_ids = [s.index for s in samples] and prompt_ids = [s.group_index for s in samples] (collapsed to None if any sample lacks it).

Why this is the right fix

  • Bit-identical default path. On the rigid path (every prompt exactly n_samples_per_prompt single-sample rollouts, prompt-major order) the prompt buckets are exactly the old reshape rows, so non-agentic training is unchanged — test_rigid_input_is_bit_identical_to_legacy_reshape_group_norm asserts exact float equality against the previous inline computation, std normalization included.
  • Fails loudly, never silently wrong. Missing rollout ids raise instead of zeroing every advantage (test_none_rollout_id_raises), and a rollout id spanning two prompts raises instead of being mis-bucketed (test_rollout_id_spanning_two_prompts_raises).
  • Deliberate behavior fix for uneven batches. The data source assigns group_index on every real path, so dynamic-sampling / uneven-count batches now get true per-prompt baselines instead of the old global-mean fallback — test_uneven_groups_get_per_prompt_baselines pins the per-prompt centering and that it differs from the old global result.
  • Metadata-less custom rollouts see zero change. Without group_index the helper reproduces the legacy positional grouping verbatim, including the documented global-group fallback (test_without_prompt_ids_uneven_counts_fall_back_to_one_global_group).
  • Minimal, CI-verifiable, no new abstraction. One pure function plus a small call-site change; the new tests are CPU-only (torch on CPU, no GPU) and live under tests/fast/, so the collector auto-assigns them to the stage-a-cpu suite with no workflow change.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the group_relative_advantages utility to compute GRPO group-relative advantages using the rollout as the unit, integrating it into the training data conversion pipeline and adding fast unit tests. The reviewer identified critical issues in the implementation, including potential NaN values during standard deviation normalization when the group size is 1, silent grouping of None indices, and a lack of input validation for empty lists or mismatched lengths. To address these, the reviewer recommended adding robust validation checks, using biased standard deviation, and expanding the test suite to cover these edge cases.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +444 to +462
# Reduce to one reward per rollout, preserving first-occurrence (prompt-major) order.
# Segments of a rollout share its outcome reward; take the first occurrence. Dict insertion
# order (3.7+) keeps the group_ids in prompt-major order, so no parallel order list is needed.
reward_of_group: dict = {}
for gid, reward in zip(group_ids, raw_rewards, strict=True):
reward_of_group.setdefault(gid, reward)

group_rewards = torch.tensor(list(reward_of_group.values()), dtype=torch.float)
if group_rewards.numel() == n_samples_per_prompt * rollout_batch_size:
grouped = group_rewards.reshape(-1, n_samples_per_prompt)
else:
# uneven rollout counts per prompt (e.g. dynamic sampling): one global group
grouped = group_rewards.view(-1, group_rewards.shape[-1])
grouped = grouped - grouped.mean(dim=-1, keepdim=True)
if std_normalization:
grouped = grouped / (grouped.std(dim=-1, keepdim=True) + 1e-6)

advantage_of_group = dict(zip(reward_of_group.keys(), grouped.flatten().tolist(), strict=True))
return [advantage_of_group[gid] for gid in group_ids]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are four critical issues in the current implementation of group_relative_advantages:

  1. List length mismatch: We should validate that the lengths of group_ids and raw_rewards match up front to prevent silent mismatches or broadcast failures in RL pipelines.
  2. Silent grouping of None indices: Python's dict allows None as a valid key. If any sample.index is None (which is the default value in Sample), they will all be silently grouped together under the key None instead of failing loudly as intended. We should explicitly check for None in group_ids and raise a ValueError.
  3. NaN values with standard deviation normalization: If std_normalization is True and the group size is 1 (e.g., when n_samples_per_prompt is 1 or falling back to a single global group of size 1), grouped.std(dim=-1, keepdim=True) will return NaN because the default is unbiased=True (division by N-1 = 0). This will propagate NaN values to all advantages and break training. Using unbiased=False (or correction=0) avoids this issue.
  4. Empty inputs: If raw_rewards is empty, the function should return early to avoid PyTorch operations on empty tensors.
    if len(group_ids) != len(raw_rewards):
        raise ValueError(f"Length mismatch: group_ids has length {len(group_ids)}, but raw_rewards has length {len(raw_rewards)}.")

    if not raw_rewards:
        return []

    if any(gid is None for gid in group_ids):
        raise ValueError("group_ids contains None. All samples must have a valid, non-None group_id (sample.index).")

    # Reduce to one reward per rollout, preserving first-occurrence (prompt-major) order.
    # Segments of a rollout share its outcome reward; take the first occurrence. Dict insertion
    # order (3.7+) keeps the group_ids in prompt-major order, so no parallel order list is needed.
    reward_of_group: dict = {}
    for gid, reward in zip(group_ids, raw_rewards, strict=True):
        reward_of_group.setdefault(gid, reward)

    group_rewards = torch.tensor(list(reward_of_group.values()), dtype=torch.float)
    if group_rewards.numel() == n_samples_per_prompt * rollout_batch_size:
        grouped = group_rewards.reshape(-1, n_samples_per_prompt)
    else:
        # uneven rollout counts per prompt (e.g. dynamic sampling): one global group
        grouped = group_rewards.view(-1, group_rewards.shape[-1])
    grouped = grouped - grouped.mean(dim=-1, keepdim=True)
    if std_normalization:
        # Use unbiased=False to avoid NaN when the group size is 1
        grouped = grouped / (grouped.std(dim=-1, keepdim=True, unbiased=False) + 1e-6)

    advantage_of_group = dict(zip(reward_of_group.keys(), grouped.flatten().tolist(), strict=True))
    return [advantage_of_group[gid] for gid in group_ids]
References
  1. When processing lists of sequence-level or token-level tensors (such as advantages, student log probabilities, and teacher log probabilities) in RL or distillation pipelines, validate that the list lengths match up front, and perform per-sample shape checks to prevent silent mismatches or broadcast failures (e.g., GRPO-style scalar-broadcast traps).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the thorough pass. I went through all four against the actual code and the upstream original, and I'm going to keep the function as-is. Walking through each:

(1) Validate len(group_ids) == len(raw_rewards) up front — already enforced. The reduction loop uses zip(group_ids, raw_rewards, strict=True) (math_utils.py:448), and the broadcast uses zip(reward_of_group.keys(), grouped.flatten().tolist(), strict=True) (line 461). strict=True raises ValueError on any length mismatch at the exact point of use, so an explicit up-front len(...) check would be redundant with the same exception type and message intent.

(2) None group_ids silently groupedNone cannot reach this function on any real path, so a ValueError guard would be dead code, and the design decision deliberately lives upstream. The call site is group_ids = [sample.index for sample in samples] (train_data_conversion.py:114). Before samples ever reach _post_process_rewards, the rollout sorts them by sample.index with no None guard: data = sorted(data, key=lambda group: ... group[0].index) (sglang_rollout.py:455, train path) and data.sort(key=lambda sample: sample.index) (sglang_rollout.py:590). A None index makes that sort raise TypeError: '<' not supported between instances of 'NoneType' and 'int' first. On every real train path sample.index is assigned (data_source.py:109: sample.index = self.sample_index). This is intentional fail-loud behavior documented in the call-site comment (train_data_conversion.py:111-113): a custom rollout that leaves index unset crashes at the sort rather than a positional fallback silently merging two unrelated groups. Adding a guard here would only mask that upstream contract.

(3) torch.std unbiased=False to avoid group-size-1 NaN — this would change numerics on the default path and is out of scope. Upstream uses rewards.std(dim=-1, keepdim=True) i.e. unbiased=True (Bessel-corrected) — confirmed via git show upstream/main:miles/ray/rollout/train_data_conversion.py. This PR's whole contract is that the default 1-sample-per-rollout path stays bit-identical to upstream group norm; switching to unbiased=False rescales every std-normalized advantage by sqrt((n-1)/n), which is a behavior change for all GRPO/GSPO std-normalized training, not a safe robustness fix. The singleton-group (group size 1) NaN is pre-existing upstream behavior and orthogonal to this fix. The existing test test_std_normalization_divides_by_per_prompt_group_std even pins the unbiased result (+-1/sqrt(2), see the comment at line 60), so flipping the flag would also break a deliberately-pinned numeric.

(4) Early-return on empty raw_rewards — upstream crashes identically on empty input (torch.tensor([]).view(-1, 0) raises RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0]), which I reproduced against both the upstream original and this helper. In practice the upstream assert len(data) == args.rollout_batch_size (sglang_rollout.py:453) guarantees a non-empty batch, so this path is unreachable. Adding an early return would change behavior (crash -> silent empty return) versus upstream with no real caller benefiting, so I'd rather not introduce that divergence in a bit-identity-scoped fix.

Comment on lines +64 to +73
def test_uneven_rollout_counts_fall_back_to_one_global_group():
# 3 rollouts when 4 are expected -> single global group (documented fallback).
advantages = group_relative_advantages(
[1.0, 2.0, 3.0],
group_ids=[0, 1, 2],
n_samples_per_prompt=2,
rollout_batch_size=2,
std_normalization=False,
)
assert advantages == pytest.approx([-1.0, 0.0, 1.0])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add unit tests to verify that group_relative_advantages raises a ValueError when group_ids contains None, and that it does not produce NaN when std_normalization is True and the group size is 1.

def test_uneven_rollout_counts_fall_back_to_one_global_group():
    # 3 rollouts when 4 are expected -> single global group (documented fallback).
    advantages = group_relative_advantages(
        [1.0, 2.0, 3.0],
        group_ids=[0, 1, 2],
        n_samples_per_prompt=2,
        rollout_batch_size=2,
        std_normalization=False,
    )
    assert advantages == pytest.approx([-1.0, 0.0, 1.0])


def test_group_ids_containing_none_raises_value_error():
    with pytest.raises(ValueError, match="group_ids contains None"):
        group_relative_advantages(
            [1.0, 2.0],
            group_ids=[None, 1],
            n_samples_per_prompt=2,
            rollout_batch_size=2,
            std_normalization=False,
        )


def test_std_normalization_with_group_size_one_does_not_nan():
    advantages = group_relative_advantages(
        [5.0],
        group_ids=[0],
        n_samples_per_prompt=1,
        rollout_batch_size=1,
        std_normalization=True,
    )
    assert advantages == [0.0]

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following from the responses above: the two new tests suggested here are tied to concerns (2) and (3), both of which I'm keeping as-is, so a None -> ValueError test would assert behavior the function intentionally doesn't own (the failure is the upstream sort, not this helper), and a std/group-size-1 no-NaN test would lock in the unbiased=False numeric change that breaks the default-path bit-identity invariant.

The one test you flagged as genuinely useful — uneven rollout counts collapsing to a single global group — is already covered: test_uneven_rollout_counts_fall_back_to_one_global_group (test_group_relative_advantages.py:64-73) passes 3 rollouts when 4 are expected and asserts the documented global-group fallback [-1.0, 0.0, 1.0]. The suite also already pins the default-path == plain per-prompt group norm identity (line 19), the fan-out-counts-once contract (line 32), and the per-prompt unbiased std division (line 51), which together cover the cases worth pinning for this change.

@EazyReal EazyReal force-pushed the upstream-pr/grpo-rollout-baseline branch from 2ee1e80 to 742bcf2 Compare June 12, 2026 08:34
@EazyReal EazyReal changed the title fix(rollout): count each rollout once in GRPO group baseline under fan-out fix(grpo): count each rollout once under fan-out Jun 13, 2026
GRPO compares attempts per prompt: each rollout should contribute its single
outcome reward once to the per-prompt baseline. But an agentic rollout can fan
into several training segments -- multi_turn / agentic_tool_call return a
list[Sample] of deepcopies that share one Sample.index (compaction / sub-agent
/ fork branches of one attempt). The old per-prompt group norm reshaped the flat
sample reward vector by n_samples_per_prompt: any fan-out makes the flat length
exceed n_samples_per_prompt * rollout_batch_size, so the reshape guard falls back
to a single global group. That collapses all prompts into one baseline and
destroys per-prompt centering -- the exact signal GRPO relies on -- and lets a
rollout that happens to emit more segments dominate the advantage.

Fix: make the *rollout* the unit and the *prompt id* the group key. A new pure
helper, group_relative_advantages (math_utils), reduces raw rewards to one
reward per rollout id (Sample.index), buckets the rollouts by prompt id
(Sample.group_index) when that metadata is present, centers each bucket on its
mean (optionally divides by the unbiased std + 1e-6), and broadcasts each
rollout's advantage back to all of its segments. Without prompt metadata
(custom rollout paths) the legacy positional grouping is preserved verbatim:
reshape by n_samples_per_prompt when counts match, else one global group.

Fail loud on missing rollout ids: Sample.index defaults to None, so an unset
index would silently merge every such sample into a single rollout and zero
out its advantages; the helper raises ValueError instead.

Behavior fix for uneven batches: the data source sets group_index on every
real path, so dynamic-sampling / uneven-count batches now get true per-prompt
baselines instead of the old single-global-group fallback. On the rigid
default path (every prompt exactly n_samples_per_prompt single-sample
rollouts) the buckets equal the legacy reshape rows and the result is
bit-identical to the previous reshape/mean/std group norm -- non-agentic
training is unchanged.

CI: adds tests/fast/backends/training_utils/loss/test_group_relative_advantages.py
(CPU-only, torch on CPU, no GPU); it auto-registers into the stage-a-cpu suite by
the tests/fast/ directory convention. The tests pin: rigid-path bit-identity vs
the legacy computation; a fanned rollout counts once and its segments share the
advantage; per-prompt std division; uneven groups get per-prompt baselines;
missing prompt metadata falls back to the legacy global group; None rollout ids
raise; a rollout id spanning two prompts raises.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@EazyReal EazyReal force-pushed the upstream-pr/grpo-rollout-baseline branch from 742bcf2 to fbe308b Compare June 13, 2026 04:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant