Skip to content

fix(rollout): apply rollout sample filter in the manager#1324

Open
EazyReal wants to merge 1 commit into
radixark:mainfrom
EazyReal:upstream-pr/manager-side-sample-filter
Open

fix(rollout): apply rollout sample filter in the manager#1324
EazyReal wants to merge 1 commit into
radixark:mainfrom
EazyReal:upstream-pr/manager-side-sample-filter

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 12, 2026

Copy link
Copy Markdown

Problem

The documented --rollout-sample-filter-path hook is silently ignored for any custom --rollout-function-path. The filter is applied only inside the two built-in rollout functions — miles/rollout/sglang_rollout.py (generate_rollout_async) and miles/rollout/inference_rollout/inference_rollout_train.py (generate_rollout_async). The producer side is non-generic, but the consumer side already is: train_data_conversion.py does if sample.remove_sample: sample.loss_mask = [0] * response_length. So when a user swaps in a custom rollout function whose output does not itself call the filter, the filter callable is never invoked, remove_sample is never set, and every sample reaches the loss unfiltered — with no error and no warning.

Before vs After

Setup: --rollout-sample-filter-path points at a drop_zero_reward filter, used with a custom --rollout-function-path that emits the documented grouped list[list[Sample]] shape:

# custom rollout fn output: 2 prompts x 2 samples, rewards [1.0, 0.0] per group
data = [
    [Sample(index=0, reward=1.0), Sample(index=1, reward=0.0)],
    [Sample(index=2, reward=1.0), Sample(index=3, reward=0.0)],
]

def drop_zero_reward(args, data):          # fn(args, data) contract
    for group in data:
        for s in group:
            if s.reward == 0.0:
                s.remove_sample = True

Before — the filter is never called for a custom rollout fn; nothing is flagged, so the zero-reward samples still contribute to the loss:

{s.index for s in flat if s.remove_sample}   # == set()   (filter never ran)
# samples 1 and 3 (reward 0.0) keep loss_mask = [1, 1, ...] and train as positives

After — the same input flows through the manager choke point, the filter runs, and the zero-reward samples are flagged → train_data_conversion zeros their loss_mask:

{s.index for s in flat if s.remove_sample}   # == {1, 3}
# samples 1 and 3 get loss_mask = [0, 0, ...] and are excluded from the loss

(This is exactly test_custom_rollout_fn_now_honors_filter_via_manager_path in the new test file.)

Fix

Hoist the filter to the single generic choke point, postprocess_rollout_data in miles/ray/rollout/rollout_data_conversion.py, which is reached by RolloutManager._get_rollout_data for every rollout function:

# Generic choke point: every rollout fn honors --rollout-sample-filter-path here,
# while `data` is still grouped list[list[Sample]] (before the flatten below).
if (filter_func := load_function(args.rollout_sample_filter_path)) is not None:
    filter_func(args, data)

while isinstance(data[0], list):
    data = list(itertools.chain.from_iterable(data))

The filter runs while data is still the documented grouped list[list[Sample]] — the fn(args, data) contract — before the existing flatten/trim, and flags Sample.remove_sample in place. The redundant in-function applications are removed from both built-in rollout functions (to avoid double-filtering), with a one-line comment left at each former call site pointing to the new owner.

Scope is kept deliberately tight:

  • --rollout-all-samples-process-path stays inside the rollout functions: it needs all_samples / data_source, which the manager does not have.
  • The load_debug_rollout_data path does not flow through postprocess_rollout_data, so the filter applies on the normal post-rollout path only.

Why this is the right fix

  • Default-path safe / bit-identical when unset. load_function(None) returns None, so the walrus-guarded block is a no-op when --rollout-sample-filter-path is not set; output is identical to before. Covered by test_no_filter_path_is_a_no_op_default.
  • Matches the documented-but-only-partially-implemented contract. The arg help already specifies fn(args, data) with data as list[list[Sample]] grouped by n_samples_per_prompt, and set sample.remove_sample = True to exclude from loss. The new call site applies it on exactly that shape, before the flatten — generalizing the existing contract rather than changing it. No new flag is introduced; only where the existing argument is applied changes.
  • No new abstraction. The choke point already existed and already owned the flatten/trim of grouped rollout data; the filter is one guarded line at the natural seam, and the consumer side (remove_sample → loss_mask = 0) is untouched.
  • CI-verifiable, torch-free. New CPU tests in tests/fast/ray/rollout/test_rollout_data_conversion.py assert, via the manager path, that (1) the filter receives grouped list[list[Sample]] before the flatten, is invoked exactly once per batch, and its in-place mutation survives flattening, (2) a custom rollout fn's grouped output now honors the filter (the previously-broken case above), and (3) rollout_sample_filter_path=None flags nothing. The integration test (tests/fast/rollout/inference_rollout/integration/test_sample_filter.py) asserts the filter is no longer applied in-function — parametrized over both built-in rollout fns (the modular InferenceRolloutFn and the legacy sglang_rollout.generate_rollout), so reintroducing the in-function filter in either one fails as double-filtering — while all-samples-process is still applied in-function, asserted against the real over-sampling invariant (all_samples strictly exceeds the trained set). All run in the existing tests/fast/ stage.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request hoists the application of --rollout-sample-filter-path to the generic postprocess_rollout_data manager function, ensuring that all rollout functions (including custom ones) consistently honor the filter and avoiding double-filtering in individual rollout implementations. Feedback suggests adding an empty check for data to prevent an IndexError and using getattr to safely retrieve the filter path from args to avoid an AttributeError.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines 11 to +16

