fix(qwen3-vl): per-segment mRoPE + vision under CP + THD packing#1308
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for Qwen3-VL context parallelism (CP) and THD packed mRoPE reconstruction. It adds patches to handle CP-local vision embeddings, bypasses redundant re-sharding in preprocess_packed_seqs, and implements a shim for AllGatherVisionEmbeddings to accept cp_group as a keyword argument. Additionally, a comprehensive CPU unit test suite is added to verify the correctness of the zigzag reconstruction. The reviewer feedback suggests improving the robustness of the patches by having the _AllGatherVisionEmbeddingsKwargShim inherit from the original class to preserve its type hierarchy, and using *args and **kwargs in the preprocess_packed_seqs wrapper to guard against future signature changes.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| class _AllGatherVisionEmbeddingsKwargShim: | ||
| _miles_kwarg_shim = True | ||
|
|
||
| @staticmethod | ||
| def apply(input, seqlens_on_cp_ranks, cp_group=None): | ||
| return orig.apply(input, seqlens_on_cp_ranks, cp_group) |
There was a problem hiding this comment.
To ensure that _AllGatherVisionEmbeddingsKwargShim behaves identically to the original AllGatherVisionEmbeddings class (e.g., preserving class attributes, static methods, or satisfying issubclass / isinstance checks in downstream code), it is safer to have the shim inherit from orig instead of being a completely separate, plain class.
| class _AllGatherVisionEmbeddingsKwargShim: | |
| _miles_kwarg_shim = True | |
| @staticmethod | |
| def apply(input, seqlens_on_cp_ranks, cp_group=None): | |
| return orig.apply(input, seqlens_on_cp_ranks, cp_group) | |
| class _AllGatherVisionEmbeddingsKwargShim(orig): | |
| _miles_kwarg_shim = True | |
| @staticmethod | |
| def apply(input, seqlens_on_cp_ranks, cp_group=None): | |
| return orig.apply(input, seqlens_on_cp_ranks, cp_group) |
| def wrapped(input_ids, attention_mask, pre_process=True, pg_collection=None): | ||
| ctx = getattr(_tls, "cp_local", None) | ||
| if ctx is not None: | ||
| # already-local CP path: do not re-shard; return the data unchanged together with | ||
| # miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention | ||
| # uses the packed_seq_params passed into forward, which already has the full cu). | ||
| return input_ids, ctx["psp"] | ||
| return orig(input_ids, attention_mask, pre_process=pre_process, pg_collection=pg_collection) |
There was a problem hiding this comment.
To make the monkeypatched wrapped function more robust against future signature changes in preprocess_packed_seqs (e.g., if the bridge library adds or reorders arguments), it is highly recommended to use *args and **kwargs and extract input_ids dynamically. This prevents potential TypeError exceptions due to signature mismatches.
| def wrapped(input_ids, attention_mask, pre_process=True, pg_collection=None): | |
| ctx = getattr(_tls, "cp_local", None) | |
| if ctx is not None: | |
| # already-local CP path: do not re-shard; return the data unchanged together with | |
| # miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention | |
| # uses the packed_seq_params passed into forward, which already has the full cu). | |
| return input_ids, ctx["psp"] | |
| return orig(input_ids, attention_mask, pre_process=pre_process, pg_collection=pg_collection) | |
| def wrapped(*args, **kwargs): | |
| ctx = getattr(_tls, "cp_local", None) | |
| if ctx is not None: | |
| # already-local CP path: do not re-shard; return the data unchanged together with | |
| # miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention | |
| # uses the packed_seq_params passed into forward, which already has the full cu). | |
| input_ids = kwargs.get("input_ids") if "input_ids" in kwargs else args[0] | |
| return input_ids, ctx["psp"] | |
| return orig(*args, **kwargs) |
db3376c to
8ef25fd
Compare
82f3215 to
0788969
Compare
8ef25fd to
9149284
Compare
de1c21b to
e37e614
Compare
…acking Extend the miles-side Qwen3-VL patch to context parallelism: reassemble each rank zigzag THD row into the full natural-order row, rebuild per-segment MRoPE positions, and re-slice them back to the CP layout. Also monkeypatch the bridge AllGatherVisionEmbeddings kwarg bug and set calculate_per_token_loss for CP>1 VL configs that skip core_transformer_config_from_args. Adds a CPU unit test for the CP+THD reconstruction round-trip.
e37e614 to
702a850
Compare
yueming-yuan
left a comment
There was a problem hiding this comment.
approved to unblock, and we may need to unify all the CP utils in the training backend refactor
6a1db39
into
zhichen/qwen3-vl-thd-miles-hijack
Follow-up to #1272 (stacked on its branch, so this diff shows only the #1296 changes — merge after #1272).
What
Makes Qwen3-VL train end-to-end under context parallelism + THD sequence packing in bridge mode.
_build_packed_positionsall-gathers the per-rank rows, de-interleaves to the full row (_reassemble_full_row, unit-tested intests/fast/test_qwen3_vl_cp_mrope.py), rebuilds per-segment MRoPE, and re-slices into this rank's zigzag layout.preprocess_packed_seqsis wrapped to an identity that returns miles' full-cupacked_seq_params(CP attention still sees the full cu; the data isn't re-split).select_local_vision_embedsmaps each rank's local vision tokens to the matching slice of the full vision-tower output (and deepstack). Cooperates with a small hook in megatron-bridge (separate PR to radixark/Megatron-Bridge).calculate_per_token_lossinto the bridge provider (Qwen3-VL asserts on it under CP) and a defensiveAllGatherVisionEmbeddings.applykwarg shim.Validation (e2e)
Qwen3-VL-2B geo3k, CP=2 TP=4, 8×H200, THD packed, bridge mode: stable over 3 steps,
train_rollout_logprob_abs_diff0.0141 → 0.0146 (healthy, == non-CP 0.011–0.016),rollout/raw_reward~0.4, no crashes.Depends on
Fixes #1296
Update: cleanup pass + re-validation
self→model(they are free functions, not methods)preprocess_packed_seqsidentity wrapper forwards*args/**kwargsinstead of hard-coding the upstream signature_miles_select_local_vision_embedshook (instead of silently mis-placing vision embeddings under CP); points at the matching Megatron-Bridge patchRe-validated after cleanup: unit tests 6/6; Qwen3-VL-2B CP2 TP4 THD geo3k RL —
train_rollout_logprob_abs_diff0.0127–0.0131 (same healthy band as the original validation), coherent rollouts, no hook warning with the bridge patch installed.Note on stacking: this branch contains #1272; the cleanup lives here because this PR owns the final shape of
qwen3_vl_packed_mrope.py.Update: vision-embed plumbing removed (−53 lines)
Megatron-Bridge PR #9 now selects CP-local vision embeddings natively inside
Qwen3VLModel(no override hook), so this PR no longer carriesselect_local_vision_embeds, the local→full position mapping, or any hook installation. What remains on the miles side: per-segment mRoPE position reconstruction and thepreprocess_packed_seqsidentity wrapper (both still needed — they are about positions, not embeddings), plus a warning when the running bridge lacks the native support.Re-validated end-to-end after the removal (Qwen3-VL-2B, CP2 TP2, THD, geo3k RL):
train_rollout_logprob_abs_diff0.0130, coherent rollouts, no warnings, no crashes.