Skip to content

fix(qwen3-vl): per-segment mRoPE + vision under CP + THD packing#1308

Merged
Zhichenzzz merged 1 commit into
zhichen/qwen3-vl-thd-miles-hijackfrom
fix/1296-qwen3vl-cp-mrope
Jun 19, 2026
Merged

fix(qwen3-vl): per-segment mRoPE + vision under CP + THD packing#1308
Zhichenzzz merged 1 commit into
zhichen/qwen3-vl-thd-miles-hijackfrom
fix/1296-qwen3vl-cp-mrope

Conversation

@Zhichenzzz

@Zhichenzzz Zhichenzzz commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Follow-up to #1272 (stacked on its branch, so this diff shows only the #1296 changes — merge after #1272).

What

Makes Qwen3-VL train end-to-end under context parallelism + THD sequence packing in bridge mode.

  1. Per-segment mRoPE under CP — when the THD row is CP-sharded (zigzag), _build_packed_positions all-gathers the per-rank rows, de-interleaves to the full row (_reassemble_full_row, unit-tested in tests/fast/test_qwen3_vl_cp_mrope.py), rebuilds per-segment MRoPE, and re-slices into this rank's zigzag layout.
  2. Don't double-shard — when the input is already CP-local, the bridge's internal preprocess_packed_seqs is wrapped to an identity that returns miles' full-cu packed_seq_params (CP attention still sees the full cu; the data isn't re-split).
  3. CP-local vision embedsselect_local_vision_embeds maps each rank's local vision tokens to the matching slice of the full vision-tower output (and deepstack). Cooperates with a small hook in megatron-bridge (separate PR to radixark/Megatron-Bridge).
  4. Wires calculate_per_token_loss into the bridge provider (Qwen3-VL asserts on it under CP) and a defensive AllGatherVisionEmbeddings.apply kwarg shim.

Validation (e2e)

Qwen3-VL-2B geo3k, CP=2 TP=4, 8×H200, THD packed, bridge mode: stable over 3 steps, train_rollout_logprob_abs_diff 0.0141 → 0.0146 (healthy, == non-CP 0.011–0.016), rollout/raw_reward ~0.4, no crashes.

Depends on

Fixes #1296


