Skip to content
6 changes: 3 additions & 3 deletions xtuner/v1/model/compose/intern_s1/modeling_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def forward(
dtype=torch.int32,
device=hidden_states.device)

attn_output: torch.Tensor = self.attn_impl_func( # type: ignore
attn_output, extra_info = self.attn_impl_func( # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a good idea to change the attention function signature like this. Instead, we should define an AttnOutput type (using TypedDict or namedtuple) to represent the attention result."

query_states[None].transpose(1, 2), # [b, n_head, seq, head_dim]
key_states[None].transpose(1, 2),
value_states[None].transpose(1, 2),
Expand All @@ -118,7 +118,7 @@ def forward(
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
output = self.projection_layer(attn_output)
output = self.projection_dropout(output)
return output
return output, extra_info


class InternS1VisionMLP(nn.Module):
Expand Down Expand Up @@ -188,7 +188,7 @@ def init_weights(self):

@maybe_compile(fullgraph=True)
def attention_pre_forward(self, hidden_states):
attention_output = self.attention(self.layernorm_before(hidden_states))
attention_output, _ = self.attention(self.layernorm_before(hidden_states))
attention_output = self.lambda_1 * attention_output
return attention_output

Expand Down
6 changes: 3 additions & 3 deletions xtuner/v1/model/compose/qwen3_vl/modeling_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def forward(
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)

attn_output: torch.Tensor = self.attn_impl_func( # type: ignore
attn_output, extra_info = self.attn_impl_func( # type: ignore
query_states, # [b, n_head, seq, head_dim]
key_states,
value_states,
Expand All @@ -153,7 +153,7 @@ def forward(

attn_output = attn_output[0].reshape(seq_length, -1).contiguous() # s, d
attn_output = self.proj(attn_output)
return attn_output
return attn_output, extra_info


class Qwen3VLVisionLayer(nn.Module):
Expand All @@ -177,7 +177,7 @@ def forward(
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_embeddings=position_embeddings
)
)[0]
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states

Expand Down
39 changes: 22 additions & 17 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def forward(
self,
seq_ctx: list[SequenceContext] | SequenceContext,
loss_ctx: list[CELossContext] | CELossContext | None,
return_router_logits: bool = False,
):
# TODO: caoweihan: Recover this assertion after the refactor of LossContext
if isinstance(seq_ctx, SequenceContext):
Expand All @@ -269,6 +270,7 @@ def forward(
return self._forward(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx, # type: ignore
return_router_logits=return_router_logits,
)
else:
assert isinstance(loss_ctx, list) and len(loss_ctx) == len(seq_ctx), (
Expand All @@ -280,12 +282,14 @@ def forward(
return self._micro_batch_forward(
seq_ctx_list=seq_ctx,
loss_ctx_list=loss_ctx,
return_router_logits=return_router_logits,
)

def _micro_batch_forward(
self,
seq_ctx_list: list[SequenceContext],
loss_ctx_list: list[CELossContext],
return_router_logits: bool = False,
) -> MoEModelOutputs:
"""Micro-batch forward pass for MoE model.

Expand Down Expand Up @@ -454,29 +458,30 @@ def _micro_batch_forward(
else:
final_logits = None

if self.config.return_router_results:
raise NotImplementedError
if self.config.return_router_results or return_router_logits:
# raise NotImplementedError

# TODO: Return router logits is costy

# router_logits_dict: dict[str, torch.Tensor] = {}
# layer_names = list(router_logits_list[0].keys())
#
# for layer_name in layer_names:
# layer_router_logits_list: list[torch.Tensor] = []
# for micro_batch_idx in range(len(seq_ctx_list)):
# layer_router_logits_list.append(router_logits_list[micro_batch_idx][layer_name].clone().detach())
# router_logits = torch.stack(layer_router_logits_list, dim=0).unsqueeze(0)
# router_logits_dict["router_logits"] = router_logits
#
# output["router_logits"] = router_logits_dict
router_logits_dict: dict[str, torch.Tensor] = {}
layer_names = list(router_logits_list[0].keys())

for layer_name in layer_names:
layer_router_logits_list: list[torch.Tensor] = []
for micro_batch_idx in range(len(seq_ctx_list)):
layer_router_logits_list.append(router_logits_list[micro_batch_idx][layer_name].clone().detach())
router_logits = torch.stack(layer_router_logits_list, dim=0).unsqueeze(0)
router_logits_dict["router_logits"] = router_logits

output["router_logits"] = router_logits_dict

return MoEModelOutputs(**output, logits=final_logits) # type: ignore[typeddict-item]

def _forward(
self,
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
loss_ctx: CELossContext | None,
return_router_logits: bool = False,
) -> MoEModelOutputs:
input_ids = seq_ctx.input_ids
position_ids = seq_ctx.position_ids
Expand Down Expand Up @@ -561,11 +566,11 @@ def _forward(

del router_logits

if self.config.return_router_results:
raise NotImplementedError
if self.config.return_router_results or return_router_logits:
# raise NotImplementedError
# TODO: Move router logits to CPU is cost
# for layer_name, router_logits in output["router_logits"].items():
# output["router_logits"][layer_name] = router_logits.detach().cpu().unsqueeze(0)
for layer_name, router_logits in output["router_logits"].items():
output["router_logits"][layer_name] = router_logits.detach().unsqueeze(0)
else:
output["router_logits"] = None

Expand Down
3 changes: 2 additions & 1 deletion xtuner/v1/model/moe/qwen3vl_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _forward(
self,
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
loss_ctx: CELossContext | None,
return_router_logits: bool = False,
) -> MoEModelOutputs:
input_ids = seq_ctx.input_ids
position_ids = seq_ctx.position_ids
Expand Down Expand Up @@ -210,7 +211,7 @@ def _forward(

del router_logits

if self.config.return_router_results:
if self.config.return_router_results or return_router_logits:
raise NotImplementedError
# TODO: Move router logits to CPU is cost
# for layer_name, router_logits in output["router_logits"].items():
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/module/attention/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def forward(
sinks = self.sinks
kwargs["s_aux"] = sinks
# [b, n_head, seq, head_dim]
attn_output: torch.Tensor = self.attn_impl_func( # type: ignore
attn_output, extra_info = self.attn_impl_func( # type: ignore
query_states,
key_states,
value_states,
Expand All @@ -404,7 +404,7 @@ def forward(

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
return attn_output, extra_info

def build_kv_cache(
self, max_batch_size: int | None = None, max_length: int | None = None, block_size: int | None = None
Expand Down
13 changes: 10 additions & 3 deletions xtuner/v1/module/attention/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def prefilling(
block_offsets=seq_ctx.block_table,
) # type: ignore[assignment]

attn_output: torch.Tensor = flash_attn_varlen_func(
attn_output, extra_info = flash_attn_varlen_func( # type: ignore
query_states.squeeze(0),
key_states.squeeze(0),
value_states.squeeze(0),
Expand Down Expand Up @@ -599,7 +599,7 @@ def forward(
assert query_states.size(0) == 1
assert key_states.size(0) == 1
assert value_states.size(0) == 1
attn_output = flash_attn_varlen_func(
attn_outputs = flash_attn_varlen_func(
query_states.transpose(1, 2).squeeze(0),
key_states.transpose(1, 2).squeeze(0),
value_states.transpose(1, 2).squeeze(0),
Expand All @@ -611,7 +611,14 @@ def forward(
softmax_scale=self.softmax_scale,
causal=True,
deterministic=XTUNER_DETERMINISTIC,
return_attn_probs=True,
)
extra_info = {}
if isinstance(attn_outputs, tuple):
attn_output = attn_outputs[0]
extra_info["softmax_lse"] = attn_outputs[1].detach()
else:
attn_output = attn_outputs
attn_output = cast(torch.Tensor, attn_output)
if self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, : self.v_head_dim]
Expand All @@ -620,7 +627,7 @@ def forward(

attn_output = self.o_proj(attn_output)

return attn_output
return attn_output, extra_info

def build_kv_cache(
self, max_batch_size: int | None = None, max_length: int | None = None, block_size: int | None = None
Expand Down
3 changes: 2 additions & 1 deletion xtuner/v1/module/decoder_layer/dense_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
generate_config=generate_config,
float8_cfg=float8_cfg,
)
self.self_attn.name = f"layers.{layer_idx}.self_attn"
self.mlp = DenseMLP(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
Expand All @@ -84,7 +85,7 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
seq_ctx=seq_ctx,
Expand Down
3 changes: 2 additions & 1 deletion xtuner/v1/module/decoder_layer/moe_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(
layer_type=layer_type,
float8_cfg=float8_cfg,
)
self.self_attn.name = f"layers.{layer_idx}.self_attn"
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.shared_experts: MoEMLP | None
self.layer_idx = layer_idx
Expand Down Expand Up @@ -540,7 +541,7 @@ def _pre_moe_forward(

# Self Attention
if state == ForwardState.TRAINING:
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
seq_ctx=seq_ctx,
Expand Down
39 changes: 25 additions & 14 deletions xtuner/v1/ops/attn_imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,13 @@ def compile_friendly_flex_attention(
score_mod=None,
enable_gqa: bool = False,
scale: float | None = None,
return_lse: bool = False,
) -> torch.Tensor:
global flex_attention_compiled
if flex_attention_compiled is None:
flex_attention_compiled = get_flex_attention_compiled()
return flex_attention_compiled( # type: ignore
q, k, v, block_mask=block_mask, score_mod=score_mod, scale=scale, enable_gqa=enable_gqa
q, k, v, block_mask=block_mask, score_mod=score_mod, scale=scale, enable_gqa=enable_gqa, return_lse=return_lse
)


Expand Down Expand Up @@ -126,7 +127,7 @@ def mask_mod(b, h, q_idx, kv_idx):

def eager_attention(
q, k, v, cu_seqlens_q, softmax_scale, window_size=(-1, -1), dropout_p=0.0, s_aux=None, **kwargs
) -> torch.Tensor:
) -> tuple[torch.Tensor, dict]:
# TODO(HHA): Currently, the mask is recalculated each time, which is quite time-consuming.
# It should be refactored to be calculated only once.

Expand All @@ -150,6 +151,7 @@ def eager_attention(
causal_mask = attention_mask[:, :, :, : k.shape[-2]]
attn_weights = attn_weights + causal_mask

extra_info = {}
if s_aux is not None:
# This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16
# when training with bsz>1 we clamp max values.
Expand All @@ -159,18 +161,20 @@ def eager_attention(
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
scores = probs[..., :-1] # we drop the sink here
extra_info["attn_logits"] = combined_logits.detach()
else:
scores = torch.softmax(attn_weights, dim=-1, dtype=attn_weights.dtype)
extra_info["attn_logits"] = attn_weights.detach()

attn_weights = nn.functional.dropout(scores, p=dropout_p, training=True)
attn_output = torch.matmul(attn_weights, v)
attn_scores = nn.functional.dropout(scores, p=dropout_p, training=True)
attn_output = torch.matmul(attn_scores, v)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output
return attn_output, extra_info


def flex_attention(
q, k, v, cu_seqlens_q, softmax_scale=None, window_size=(-1, -1), dropout_p=0.0, s_aux=None, causal=True, **kwargs
) -> torch.Tensor:
) -> tuple[torch.Tensor, dict]:
# q, k, v: [b, n_head, seq, head_dim]
assert dropout_p == 0.0, "Dropout is not supported in flex attention"

Expand All @@ -187,40 +191,47 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
mask = create_packing_block_causal_mask(cu_seqlens_q, window_size=window_size, causal=causal)
enable_gqa = k.size(1) != q.size(1)

attention_output = compile_friendly_flex_attention(
attention_output, softmax_lse = compile_friendly_flex_attention(
q,
k,
v,
block_mask=mask,
score_mod=score_mod_fn,
scale=softmax_scale,
enable_gqa=enable_gqa,
return_lse=True,
)
attention_output = attention_output.transpose(1, 2).contiguous()
return attention_output
extra_info = {"softmax_lse": softmax_lse.detach()}
return attention_output, extra_info


def flash_attention(q, k, v, window_size=(-1, -1), s_aux=None, **kwargs) -> torch.Tensor:
def flash_attention(q, k, v, window_size=(-1, -1), s_aux=None, **kwargs) -> tuple[torch.Tensor, dict]:
# q, k, v: [b, n_head, seq , head_dim]
assert q.size(0) == 1, "Only support batch size 1 for flash attention"
q = q.transpose(1, 2).squeeze(0) # [seq, head, dim]
k = k.transpose(1, 2).squeeze(0)
v = v.transpose(1, 2).squeeze(0)

attention_output: torch.Tensor

extra_info = {}
if s_aux is None:
if flash_attn_exception is not None:
traceback.print_exception(flash_attn_exception)
raise flash_attn_exception
attention_output = flash_attn_varlen_func(q, k, v, **kwargs)
attention_outputs = flash_attn_varlen_func(q, k, v, return_attn_probs=True, **kwargs) # type: ignore
if isinstance(attention_outputs, tuple):
attention_output = attention_outputs[0]
extra_info["softmax_lse"] = attention_outputs[1].detach()
else: # npu fused attn doesn't support softmax_lse
attention_output = attention_outputs
else:
if flash_sink_attn_exception is not None:
traceback.print_exception(flash_sink_attn_exception)
raise flash_sink_attn_exception
cu_seqlens_q = kwargs["cu_seqlens_q"]
attention_output = flash_sink_attn_varlen_func(q, k, v, s_aux, cu_seqlens_q, window_size[0])
return attention_output[None]
attention_output, softmax_lse = flash_sink_attn_varlen_func(q, k, v, s_aux, cu_seqlens_q, window_size[0])
extra_info["softmax_lse"] = softmax_lse.detach()
return attention_output[None], extra_info


attn_impl_mapping = {
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/ops/flash_attn/flash_sink_varlen_attn_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,10 +496,10 @@ def forward(
ctx.window_size = window_size
ctx.cu_seqlen = cu_seqlen

return o
return o, lse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lse should only be returned when specific flags are passed in. Otherwise, only the attention output should be returned. This makes for a cleaner attention interface.


@staticmethod
def backward(ctx, do):
def backward(ctx, do, dlse):
q, k, v, o, lse = ctx.saved_tensors

dq = torch.zeros_like(q)
Expand Down
Loading
Loading