fix(rollout): apply rollout sample filter in the manager#1324
Conversation
There was a problem hiding this comment.
Code Review
This pull request hoists the application of --rollout-sample-filter-path to the generic postprocess_rollout_data manager function, ensuring that all rollout functions (including custom ones) consistently honor the filter and avoiding double-filtering in individual rollout implementations. Feedback suggests adding an empty check for data to prevent an IndexError and using getattr to safely retrieve the filter path from args to avoid an AttributeError.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
|
||
| # Apply the rollout sample filter at this generic manager choke point so every | ||
| # rollout fn honors --rollout-sample-filter-path, while `data` is still the | ||
| # documented grouped `list[list[Sample]]` (before the flatten/trim below). | ||
| if (filter_func := load_function(args.rollout_sample_filter_path)) is not None: | ||
| filter_func(args, data) |
There was a problem hiding this comment.
To prevent a potential AttributeError if args does not contain the rollout_sample_filter_path attribute, and to avoid an IndexError on data[0] if data is empty, we should add an empty check for data and use getattr to safely retrieve the filter path.
| # Apply the rollout sample filter at this generic manager choke point so every | |
| # rollout fn honors --rollout-sample-filter-path, while `data` is still the | |
| # documented grouped `list[list[Sample]]` (before the flatten/trim below). | |
| if (filter_func := load_function(args.rollout_sample_filter_path)) is not None: | |
| filter_func(args, data) | |
| if not data: | |
| return data, metadata | |
| # Apply the rollout sample filter at this generic manager choke point so every | |
| # rollout fn honors --rollout-sample-filter-path, while data is still the | |
| # documented grouped list[list[Sample]] (before the flatten/trim below). | |
| if (filter_func := load_function(getattr(args, "rollout_sample_filter_path", None))) is not None: | |
| filter_func(args, data) |
There was a problem hiding this comment.
Thanks for the careful read. I'm going to leave both of these as-is, with reasoning:
-
getattr fallback: --rollout-sample-filter-path is declared unconditionally in arguments.py (with default=None), so args.rollout_sample_filter_path is always present on the parsed args namespace and can't raise AttributeError. load_function also returns None for a None path, so the existing
if (filter_func := load_function(args.rollout_sample_filter_path)) is not None:walrus already handles the unset case. Using getattr(args, 'rollout_sample_filter_path', None) here would actually be misleading, since it suggests the attribute might be missing when it never is. -
Empty-data / IndexError: the line this PR adds (filter_func(args, data)) doesn't index data[0]. The data[0] access is in the pre-existing flatten loop
while isinstance(data[0], list), which is unchanged from upstream/main. So an empty data would already raise the same IndexError upstream, independent of this change — it's pre-existing and orthogonal to routing the sample filter through the manager. Adding an empty-data guard here would also change default-path behavior (it would silently return empty data instead of surfacing the existing error), which is outside the scope of this fix. If empty rollout batches turn out to be a real, reachable condition, that's worth its own dedicated change rather than folding it into this one.
c984ebc to
cebecf1
Compare
The documented --rollout-sample-filter-path hook is silently ignored for any custom --rollout-function-path. It was applied only INSIDE the two built-in rollout fns (miles/rollout/sglang_rollout.py and miles/rollout/inference_rollout/inference_rollout_train.py), so a user who swaps in a custom rollout function got no sample filtering at all even though the flag was set and the consumer (train_data_conversion: remove_sample -> loss_mask=0) is fully generic. Only the PRODUCER side was non-generic. Fix: hoist the filter to the single generic choke point — postprocess_rollout_data, reached by RolloutManager._get_rollout_data for EVERY rollout fn. The filter runs while `data` is still the documented grouped `list[list[Sample]]` (the `fn(args, data)` contract), i.e. before the existing flatten/trim. The redundant in-fn applications are removed from both built-in fns to avoid double-filtering. Scope kept tight: - --rollout-all-samples-process-path stays inside the rollout fns: it needs all_samples / data_source that the manager does not have. - The load_debug_rollout_data path does not flow through postprocess_rollout_data, so the filter applies on the normal post-rollout path only. Default-path safety: walrus-guarded load_function(None) -> None, so with rollout_sample_filter_path unset the choke point is a no-op and output is bit-identical to before. Tests: add CPU tests in tests/fast/ray/rollout/test_rollout_data_conversion.py asserting a custom rollout fn's grouped output now honors the filter via the manager path (grouped-before-flatten contract, in-place remove_sample, and the None default no-op). The inference-rollout integration test asserts the filter is no longer applied in-fn (manager owns it now) — parametrized over BOTH built-in rollout fns (modular InferenceRolloutFn and the legacy sglang_rollout.generate_rollout), so reintroducing the in-fn filter in either one fails as double-filtering — while all-samples-process is still applied in-fn, asserted against the real over-sampling invariant (all_samples strictly exceeds the trained set). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
cebecf1 to
20e78a4
Compare
Problem
The documented
--rollout-sample-filter-pathhook is silently ignored for any custom--rollout-function-path. The filter is applied only inside the two built-in rollout functions —miles/rollout/sglang_rollout.py(generate_rollout_async) andmiles/rollout/inference_rollout/inference_rollout_train.py(generate_rollout_async). The producer side is non-generic, but the consumer side already is:train_data_conversion.pydoesif sample.remove_sample: sample.loss_mask = [0] * response_length. So when a user swaps in a custom rollout function whose output does not itself call the filter, the filter callable is never invoked,remove_sampleis never set, and every sample reaches the loss unfiltered — with no error and no warning.Before vs After
Setup:
--rollout-sample-filter-pathpoints at adrop_zero_rewardfilter, used with a custom--rollout-function-paththat emits the documented groupedlist[list[Sample]]shape:Before — the filter is never called for a custom rollout fn; nothing is flagged, so the zero-reward samples still contribute to the loss:
{s.index for s in flat if s.remove_sample} # == set() (filter never ran) # samples 1 and 3 (reward 0.0) keep loss_mask = [1, 1, ...] and train as positivesAfter — the same input flows through the manager choke point, the filter runs, and the zero-reward samples are flagged →
train_data_conversionzeros theirloss_mask:{s.index for s in flat if s.remove_sample} # == {1, 3} # samples 1 and 3 get loss_mask = [0, 0, ...] and are excluded from the loss(This is exactly
test_custom_rollout_fn_now_honors_filter_via_manager_pathin the new test file.)Fix
Hoist the filter to the single generic choke point,
postprocess_rollout_datainmiles/ray/rollout/rollout_data_conversion.py, which is reached byRolloutManager._get_rollout_datafor every rollout function:The filter runs while
datais still the documented groupedlist[list[Sample]]— thefn(args, data)contract — before the existing flatten/trim, and flagsSample.remove_samplein place. The redundant in-function applications are removed from both built-in rollout functions (to avoid double-filtering), with a one-line comment left at each former call site pointing to the new owner.Scope is kept deliberately tight:
--rollout-all-samples-process-pathstays inside the rollout functions: it needsall_samples/data_source, which the manager does not have.load_debug_rollout_datapath does not flow throughpostprocess_rollout_data, so the filter applies on the normal post-rollout path only.Why this is the right fix
load_function(None)returnsNone, so the walrus-guarded block is a no-op when--rollout-sample-filter-pathis not set; output is identical to before. Covered bytest_no_filter_path_is_a_no_op_default.fn(args, data)withdataaslist[list[Sample]]grouped byn_samples_per_prompt, andset sample.remove_sample = Trueto exclude from loss. The new call site applies it on exactly that shape, before the flatten — generalizing the existing contract rather than changing it. No new flag is introduced; only where the existing argument is applied changes.remove_sample → loss_mask = 0) is untouched.tests/fast/ray/rollout/test_rollout_data_conversion.pyassert, via the manager path, that (1) the filter receives groupedlist[list[Sample]]before the flatten, is invoked exactly once per batch, and its in-place mutation survives flattening, (2) a custom rollout fn's grouped output now honors the filter (the previously-broken case above), and (3)rollout_sample_filter_path=Noneflags nothing. The integration test (tests/fast/rollout/inference_rollout/integration/test_sample_filter.py) asserts the filter is no longer applied in-function — parametrized over both built-in rollout fns (the modularInferenceRolloutFnand the legacysglang_rollout.generate_rollout), so reintroducing the in-function filter in either one fails as double-filtering — while all-samples-process is still applied in-function, asserted against the real over-sampling invariant (all_samplesstrictly exceeds the trained set). All run in the existingtests/fast/stage.