11from collections import defaultdict
2- from typing import Any
3- import numpy as np
42import torch
5- from torch import nn
63import torch .distributed as dist
4+ from torch import nn
75from torch .utils .hooks import RemovableHandle
6+ from typing_extensions import TypedDict
87
9- from xtuner .v1 .module import (
10- RMSNorm ,
11- MultiHeadAttention ,
12- MultiLatentAttention ,
13- LMHead
14- )
15- from xtuner .v1 .module .decoder_layer .moe_decoder_layer import MoEGate , MoEBlock , MoEDecoderLayer
16- from xtuner .v1 .module .decoder_layer .dense_decoder_layer import DenseDecoderLayer
8+ from xtuner .v1 .engine .train_engine import TrainEngine
179from xtuner .v1 .model import MoE
1810from xtuner .v1 .model .base import ModelItem
19- from xtuner .v1 .engine .train_engine import TrainEngine
11+ from xtuner .v1 .module import LMHead , MultiHeadAttention , MultiLatentAttention
12+ from xtuner .v1 .module .decoder_layer .dense_decoder_layer import DenseDecoderLayer
13+ from xtuner .v1 .module .decoder_layer .moe_decoder_layer import MoEDecoderLayer
2014from xtuner .v1 .utils .grad_norm import group_tensors_by_device_mesh_and_placements , cal_total_norm
2115
22- from typing_extensions import TypedDict
23-
2416
2517class InternalMetrics (TypedDict ):
2618 weight_rms : dict [str , float ]
@@ -45,6 +37,7 @@ class InternalMetrics(TypedDict):
4537ATTN_MAX_LSE : dict [str , torch .Tensor ] = {}
4638ATTN_MAX_LOGITS : dict [str , torch .Tensor ] = {}
4739
40+
4841class InternalMetricsRecorder :
4942 def __init__ (self , engine : TrainEngine ):
5043 self .model = engine .model
@@ -60,8 +53,10 @@ def __init__(self, engine: TrainEngine):
6053 "attn_max_logits" : {},
6154 }
6255
56+ @torch .no_grad ()
6357 def calculate_module_weight_rms (self , module : nn .Module , layer_name : str , dtype : torch .dtype = torch .float32 ):
64- all_params = [param for param in module .parameters () if param .requires_grad ]
58+ """Calculate the RMS of the module's parameters"""
59+ all_params = [param .data for param in module .parameters () if param .requires_grad ]
6560 if not all_params :
6661 return
6762 grouped_params = group_tensors_by_device_mesh_and_placements (all_params )
@@ -73,16 +68,14 @@ def calculate_module_weight_rms(self, module: nn.Module, layer_name: str, dtype:
7368 total_numel += sum (p .numel () for p in params )
7469 param_l2_norm = torch .linalg .vector_norm (torch .stack (total_norms ), ord = 2.0 , dtype = dtype )
7570 param_rms = param_l2_norm / total_numel ** 0.5
76- self .metrics [' weight_rms' ][layer_name ] = param_rms .item ()
71+ self .metrics [" weight_rms" ][layer_name ] = param_rms .item ()
7772
7873 def register_attn_extra_info_hook (self , module : nn .Module , layer_name : str ):
79- """
80- Register attention extra info hook as a forward hook
81- """
74+ """Register attention extra info hook as a forward hook"""
8275 def hook (module , input , output ):
8376 extra_info = output [1 ]
8477 if extra_info .get ("softmax_lse" , None ) is not None :
85- if layer_name not in ATTN_MAX_LSE :
78+ if layer_name not in ATTN_MAX_LSE :
8679 # original shape: [n_head, seq]
8780 ATTN_MAX_LSE [layer_name ] = extra_info ["softmax_lse" ].max ()
8881 else :
@@ -101,6 +94,7 @@ def hook(module, input, output):
10194
10295 @torch .no_grad ()
10396 def get_metrics (self , data_batches : list [ModelItem ]):
97+ """Run a dummy forward to get metrics"""
10498 additional_kwargs = {}
10599 if isinstance (self .model , MoE ):
106100 # for MoE model, add additional kwargs to return necessary stats
@@ -140,7 +134,6 @@ def get_metrics(self, data_batches: list[ModelItem]):
140134 else :
141135 tokens_per_expert_global += output ["tokens_per_expert_global" ].float ()
142136
143-
144137 if output .get ("router_logits" , None ) is not None :
145138 for layer_name , router_logits in output ["router_logits" ].items ():
146139 # [bsz, packed_len, num_experts]
@@ -151,7 +144,9 @@ def get_metrics(self, data_batches: list[ModelItem]):
151144 avg_count_load = tokens_per_expert_global .mean (1 )
152145 max_load_i = torch .amax (tokens_per_expert_global , dim = 1 )
153146 maxvio_all_layers = (max_load_i - avg_count_load ) / avg_count_load
154- drop_ratio_all_layers = (tokens_per_expert_global - avg_count_load [:,None ]).abs ().mean (dim = 1 ) / avg_count_load
147+ drop_ratio_all_layers = (
148+ tokens_per_expert_global - avg_count_load [:,None ]
149+ ).abs ().mean (dim = 1 ) / avg_count_load
155150 drop_ratio = drop_ratio_all_layers .mean ()
156151 self .metrics ["drop_ratio" ].update (
157152 {f"layer{ idx } " : drop_ratio_all_layers [idx ].item () for idx in range (drop_ratio_all_layers .shape [0 ])}
0 commit comments