Skip to content

Commit e557bee

Browse files
committed
[Enhance] resolve some lint issues
[Fix] fix rms_norm no_grad
1 parent 4fc5fe4 commit e557bee

File tree

7 files changed

+37
-38
lines changed

7 files changed

+37
-38
lines changed

xtuner/v1/float8/float8_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,4 @@ def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None):
165165
args[0]._orig_dtype,
166166
args[0]._scaling_granularity,
167167
args[0]._group_size,
168-
)
168+
)

xtuner/v1/model/moe/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,14 +465,14 @@ def _micro_batch_forward(
465465

466466
router_logits_dict: dict[str, torch.Tensor] = {}
467467
layer_names = list(router_logits_list[0].keys())
468-
468+
469469
for layer_name in layer_names:
470470
layer_router_logits_list: list[torch.Tensor] = []
471471
for micro_batch_idx in range(len(seq_ctx_list)):
472472
layer_router_logits_list.append(router_logits_list[micro_batch_idx][layer_name].clone().detach())
473473
router_logits = torch.stack(layer_router_logits_list, dim=0).unsqueeze(0)
474474
router_logits_dict["router_logits"] = router_logits
475-
475+
476476
output["router_logits"] = router_logits_dict
477477

478478
return MoEModelOutputs(**output, logits=final_logits) # type: ignore[typeddict-item]

xtuner/v1/model/moe/qwen3vl_text.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def _forward(
111111
self,
112112
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
113113
loss_ctx: CELossContext | None,
114+
return_router_logits: bool = False,
114115
) -> MoEModelOutputs:
115116
input_ids = seq_ctx.input_ids
116117
position_ids = seq_ctx.position_ids
@@ -210,7 +211,7 @@ def _forward(
210211

211212
del router_logits
212213

213-
if self.config.return_router_results:
214+
if self.config.return_router_results or return_router_logits:
214215
raise NotImplementedError
215216
# TODO: Move router logits to CPU is cost
216217
# for layer_name, router_logits in output["router_logits"].items():

xtuner/v1/ops/attn_imp.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
flash_sink_attn_varlen_func = None # type: ignore[assignment]
3434
flash_sink_attn_exception = e
3535

36-
from typing import List
37-
3836

3937
def get_flex_attention_compiled():
4038
torch._dynamo.config.cache_size_limit = 128
@@ -129,7 +127,7 @@ def mask_mod(b, h, q_idx, kv_idx):
129127

130128
def eager_attention(
131129
q, k, v, cu_seqlens_q, softmax_scale, window_size=(-1, -1), dropout_p=0.0, s_aux=None, **kwargs
132-
) -> torch.Tensor:
130+
) -> tuple[torch.Tensor, dict]:
133131
# TODO(HHA): Currently, the mask is recalculated each time, which is quite time-consuming.
134132
# It should be refactored to be calculated only once.
135133

@@ -176,7 +174,7 @@ def eager_attention(
176174

177175
def flex_attention(
178176
q, k, v, cu_seqlens_q, softmax_scale=None, window_size=(-1, -1), dropout_p=0.0, s_aux=None, causal=True, **kwargs
179-
) -> torch.Tensor:
177+
) -> tuple[torch.Tensor, dict]:
180178
# q, k, v: [b, n_head, seq, head_dim]
181179
assert dropout_p == 0.0, "Dropout is not supported in flex attention"
182180

@@ -208,7 +206,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
208206
return attention_output, extra_info
209207

210208

211-
def flash_attention(q, k, v, window_size=(-1, -1), s_aux=None, **kwargs) -> torch.Tensor:
209+
def flash_attention(q, k, v, window_size=(-1, -1), s_aux=None, **kwargs) -> tuple[torch.Tensor, dict]:
212210
# q, k, v: [b, n_head, seq , head_dim]
213211
assert q.size(0) == 1, "Only support batch size 1 for flash attention"
214212
q = q.transpose(1, 2).squeeze(0) # [seq, head, dim]
@@ -220,11 +218,11 @@ def flash_attention(q, k, v, window_size=(-1, -1), s_aux=None, **kwargs) -> torc
220218
if flash_attn_exception is not None:
221219
traceback.print_exception(flash_attn_exception)
222220
raise flash_attn_exception
223-
attention_outputs = flash_attn_varlen_func(q, k, v, return_attn_probs=True, **kwargs) # type: ignore
221+
attention_outputs = flash_attn_varlen_func(q, k, v, return_attn_probs=True, **kwargs) # type: ignore
224222
if isinstance(attention_outputs, tuple):
225223
attention_output = attention_outputs[0]
226224
extra_info["softmax_lse"] = attention_outputs[1].detach()
227-
else: # npu fused attn doesn't support softmax_lse
225+
else: # npu fused attn doesn't support softmax_lse
228226
attention_output = attention_outputs
229227
else:
230228
if flash_sink_attn_exception is not None:

xtuner/v1/train/trainer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@
4141
from xtuner.v1.profiler.prober_utils import setup_prober_list
4242
from xtuner.v1.utils import (
4343
XTUNER_DETERMINISTIC,
44+
InternalMetrics,
45+
InternalMetricsRecorder,
4446
ParallelConfigException,
4547
get_logger,
4648
is_hf_model_path,
4749
log_format,
4850
record_git_info,
49-
InternalMetricsRecorder,
50-
InternalMetrics,
5151
)
5252
from xtuner.v1.utils.device import get_device, get_torch_device_module
5353

@@ -344,7 +344,9 @@ def __init__(
344344
self._hf_interval = hf_interval
345345
self._internal_metrics_interval = internal_metrics_interval
346346
if self._internal_metrics_interval is not None:
347-
torch._dynamo.config.skip_nnmodule_hook_guards = False # otherwise the hook will be ignored for compiled modules
347+
torch._dynamo.config.skip_nnmodule_hook_guards = (
348+
False # otherwise the hook will be ignored for compiled modules
349+
)
348350

349351
if tokenizer_path is not None:
350352
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
@@ -1438,7 +1440,7 @@ def _setup_env(self):
14381440
logger.info(log_str)
14391441

14401442

1441-
def _flatten_nested_metrics(metrics: InternalMetrics, sep: str = '/') -> dict:
1443+
def _flatten_nested_metrics(metrics: InternalMetrics, sep: str = "/") -> dict:
14421444
items = []
14431445
for name, sub_metrics in metrics.items():
14441446
if isinstance(sub_metrics, dict):
@@ -1448,5 +1450,7 @@ def _flatten_nested_metrics(metrics: InternalMetrics, sep: str = '/') -> dict:
14481450
else:
14491451
raise ValueError(f"Unsupported metric value type: expected float or int, but got {type(v)}")
14501452
else:
1451-
raise ValueError(f"Unsupported metric type for internal metrics: expected dict, but got {type(sub_metrics)}")
1453+
raise ValueError(
1454+
f"Unsupported metric type for internal metrics: expected dict, but got {type(sub_metrics)}"
1455+
)
14521456
return dict(items)

xtuner/v1/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from .internal_metrics import InternalMetricsRecorder, InternalMetrics
1818

19+
1920
IGNORE_INDEX = -100
2021

2122
__all__ = [

xtuner/v1/utils/internal_metrics.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,18 @@
11
from collections import defaultdict
2-
from typing import Any
3-
import numpy as np
42
import torch
5-
from torch import nn
63
import torch.distributed as dist
4+
from torch import nn
75
from torch.utils.hooks import RemovableHandle
6+
from typing_extensions import TypedDict
87

9-
from xtuner.v1.module import (
10-
RMSNorm,
11-
MultiHeadAttention,
12-
MultiLatentAttention,
13-
LMHead
14-
)
15-
from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEGate, MoEBlock, MoEDecoderLayer
16-
from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer
8+
from xtuner.v1.engine.train_engine import TrainEngine
179
from xtuner.v1.model import MoE
1810
from xtuner.v1.model.base import ModelItem
19-
from xtuner.v1.engine.train_engine import TrainEngine
11+
from xtuner.v1.module import LMHead, MultiHeadAttention, MultiLatentAttention
12+
from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer
13+
from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEDecoderLayer
2014
from xtuner.v1.utils.grad_norm import group_tensors_by_device_mesh_and_placements, cal_total_norm
2115

22-
from typing_extensions import TypedDict
23-
2416

2517
class InternalMetrics(TypedDict):
2618
weight_rms: dict[str, float]
@@ -45,6 +37,7 @@ class InternalMetrics(TypedDict):
4537
ATTN_MAX_LSE: dict[str, torch.Tensor] = {}
4638
ATTN_MAX_LOGITS: dict[str, torch.Tensor] = {}
4739

40+
4841
class InternalMetricsRecorder:
4942
def __init__(self, engine: TrainEngine):
5043
self.model = engine.model
@@ -60,8 +53,10 @@ def __init__(self, engine: TrainEngine):
6053
"attn_max_logits": {},
6154
}
6255

56+
@torch.no_grad()
6357
def calculate_module_weight_rms(self, module: nn.Module, layer_name: str, dtype: torch.dtype = torch.float32):
64-
all_params = [param for param in module.parameters() if param.requires_grad]
58+
"""Calculate the RMS of the module's parameters"""
59+
all_params = [param.data for param in module.parameters() if param.requires_grad]
6560
if not all_params:
6661
return
6762
grouped_params = group_tensors_by_device_mesh_and_placements(all_params)
@@ -73,16 +68,14 @@ def calculate_module_weight_rms(self, module: nn.Module, layer_name: str, dtype:
7368
total_numel += sum(p.numel() for p in params)
7469
param_l2_norm = torch.linalg.vector_norm(torch.stack(total_norms), ord=2.0, dtype=dtype)
7570
param_rms = param_l2_norm / total_numel**0.5
76-
self.metrics['weight_rms'][layer_name] = param_rms.item()
71+
self.metrics["weight_rms"][layer_name] = param_rms.item()
7772

7873
def register_attn_extra_info_hook(self, module: nn.Module, layer_name: str):
79-
"""
80-
Register attention extra info hook as a forward hook
81-
"""
74+
"""Register attention extra info hook as a forward hook"""
8275
def hook(module, input, output):
8376
extra_info = output[1]
8477
if extra_info.get("softmax_lse", None) is not None:
85-
if layer_name not in ATTN_MAX_LSE:
78+
if layer_name not in ATTN_MAX_LSE:
8679
# original shape: [n_head, seq]
8780
ATTN_MAX_LSE[layer_name] = extra_info["softmax_lse"].max()
8881
else:
@@ -101,6 +94,7 @@ def hook(module, input, output):
10194

10295
@torch.no_grad()
10396
def get_metrics(self, data_batches: list[ModelItem]):
97+
"""Run a dummy forward to get metrics"""
10498
additional_kwargs = {}
10599
if isinstance(self.model, MoE):
106100
# for MoE model, add additional kwargs to return necessary stats
@@ -140,7 +134,6 @@ def get_metrics(self, data_batches: list[ModelItem]):
140134
else:
141135
tokens_per_expert_global += output["tokens_per_expert_global"].float()
142136

143-
144137
if output.get("router_logits", None) is not None:
145138
for layer_name, router_logits in output["router_logits"].items():
146139
# [bsz, packed_len, num_experts]
@@ -151,7 +144,9 @@ def get_metrics(self, data_batches: list[ModelItem]):
151144
avg_count_load = tokens_per_expert_global.mean(1)
152145
max_load_i = torch.amax(tokens_per_expert_global, dim=1)
153146
maxvio_all_layers = (max_load_i - avg_count_load) / avg_count_load
154-
drop_ratio_all_layers = (tokens_per_expert_global - avg_count_load[:,None]).abs().mean(dim=1) / avg_count_load
147+
drop_ratio_all_layers = (
148+
tokens_per_expert_global - avg_count_load[:,None]
149+
).abs().mean(dim=1) / avg_count_load
155150
drop_ratio = drop_ratio_all_layers.mean()
156151
self.metrics["drop_ratio"].update(
157152
{f"layer{idx}": drop_ratio_all_layers[idx].item() for idx in range(drop_ratio_all_layers.shape[0])}

0 commit comments

Comments
 (0)