@@ -29,29 +29,28 @@ def __init__(
2929 self ,
3030 balancing_loss_alpha : float ,
3131 balancing_loss_global_average : bool ,
32- router_scoring_func : Literal ["sigmoid" , "softmax" ],
3332 ) -> None :
3433 super ().__init__ ()
3534 self .loss_weight = balancing_loss_alpha
3635 self .global_average = balancing_loss_global_average
3736
38- def forward (self , router_weights , n_routed_experts , num_experts_per_tok ):
37+ def forward (
38+ self ,
39+ router_weights : torch .Tensor ,
40+ n_routed_experts : int ,
41+ num_experts_per_tok : int ,
42+ router_n_groups : int ,
43+ ):
3944 if self .loss_weight == 0 :
4045 return torch .tensor (0.0 , device = router_weights .device , dtype = torch .float32 )
4146
42- num_layers = router_weights .shape [0 ]
4347 router_weights = router_weights .float () # (nlayers, seq, ne)
44- _ , selected_experts = torch .topk (router_weights , num_experts_per_tok , dim = - 1 )
45- selected_experts_flat = selected_experts .view (num_layers , - 1 )
46- offset = torch .arange (num_layers , device = router_weights .device ).unsqueeze (1 ) * n_routed_experts
47- selected_experts_offset = selected_experts_flat + offset
48- tokens_per_expert_flat = torch .histc (
49- selected_experts_offset .view (- 1 ),
50- bins = num_layers * n_routed_experts ,
51- min = 0 ,
52- max = num_layers * n_routed_experts ,
53- )
54- tokens_per_expert = tokens_per_expert_flat .view (num_layers , n_routed_experts ) # (nlayers, ne)
48+ tokens_per_expert = self ._get_tokens_per_experts (
49+ router_weights ,
50+ n_routed_experts ,
51+ num_experts_per_tok ,
52+ router_n_groups ,
53+ ) # (nlayers, ne)
5554
5655 tokens_per_expert_global = tokens_per_expert .to (router_weights .dtype ) # (nlayers, ne)
5756 if self .global_average and dist .is_initialized ():
@@ -74,6 +73,32 @@ def forward(self, router_weights, n_routed_experts, num_experts_per_tok):
7473 # ProberList.record_tensor(scale_global, "[balancing_loss][after]scale_global")
7574 return loss * self .loss_weight
7675
76+ def _get_tokens_per_experts (
77+ self ,
78+ router_weights : torch .Tensor , # (nlayers, seq, ne)
79+ n_routed_experts : int ,
80+ num_experts_per_tok : int ,
81+ n_groups : int ,
82+ ):
83+ num_layers , seq , n_routed_experts = router_weights .shape
84+ group_size = max (1 , n_routed_experts // n_groups )
85+
86+ scores_for_choice = router_weights .view (num_layers , seq , n_groups , group_size )
87+ _ , group_local_max_idx = torch .topk (
88+ scores_for_choice , k = num_experts_per_tok // n_groups , dim = 3 ) # nlayers, seq, n_groups, top_k_per_group
89+ group_offsets = (torch .arange (num_layers * n_groups , device = router_weights .device ) * group_size )
90+ group_offsets = group_offsets .view (num_layers , 1 , n_groups , 1 )
91+
92+ topk_ids = (group_local_max_idx + group_offsets ).to (torch .long ) # [seq, n_groups, top_k_per_group]
93+ tokens_per_expert_flat = torch .histc (
94+ topk_ids .view (- 1 ),
95+ bins = num_layers * n_routed_experts ,
96+ min = 0 ,
97+ max = num_layers * n_routed_experts ,
98+ )
99+ tokens_per_expert = tokens_per_expert_flat .view (num_layers , n_routed_experts )
100+ return tokens_per_expert
101+
77102
78103def z_loss (router_logits : torch .Tensor , global_average : bool = False ):
79104 router_logits = router_logits .float () # (nlayers, seq, ne)
0 commit comments