# Apply the rollout sample filter at this generic manager choke point so every
# rollout fn honors --rollout-sample-filter-path, while `data` is still the
# documented grouped `list[list[Sample]]` (before the flatten/trim below).
if (filter_func := load_function(args.rollout_sample_filter_path)) is not None:
filter_func(args, data)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent a potential AttributeError if args does not contain the rollout_sample_filter_path attribute, and to avoid an IndexError on data[0] if data is empty, we should add an empty check for data and use getattr to safely retrieve the filter path.

Suggested change
# Apply the rollout sample filter at this generic manager choke point so every
# rollout fn honors --rollout-sample-filter-path, while `data` is still the
# documented grouped `list[list[Sample]]` (before the flatten/trim below).
if (filter_func := load_function(args.rollout_sample_filter_path)) is not None:
filter_func(args, data)
if not data:
return data, metadata
# Apply the rollout sample filter at this generic manager choke point so every
# rollout fn honors --rollout-sample-filter-path, while data is still the
# documented grouped list[list[Sample]] (before the flatten/trim below).
if (filter_func := load_function(getattr(args, "rollout_sample_filter_path", None))) is not None:
filter_func(args, data)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the careful read. I'm going to leave both of these as-is, with reasoning:

  1. getattr fallback: --rollout-sample-filter-path is declared unconditionally in arguments.py (with default=None), so args.rollout_sample_filter_path is always present on the parsed args namespace and can't raise AttributeError. load_function also returns None for a None path, so the existing if (filter_func := load_function(args.rollout_sample_filter_path)) is not None: walrus already handles the unset case. Using getattr(args, 'rollout_sample_filter_path', None) here would actually be misleading, since it suggests the attribute might be missing when it never is.

  2. Empty-data / IndexError: the line this PR adds (filter_func(args, data)) doesn't index data[0]. The data[0] access is in the pre-existing flatten loop while isinstance(data[0], list), which is unchanged from upstream/main. So an empty data would already raise the same IndexError upstream, independent of this change — it's pre-existing and orthogonal to routing the sample filter through the manager. Adding an empty-data guard here would also change default-path behavior (it would silently return empty data instead of surfacing the existing error), which is outside the scope of this fix. If empty rollout batches turn out to be a real, reachable condition, that's worth its own dedicated change rather than folding it into this one.

@EazyReal EazyReal force-pushed the upstream-pr/manager-side-sample-filter branch from c984ebc to cebecf1 Compare June 12, 2026 09:02
@EazyReal EazyReal changed the title fix(rollout): apply --rollout-sample-filter-path generically in the manager fix(rollout): apply rollout sample filter in the manager Jun 13, 2026
The documented --rollout-sample-filter-path hook is silently ignored for any
custom --rollout-function-path. It was applied only INSIDE the two built-in
rollout fns (miles/rollout/sglang_rollout.py and
miles/rollout/inference_rollout/inference_rollout_train.py), so a user who swaps
in a custom rollout function got no sample filtering at all even though the flag
was set and the consumer (train_data_conversion: remove_sample -> loss_mask=0)
is fully generic. Only the PRODUCER side was non-generic.

Fix: hoist the filter to the single generic choke point — postprocess_rollout_data,
reached by RolloutManager._get_rollout_data for EVERY rollout fn. The filter runs
while `data` is still the documented grouped `list[list[Sample]]` (the
`fn(args, data)` contract), i.e. before the existing flatten/trim. The redundant
in-fn applications are removed from both built-in fns to avoid double-filtering.

Scope kept tight:
- --rollout-all-samples-process-path stays inside the rollout fns: it needs
  all_samples / data_source that the manager does not have.
- The load_debug_rollout_data path does not flow through postprocess_rollout_data,
  so the filter applies on the normal post-rollout path only.

Default-path safety: walrus-guarded load_function(None) -> None, so with
rollout_sample_filter_path unset the choke point is a no-op and output is
bit-identical to before.

Tests: add CPU tests in tests/fast/ray/rollout/test_rollout_data_conversion.py
asserting a custom rollout fn's grouped output now honors the filter via the
manager path (grouped-before-flatten contract, in-place remove_sample, and the
None default no-op). The inference-rollout integration test asserts the filter
is no longer applied in-fn (manager owns it now) — parametrized over BOTH
built-in rollout fns (modular InferenceRolloutFn and the legacy
sglang_rollout.generate_rollout), so reintroducing the in-fn filter in either
one fails as double-filtering — while all-samples-process is still applied
in-fn, asserted against the real over-sampling invariant (all_samples strictly
exceeds the trained set).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@EazyReal EazyReal force-pushed the upstream-pr/manager-side-sample-filter branch from cebecf1 to 20e78a4 Compare June 13, 2026 04:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant