1- from typing import Literal
1+ from typing import cast
22
33import torch
44import torch .nn as nn
55from torch import distributed as dist
6+ from torch .distributed import ProcessGroup
67from torch .distributed ._functional_collectives import all_reduce
78
89
@@ -29,33 +30,34 @@ def __init__(
2930 self ,
3031 balancing_loss_alpha : float ,
3132 balancing_loss_global_average : bool ,
32- router_scoring_func : Literal ["sigmoid" , "softmax" ],
3333 ) -> None :
3434 super ().__init__ ()
3535 self .loss_weight = balancing_loss_alpha
3636 self .global_average = balancing_loss_global_average
3737
38- def forward (self , router_weights , n_routed_experts , num_experts_per_tok ):
38+ def forward (
39+ self ,
40+ router_weights : torch .Tensor ,
41+ n_routed_experts : int ,
42+ num_experts_per_tok : int ,
43+ router_n_groups : int ,
44+ ):
3945 if self .loss_weight == 0 :
4046 return torch .tensor (0.0 , device = router_weights .device , dtype = torch .float32 )
4147
42- num_layers = router_weights .shape [0 ]
4348 router_weights = router_weights .float () # (nlayers, seq, ne)
44- _ , selected_experts = torch .topk (router_weights , num_experts_per_tok , dim = - 1 )
45- selected_experts_flat = selected_experts .view (num_layers , - 1 )
46- offset = torch .arange (num_layers , device = router_weights .device ).unsqueeze (1 ) * n_routed_experts
47- selected_experts_offset = selected_experts_flat + offset
48- tokens_per_expert_flat = torch .histc (
49- selected_experts_offset .view (- 1 ),
50- bins = num_layers * n_routed_experts ,
51- min = 0 ,
52- max = num_layers * n_routed_experts ,
53- )
54- tokens_per_expert = tokens_per_expert_flat .view (num_layers , n_routed_experts ) # (nlayers, ne)
49+ tokens_per_expert = self ._get_tokens_per_experts (
50+ router_weights ,
51+ n_routed_experts ,
52+ num_experts_per_tok ,
53+ router_n_groups ,
54+ ) # (nlayers, ne)
5555
5656 tokens_per_expert_global = tokens_per_expert .to (router_weights .dtype ) # (nlayers, ne)
5757 if self .global_average and dist .is_initialized ():
58- tokens_per_expert_global = all_reduce (tokens_per_expert_global , "sum" , dist .group .WORLD ) # (nlayers, ne)
58+ tokens_per_expert_global = all_reduce ( # (nlayers, ne)
59+ tokens_per_expert_global , "sum" , cast (ProcessGroup , dist .group .WORLD )
60+ )
5961 tokens_global = tokens_per_expert_global .sum (- 1 ) # (nlayers, )
6062 seqlen_global = tokens_global // num_experts_per_tok
6163 routing_weights_sum_global = all_reduce_autograd (
@@ -74,6 +76,33 @@ def forward(self, router_weights, n_routed_experts, num_experts_per_tok):
7476 # ProberList.record_tensor(scale_global, "[balancing_loss][after]scale_global")
7577 return loss * self .loss_weight
7678
79+ def _get_tokens_per_experts (
80+ self ,
81+ router_weights : torch .Tensor , # (nlayers, seq, ne)
82+ n_routed_experts : int ,
83+ num_experts_per_tok : int ,
84+ n_groups : int ,
85+ ):
86+ num_layers , seq , n_routed_experts = router_weights .shape
87+ group_size = max (1 , n_routed_experts // n_groups )
88+
89+ scores_for_choice = router_weights .view (num_layers , seq , n_groups , group_size )
90+ _ , group_local_max_idx = torch .topk (
91+ scores_for_choice , k = num_experts_per_tok // n_groups , dim = 3
92+ ) # nlayers, seq, n_groups, top_k_per_group
93+ group_offsets = torch .arange (num_layers * n_groups , device = router_weights .device ) * group_size
94+ group_offsets = group_offsets .view (num_layers , 1 , n_groups , 1 )
95+
96+ topk_ids = (group_local_max_idx + group_offsets ).to (torch .long ) # [seq, n_groups, top_k_per_group]
97+ tokens_per_expert_flat = torch .histc (
98+ topk_ids .view (- 1 ),
99+ bins = num_layers * n_routed_experts ,
100+ min = 0 ,
101+ max = num_layers * n_routed_experts ,
102+ )
103+ tokens_per_expert = tokens_per_expert_flat .view (num_layers , n_routed_experts )
104+ return tokens_per_expert
105+
77106
78107def z_loss (router_logits : torch .Tensor , global_average : bool = False ):
79108 router_logits = router_logits .float () # (nlayers, seq, ne)
0 commit comments