Skip to content

fix(metrics): group pass-rate by real sample identity#1326

Open
EazyReal wants to merge 1 commit into
radixark:mainfrom
EazyReal:upstream-pr/metric-ragged-passrate
Open

fix(metrics): group pass-rate by real sample identity#1326
EazyReal wants to merge 1 commit into
radixark:mainfrom
EazyReal:upstream-pr/metric-ragged-passrate

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 12, 2026

Copy link
Copy Markdown

Problem

compute_pass_rate (in miles/utils/metric_utils.py) hard-codes a rigid shape: it runs assert len(flat_rewards) == num_groups * group_size and then np.array(flat_rewards).reshape(num_groups, group_size). Both call sites can feed it a flat reward count that is not an exact multiple of group_size, so the assert (or reshape) raises:

  • Trainlog_passrate in miles/backends/training_utils/log_utils.py pinned num_groups=args.rollout_batch_size. Under over-sampling, dynamic-sampling filtering, or aborted/partial groups, the flat reward count is not rollout_batch_size * n_samples_per_prompt, so the reshape raises.
  • Evallog_eval_rollout_data in miles/ray/rollout/metrics.py passed num_groups=None, which makes the function infer num_groups = len(flat_rewards) // group_size. Multi-turn / tool eval list-expands a single prompt into a variable number of samples, so the reward count is not a clean multiple of n_samples_per_eval_prompt and the assert raises.

And even when the count happens to divide cleanly, positional grouping can silently mis-group: list expansion shifts every later sample's position, and a per-dataset n_samples_per_eval_prompt override changes the true group stride away from args.n_samples_per_eval_prompt.

Before vs After

Same input, train over-sampled shape — 9 rewards with group_size=4 (not a multiple of 4):

rewards = [1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0]

# BEFORE (legacy): num_groups pinned to 2 -> assert 9 == 2 * 4 fails
compute_pass_rate(rewards, group_size=4, num_groups=2)
# AssertionError: len(flat_rewards)=9 num_groups=2 group_size=4
# AFTER: chunked into [1,0,1,1] (c=3,n=4), [0,0,1,0] (c=1,n=4), [1] (c=1,n=1)
compute_pass_rate(rewards, group_size=4)
# {'pass@1': 0.666...,   # mean(3/4, 1/4, 1.0) over all 3 groups
#  'pass@2': 0.75,       # eligible groups (size>=2): [1,0,1,1]->1.0, [0,0,1,0]->0.5
#  'pass@4': 1.0}        # eligible groups (size>=4): both fully sampled -> 1.0

Eval list-expanded shape — 5 rewards from 2 prompts (prompt 0 expanded to 3 samples), bucketed by the per-prompt group_index each sample now carries:

rewards   = [1.0, 0.0, 1.0, 0.0, 1.0]
group_ids = [0,   0,   0,   1,   1]   # s.group_index, assigned at sample creation

# BEFORE: 5 is not a multiple of group_size=2 -> assert raises
# AFTER:
compute_pass_rate(rewards, group_size=2, group_ids=group_ids)
# prompt 0 -> [1,0,1] (c=2,n=3), prompt 1 -> [0,1] (c=1,n=2)
# {'pass@1': 0.583...,  # mean(2/3, 1/2)
#  'pass@2': 1.0}       # every group has >=1 correct

Fix

Replace the num_groups parameter with an optional group_ids and group rewards by real per-sample identity instead of positional arithmetic:

  • Utilitygroup_ids given: bucket flat_rewards by the per-sample group key; groups may have any number of samples. group_ids is None: chunk flat_rewards into contiguous group_size blocks, keeping a trailing partial block as a smaller group instead of asserting an exact multiple. The docstring states what the fallback assumes: a group-major layout (each group's rewards adjacent, every group full-size), tolerating only trailing raggedness — interior raggedness needs group_ids.
  • Eval — both eval rollout implementations now assign sample.group_index = <prompt index> at sample creation, exactly as the train data source already does for train samples; list-expanded multi-turn/tool samples inherit it through deepcopy. log_eval_rollout_data passes group_ids = [s.group_index for s in samples] when every sample carries one, so each sample lands in its true prompt's group regardless of expansion or per-dataset sampling overrides. Samples without group identity (e.g. from custom eval rollout functions) fall back to chunking.
  • Trainconvert_samples_to_train_data emits a group_indices entry built from sample.group_index, aligned 1:1 with raw_reward and propagated full-batch (un-partitioned) by split_train_data_by_dp exactly like raw_reward. log_passrate drops the num_groups pin and buckets by it; rewards without group identity (custom conversion functions, old debug dumps) fall back to chunking.

For each k, pass@k is averaged only over groups with at least k samples; a rung is dropped entirely when no group qualifies, so undersized/partial groups never crash or distort the metric.

Why this is the right fix

  • Group identity is recorded at the source, not reconstructed downstream. Train samples already carry group_index from the data source; eval now assigns it at the same point (sample creation). Both call sites consume that recorded identity, so grouping stays correct under list expansion, per-dataset n_samples_per_eval_prompt overrides, and buffer reordering — everything that breaks positional or index-derived grouping.
  • Default-path safety (bit-identical). For well-formed exact-multiple input the group_ids=None chunking path is numerically identical to the legacy reshape(num_groups, group_size), so the common train layout (exact-multiple contiguous groups) is unchanged. A test pins this equivalence against a copy of the pre-fix implementation.
  • No metric distortion. A singleton/partial group can satisfy pass@1 but cannot represent pass@4; restricting each rung to groups of size >= k (and dropping empty rungs) keeps every reported pass@k honest instead of padding or skewing it.
  • Graceful degradation. Callers whose samples or rewards carry no group identity keep working through the chunk fallback, which never crashes and is exact for well-formed input.
  • Minimal, no new abstraction. One function's grouping logic changes, the two call sites each change by a few lines, the eval rollout implementations gain a one-line group_index assignment, and the train conversion gains one aligned entry; the pass@k estimator (_estimate_pass_at_k) is untouched.
  • CI-verifiable. Adds tests/fast/utils/test_metric_utils.py and call-site tests (auto-collected into the stage-a-cpu CPU suite by the existing tests/fast/** discovery convention). They cover: (1) bit-identical output vs the legacy reshape on exact-multiple input, (2) no crash on the train over-sampled and eval list-expanded ragged shapes, (3) rung eligibility, (4) log_eval_rollout_data bucketing by group_index rather than position (a positional pairing would report a different pass@2), and (5) the group_indices full-batch alignment contract through convert_samples_to_train_data / split_train_data_by_dp. The existing eval integration test now pins that eval samples carry their per-prompt group_index.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the compute_pass_rate function to handle ragged inputs (such as over-sampled, dynamic-sampled, or list-expanded evaluations) without crashing, supporting both explicit group_ids bucketing and contiguous chunking. The training and evaluation logging call sites are updated accordingly, and comprehensive unit tests are added. The reviewer suggested a performance optimization to precompute group sizes and correct counts outside the loop rather than recalculating them for each k value.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines 50 to +60
log_dict = {}
for k in pass_rate_name_list:
num_correct = np.sum(rewards_of_group == 1, axis=1)
num_samples = np.full(num_groups, group_size)
eligible = [g for g in groups if len(g) >= k]
if not eligible:
continue

pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k)
num_samples = np.array([len(g) for g in eligible])
num_correct = np.array([int(np.sum(g == 1)) for g in eligible])

pass_k = np.mean(pass_k_estimates)
log_dict[f"pass@{k}"] = pass_k
pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k)
log_dict[f"pass@{k}"] = np.mean(pass_k_estimates)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation recalculates num_samples and num_correct (including the expensive np.sum(g == 1) operation on numpy arrays) from scratch for every single k in pass_rate_name_list. Since the groups themselves do not change across different k values, we can precompute the sizes and correct counts for all groups once before the loop. This avoids redundant list comprehensions and numpy operations, significantly improving performance when there are many groups or multiple k values.

Suggested change
log_dict = {}
for k in pass_rate_name_list:
num_correct = np.sum(rewards_of_group == 1, axis=1)
num_samples = np.full(num_groups, group_size)
eligible = [g for g in groups if len(g) >= k]
if not eligible:
continue
pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k)
num_samples = np.array([len(g) for g in eligible])
num_correct = np.array([int(np.sum(g == 1)) for g in eligible])
pass_k = np.mean(pass_k_estimates)
log_dict[f"pass@{k}"] = pass_k
pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k)
log_dict[f"pass@{k}"] = np.mean(pass_k_estimates)
group_sizes = np.array([len(g) for g in groups])
group_correct = np.array([np.sum(g == 1) for g in groups])
log_dict = {}
for k in pass_rate_name_list:
eligible_mask = group_sizes >= k
if not np.any(eligible_mask):
continue
num_samples = group_sizes[eligible_mask]
num_correct = group_correct[eligible_mask]
pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k)
log_dict[f"pass@{k}"] = np.mean(pass_k_estimates)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, applied. The per-group quantities len(g) and np.sum(g == 1) are k-invariant -- only the eligibility filter (size >= k) and the estimator depend on k -- so I hoisted them out of the loop:

group_sizes = np.array([len(g) for g in groups])
group_correct = np.array([int(np.sum(g == 1)) for g in groups])

for k in pass_rate_name_list:
    eligible = group_sizes >= k
    if not eligible.any():
        continue
    pass_k_estimates = _estimate_pass_at_k(group_sizes[eligible], group_correct[eligible], k)
    log_dict[f"pass@{k}"] = np.mean(pass_k_estimates)

I kept the boolean-mask form (group_sizes[eligible]) rather than a Python list comprehension so the rung-eligibility (>= k), the per-rung averaging, and the rung-drop-when-none-eligible (not eligible.any()) semantics stay byte-for-byte what they were -- group order is preserved, so each (num_samples, num_correct) pair feeding _estimate_pass_at_k is unchanged. I confirmed the output dict is identical to the previous per-k recompute across exact-multiple chunked input, ragged group_ids, a trailing partial chunk, all-zero/all-one rewards, singleton groups that drop higher rungs, and group_size=2. Thanks!

@EazyReal EazyReal force-pushed the upstream-pr/metric-ragged-passrate branch from c49ef56 to 2dcd8a9 Compare June 12, 2026 09:14
@EazyReal EazyReal changed the title fix: make compute_pass_rate ragged-safe at both train and eval call sites fix(metrics): group pass-rate by real sample identity Jun 13, 2026
compute_pass_rate inherited a rigid `assert len(flat_rewards) == num_groups
* group_size` + reshape contract, but both call sites feed it ragged input
and crash on the assertion:

* Train (log_passrate) pinned num_groups=rollout_batch_size, so any
  over-sampling, dynamic-sampling filtering, or aborted/partial group makes
  the flat reward count differ from rollout_batch_size * n_samples_per_prompt
  and the reshape raises.
* Eval (log_eval_rollout_data) passed num_groups=None; multi-turn / tool eval
  list-expands one prompt into a variable number of samples, so the reward
  count is not a clean multiple of n_samples_per_eval_prompt and the reshape
  raises.

Fix: replace num_groups with an optional group_ids argument and group rewards
by the per-prompt group_index that samples already carry (or now carry):

* Utility: when group_ids is given, bucket flat_rewards by group key (groups
  may have any size); when None, chunk into contiguous group_size blocks,
  keeping a trailing partial block instead of asserting an exact multiple.
  The chunk fallback assumes a group-major layout and tolerates only trailing
  raggedness, as its docstring now states.
* Eval: both eval rollout implementations now assign
  sample.group_index = prompt index at sample creation (mirroring the train
  data source); list-expanded samples inherit it via deepcopy. The metrics
  site buckets by that real identity instead of index arithmetic, which
  silently mis-grouped whenever a per-dataset n_samples_per_eval_prompt
  override changed the stride.
* Train: convert_samples_to_train_data emits a full-batch group_indices entry
  aligned with raw_reward (propagated un-partitioned by
  split_train_data_by_dp, like raw_reward), and log_passrate buckets by it.
  Rewards without group identity (custom conversion/rollout functions) fall
  back to chunking.

pass@k is averaged only over groups with at least k samples; a rung is
dropped when no group qualifies, so undersized groups never distort or crash
the metric. For well-formed exact-multiple input the chunking path is
numerically identical to the legacy reshape, so the common train layout is
unchanged.

CI-verifiable: adds tests/fast/utils/test_metric_utils.py plus call-site
tests (auto-collected into the stage-a-cpu suite) pinning (1) bit-identical
output to the legacy reshape on exact-multiple input, (2) no crash on the
train over-sampled and eval list-expanded ragged shapes, (3) rung
eligibility, (4) eval bucketing by group_index rather than position, and
(5) the group_indices full-batch alignment contract through
convert_samples_to_train_data / split_train_data_by_dp.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@EazyReal EazyReal force-pushed the upstream-pr/metric-ragged-passrate branch from 2dcd8a9 to 878de0d Compare June 13, 2026 04:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant