fix(metrics): group pass-rate by real sample identity#1326
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the compute_pass_rate function to handle ragged inputs (such as over-sampled, dynamic-sampled, or list-expanded evaluations) without crashing, supporting both explicit group_ids bucketing and contiguous chunking. The training and evaluation logging call sites are updated accordingly, and comprehensive unit tests are added. The reviewer suggested a performance optimization to precompute group sizes and correct counts outside the loop rather than recalculating them for each k value.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| log_dict = {} | ||
| for k in pass_rate_name_list: | ||
| num_correct = np.sum(rewards_of_group == 1, axis=1) | ||
| num_samples = np.full(num_groups, group_size) | ||
| eligible = [g for g in groups if len(g) >= k] | ||
| if not eligible: | ||
| continue | ||
|
|
||
| pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k) | ||
| num_samples = np.array([len(g) for g in eligible]) | ||
| num_correct = np.array([int(np.sum(g == 1)) for g in eligible]) | ||
|
|
||
| pass_k = np.mean(pass_k_estimates) | ||
| log_dict[f"pass@{k}"] = pass_k | ||
| pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k) | ||
| log_dict[f"pass@{k}"] = np.mean(pass_k_estimates) |
There was a problem hiding this comment.
The current implementation recalculates num_samples and num_correct (including the expensive np.sum(g == 1) operation on numpy arrays) from scratch for every single k in pass_rate_name_list. Since the groups themselves do not change across different k values, we can precompute the sizes and correct counts for all groups once before the loop. This avoids redundant list comprehensions and numpy operations, significantly improving performance when there are many groups or multiple k values.
| log_dict = {} | |
| for k in pass_rate_name_list: | |
| num_correct = np.sum(rewards_of_group == 1, axis=1) | |
| num_samples = np.full(num_groups, group_size) | |
| eligible = [g for g in groups if len(g) >= k] | |
| if not eligible: | |
| continue | |
| pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k) | |
| num_samples = np.array([len(g) for g in eligible]) | |
| num_correct = np.array([int(np.sum(g == 1)) for g in eligible]) | |
| pass_k = np.mean(pass_k_estimates) | |
| log_dict[f"pass@{k}"] = pass_k | |
| pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k) | |
| log_dict[f"pass@{k}"] = np.mean(pass_k_estimates) | |
| group_sizes = np.array([len(g) for g in groups]) | |
| group_correct = np.array([np.sum(g == 1) for g in groups]) | |
| log_dict = {} | |
| for k in pass_rate_name_list: | |
| eligible_mask = group_sizes >= k | |
| if not np.any(eligible_mask): | |
| continue | |
| num_samples = group_sizes[eligible_mask] | |
| num_correct = group_correct[eligible_mask] | |
| pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k) | |
| log_dict[f"pass@{k}"] = np.mean(pass_k_estimates) |
There was a problem hiding this comment.
Good catch, applied. The per-group quantities len(g) and np.sum(g == 1) are k-invariant -- only the eligibility filter (size >= k) and the estimator depend on k -- so I hoisted them out of the loop:
group_sizes = np.array([len(g) for g in groups])
group_correct = np.array([int(np.sum(g == 1)) for g in groups])
for k in pass_rate_name_list:
eligible = group_sizes >= k
if not eligible.any():
continue
pass_k_estimates = _estimate_pass_at_k(group_sizes[eligible], group_correct[eligible], k)
log_dict[f"pass@{k}"] = np.mean(pass_k_estimates)I kept the boolean-mask form (group_sizes[eligible]) rather than a Python list comprehension so the rung-eligibility (>= k), the per-rung averaging, and the rung-drop-when-none-eligible (not eligible.any()) semantics stay byte-for-byte what they were -- group order is preserved, so each (num_samples, num_correct) pair feeding _estimate_pass_at_k is unchanged. I confirmed the output dict is identical to the previous per-k recompute across exact-multiple chunked input, ragged group_ids, a trailing partial chunk, all-zero/all-one rewards, singleton groups that drop higher rungs, and group_size=2. Thanks!
c49ef56 to
2dcd8a9
Compare
compute_pass_rate inherited a rigid `assert len(flat_rewards) == num_groups * group_size` + reshape contract, but both call sites feed it ragged input and crash on the assertion: * Train (log_passrate) pinned num_groups=rollout_batch_size, so any over-sampling, dynamic-sampling filtering, or aborted/partial group makes the flat reward count differ from rollout_batch_size * n_samples_per_prompt and the reshape raises. * Eval (log_eval_rollout_data) passed num_groups=None; multi-turn / tool eval list-expands one prompt into a variable number of samples, so the reward count is not a clean multiple of n_samples_per_eval_prompt and the reshape raises. Fix: replace num_groups with an optional group_ids argument and group rewards by the per-prompt group_index that samples already carry (or now carry): * Utility: when group_ids is given, bucket flat_rewards by group key (groups may have any size); when None, chunk into contiguous group_size blocks, keeping a trailing partial block instead of asserting an exact multiple. The chunk fallback assumes a group-major layout and tolerates only trailing raggedness, as its docstring now states. * Eval: both eval rollout implementations now assign sample.group_index = prompt index at sample creation (mirroring the train data source); list-expanded samples inherit it via deepcopy. The metrics site buckets by that real identity instead of index arithmetic, which silently mis-grouped whenever a per-dataset n_samples_per_eval_prompt override changed the stride. * Train: convert_samples_to_train_data emits a full-batch group_indices entry aligned with raw_reward (propagated un-partitioned by split_train_data_by_dp, like raw_reward), and log_passrate buckets by it. Rewards without group identity (custom conversion/rollout functions) fall back to chunking. pass@k is averaged only over groups with at least k samples; a rung is dropped when no group qualifies, so undersized groups never distort or crash the metric. For well-formed exact-multiple input the chunking path is numerically identical to the legacy reshape, so the common train layout is unchanged. CI-verifiable: adds tests/fast/utils/test_metric_utils.py plus call-site tests (auto-collected into the stage-a-cpu suite) pinning (1) bit-identical output to the legacy reshape on exact-multiple input, (2) no crash on the train over-sampled and eval list-expanded ragged shapes, (3) rung eligibility, (4) eval bucketing by group_index rather than position, and (5) the group_indices full-batch alignment contract through convert_samples_to_train_data / split_train_data_by_dp. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2dcd8a9 to
878de0d
Compare
Problem
compute_pass_rate(inmiles/utils/metric_utils.py) hard-codes a rigid shape: it runsassert len(flat_rewards) == num_groups * group_sizeand thennp.array(flat_rewards).reshape(num_groups, group_size). Both call sites can feed it a flat reward count that is not an exact multiple ofgroup_size, so the assert (or reshape) raises:log_passrateinmiles/backends/training_utils/log_utils.pypinnednum_groups=args.rollout_batch_size. Under over-sampling, dynamic-sampling filtering, or aborted/partial groups, the flat reward count is notrollout_batch_size * n_samples_per_prompt, so the reshape raises.log_eval_rollout_datainmiles/ray/rollout/metrics.pypassednum_groups=None, which makes the function infernum_groups = len(flat_rewards) // group_size. Multi-turn / tool eval list-expands a single prompt into a variable number of samples, so the reward count is not a clean multiple ofn_samples_per_eval_promptand the assert raises.And even when the count happens to divide cleanly, positional grouping can silently mis-group: list expansion shifts every later sample's position, and a per-dataset
n_samples_per_eval_promptoverride changes the true group stride away fromargs.n_samples_per_eval_prompt.Before vs After
Same input, train over-sampled shape —
9rewards withgroup_size=4(not a multiple of 4):Eval list-expanded shape —
5rewards from2prompts (prompt 0 expanded to 3 samples), bucketed by the per-promptgroup_indexeach sample now carries:Fix
Replace the
num_groupsparameter with an optionalgroup_idsand group rewards by real per-sample identity instead of positional arithmetic:group_idsgiven: bucketflat_rewardsby the per-sample group key; groups may have any number of samples.group_ids is None: chunkflat_rewardsinto contiguousgroup_sizeblocks, keeping a trailing partial block as a smaller group instead of asserting an exact multiple. The docstring states what the fallback assumes: a group-major layout (each group's rewards adjacent, every group full-size), tolerating only trailing raggedness — interior raggedness needsgroup_ids.sample.group_index = <prompt index>at sample creation, exactly as the train data source already does for train samples; list-expanded multi-turn/tool samples inherit it throughdeepcopy.log_eval_rollout_datapassesgroup_ids = [s.group_index for s in samples]when every sample carries one, so each sample lands in its true prompt's group regardless of expansion or per-dataset sampling overrides. Samples without group identity (e.g. from custom eval rollout functions) fall back to chunking.convert_samples_to_train_dataemits agroup_indicesentry built fromsample.group_index, aligned 1:1 withraw_rewardand propagated full-batch (un-partitioned) bysplit_train_data_by_dpexactly likeraw_reward.log_passratedrops thenum_groupspin and buckets by it; rewards without group identity (custom conversion functions, old debug dumps) fall back to chunking.For each
k, pass@k is averaged only over groups with at leastksamples; a rung is dropped entirely when no group qualifies, so undersized/partial groups never crash or distort the metric.Why this is the right fix
group_indexfrom the data source; eval now assigns it at the same point (sample creation). Both call sites consume that recorded identity, so grouping stays correct under list expansion, per-datasetn_samples_per_eval_promptoverrides, and buffer reordering — everything that breaks positional or index-derived grouping.group_ids=Nonechunking path is numerically identical to the legacyreshape(num_groups, group_size), so the common train layout (exact-multiple contiguous groups) is unchanged. A test pins this equivalence against a copy of the pre-fix implementation.>= k(and dropping empty rungs) keeps every reported pass@k honest instead of padding or skewing it.group_indexassignment, and the train conversion gains one aligned entry; the pass@k estimator (_estimate_pass_at_k) is untouched.tests/fast/utils/test_metric_utils.pyand call-site tests (auto-collected into thestage-a-cpuCPU suite by the existingtests/fast/**discovery convention). They cover: (1) bit-identical output vs the legacy reshape on exact-multiple input, (2) no crash on the train over-sampled and eval list-expanded ragged shapes, (3) rung eligibility, (4)log_eval_rollout_databucketing bygroup_indexrather than position (a positional pairing would report a different pass@2), and (5) thegroup_indicesfull-batch alignment contract throughconvert_samples_to_train_data/split_train_data_by_dp. The existing eval integration test now pins that eval samples carry their per-promptgroup_index.