Update: cleanup pass + re-validation

  • helpers' first parameter renamed selfmodel (they are free functions, not methods)
  • the preprocess_packed_seqs identity wrapper forwards *args/**kwargs instead of hard-coding the upstream signature
  • the patch warns once at install when the running megatron-bridge lacks the _miles_select_local_vision_embeds hook (instead of silently mis-placing vision embeddings under CP); points at the matching Megatron-Bridge patch

Re-validated after cleanup: unit tests 6/6; Qwen3-VL-2B CP2 TP4 THD geo3k RL — train_rollout_logprob_abs_diff 0.0127–0.0131 (same healthy band as the original validation), coherent rollouts, no hook warning with the bridge patch installed.

Note on stacking: this branch contains #1272; the cleanup lives here because this PR owns the final shape of qwen3_vl_packed_mrope.py.


Update: vision-embed plumbing removed (−53 lines)

Megatron-Bridge PR #9 now selects CP-local vision embeddings natively inside Qwen3VLModel (no override hook), so this PR no longer carries select_local_vision_embeds, the local→full position mapping, or any hook installation. What remains on the miles side: per-segment mRoPE position reconstruction and the preprocess_packed_seqs identity wrapper (both still needed — they are about positions, not embeddings), plus a warning when the running bridge lacks the native support.

Re-validated end-to-end after the removal (Qwen3-VL-2B, CP2 TP2, THD, geo3k RL): train_rollout_logprob_abs_diff 0.0130, coherent rollouts, no warnings, no crashes.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Qwen3-VL context parallelism (CP) and THD packed mRoPE reconstruction. It adds patches to handle CP-local vision embeddings, bypasses redundant re-sharding in preprocess_packed_seqs, and implements a shim for AllGatherVisionEmbeddings to accept cp_group as a keyword argument. Additionally, a comprehensive CPU unit test suite is added to verify the correctness of the zigzag reconstruction. The reviewer feedback suggests improving the robustness of the patches by having the _AllGatherVisionEmbeddingsKwargShim inherit from the original class to preserve its type hierarchy, and using *args and **kwargs in the preprocess_packed_seqs wrapper to guard against future signature changes.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +43 to +48
class _AllGatherVisionEmbeddingsKwargShim:
_miles_kwarg_shim = True

@staticmethod
def apply(input, seqlens_on_cp_ranks, cp_group=None):
return orig.apply(input, seqlens_on_cp_ranks, cp_group)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure that _AllGatherVisionEmbeddingsKwargShim behaves identically to the original AllGatherVisionEmbeddings class (e.g., preserving class attributes, static methods, or satisfying issubclass / isinstance checks in downstream code), it is safer to have the shim inherit from orig instead of being a completely separate, plain class.

Suggested change
class _AllGatherVisionEmbeddingsKwargShim:
_miles_kwarg_shim = True
@staticmethod
def apply(input, seqlens_on_cp_ranks, cp_group=None):
return orig.apply(input, seqlens_on_cp_ranks, cp_group)
class _AllGatherVisionEmbeddingsKwargShim(orig):
_miles_kwarg_shim = True
@staticmethod
def apply(input, seqlens_on_cp_ranks, cp_group=None):
return orig.apply(input, seqlens_on_cp_ranks, cp_group)

Comment on lines +132 to +139
def wrapped(input_ids, attention_mask, pre_process=True, pg_collection=None):
ctx = getattr(_tls, "cp_local", None)
if ctx is not None:
# already-local CP path: do not re-shard; return the data unchanged together with
# miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention
# uses the packed_seq_params passed into forward, which already has the full cu).
return input_ids, ctx["psp"]
return orig(input_ids, attention_mask, pre_process=pre_process, pg_collection=pg_collection)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To make the monkeypatched wrapped function more robust against future signature changes in preprocess_packed_seqs (e.g., if the bridge library adds or reorders arguments), it is highly recommended to use *args and **kwargs and extract input_ids dynamically. This prevents potential TypeError exceptions due to signature mismatches.

Suggested change
def wrapped(input_ids, attention_mask, pre_process=True, pg_collection=None):
ctx = getattr(_tls, "cp_local", None)
if ctx is not None:
# already-local CP path: do not re-shard; return the data unchanged together with
# miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention
# uses the packed_seq_params passed into forward, which already has the full cu).
return input_ids, ctx["psp"]
return orig(input_ids, attention_mask, pre_process=pre_process, pg_collection=pg_collection)
def wrapped(*args, **kwargs):
ctx = getattr(_tls, "cp_local", None)
if ctx is not None:
# already-local CP path: do not re-shard; return the data unchanged together with
# miles' full-cu packed_seq_params (callers ignore the psp; the model's CP attention
# uses the packed_seq_params passed into forward, which already has the full cu).
input_ids = kwargs.get("input_ids") if "input_ids" in kwargs else args[0]
return input_ids, ctx["psp"]
return orig(*args, **kwargs)

@Zhichenzzz Zhichenzzz force-pushed the zhichen/qwen3-vl-thd-miles-hijack branch from db3376c to 8ef25fd Compare June 18, 2026 21:39
@Zhichenzzz Zhichenzzz force-pushed the fix/1296-qwen3vl-cp-mrope branch from 82f3215 to 0788969 Compare June 18, 2026 21:39
@Zhichenzzz Zhichenzzz force-pushed the zhichen/qwen3-vl-thd-miles-hijack branch from 8ef25fd to 9149284 Compare June 18, 2026 21:49
@Zhichenzzz Zhichenzzz force-pushed the fix/1296-qwen3vl-cp-mrope branch 2 times, most recently from de1c21b to e37e614 Compare June 18, 2026 21:54
…acking

Extend the miles-side Qwen3-VL patch to context parallelism: reassemble each
rank zigzag THD row into the full natural-order row, rebuild per-segment MRoPE
positions, and re-slice them back to the CP layout. Also monkeypatch the bridge
AllGatherVisionEmbeddings kwarg bug and set calculate_per_token_loss for CP>1 VL
configs that skip core_transformer_config_from_args. Adds a CPU unit test for
the CP+THD reconstruction round-trip.
@Zhichenzzz Zhichenzzz force-pushed the fix/1296-qwen3vl-cp-mrope branch from e37e614 to 702a850 Compare June 18, 2026 22:29

@yueming-yuan yueming-yuan left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved to unblock, and we may need to unify all the CP utils in the training backend refactor

@Zhichenzzz Zhichenzzz merged commit 6a1db39 into zhichen/qwen3-vl-thd-miles-hijack Jun 19, 2026
29 checks passed
@Zhichenzzz Zhichenzzz deleted the fix/1296-qwen3vl-cp-mrope branch June 19, 2026 17:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants