Skip to content

Commit 6cdb347

Browse files
committed
[Feature] Support group router based balance loss in BalancingLoss
1 parent cf3bb40 commit 6cdb347

File tree

2 files changed

+40
-15
lines changed

2 files changed

+40
-15
lines changed

xtuner/v1/loss/moe_loss.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,28 @@ def __init__(
2929
self,
3030
balancing_loss_alpha: float,
3131
balancing_loss_global_average: bool,
32-
router_scoring_func: Literal["sigmoid", "softmax"],
3332
) -> None:
3433
super().__init__()
3534
self.loss_weight = balancing_loss_alpha
3635
self.global_average = balancing_loss_global_average
3736

38-
def forward(self, router_weights, n_routed_experts, num_experts_per_tok):
37+
def forward(
38+
self,
39+
router_weights: torch.Tensor,
40+
n_routed_experts: int,
41+
num_experts_per_tok: int,
42+
router_n_groups: int,
43+
):
3944
if self.loss_weight == 0:
4045
return torch.tensor(0.0, device=router_weights.device, dtype=torch.float32)
4146

42-
num_layers = router_weights.shape[0]
4347
router_weights = router_weights.float() # (nlayers, seq, ne)
44-
_, selected_experts = torch.topk(router_weights, num_experts_per_tok, dim=-1)
45-
selected_experts_flat = selected_experts.view(num_layers, -1)
46-
offset = torch.arange(num_layers, device=router_weights.device).unsqueeze(1) * n_routed_experts
47-
selected_experts_offset = selected_experts_flat + offset
48-
tokens_per_expert_flat = torch.histc(
49-
selected_experts_offset.view(-1),
50-
bins=num_layers * n_routed_experts,
51-
min=0,
52-
max=num_layers * n_routed_experts,
53-
)
54-
tokens_per_expert = tokens_per_expert_flat.view(num_layers, n_routed_experts) # (nlayers, ne)
48+
tokens_per_expert = self._get_tokens_per_experts(
49+
router_weights,
50+
n_routed_experts,
51+
num_experts_per_tok,
52+
router_n_groups,
53+
) # (nlayers, ne)
5554

5655
tokens_per_expert_global = tokens_per_expert.to(router_weights.dtype) # (nlayers, ne)
5756
if self.global_average and dist.is_initialized():
@@ -74,6 +73,32 @@ def forward(self, router_weights, n_routed_experts, num_experts_per_tok):
7473
# ProberList.record_tensor(scale_global, "[balancing_loss][after]scale_global")
7574
return loss * self.loss_weight
7675

76+
def _get_tokens_per_experts(
77+
self,
78+
router_weights: torch.Tensor, # (nlayers, seq, ne)
79+
n_routed_experts: int,
80+
num_experts_per_tok: int,
81+
n_groups: int,
82+
):
83+
num_layers, seq, n_routed_experts = router_weights.shape
84+
group_size = max(1, n_routed_experts // n_groups)
85+
86+
scores_for_choice = router_weights.view(num_layers, seq, n_groups, group_size)
87+
_, group_local_max_idx = torch.topk(
88+
scores_for_choice, k=num_experts_per_tok // n_groups, dim=3) # nlayers, seq, n_groups, top_k_per_group
89+
group_offsets = (torch.arange(num_layers * n_groups, device=router_weights.device) * group_size)
90+
group_offsets = group_offsets.view(num_layers, 1, n_groups, 1)
91+
92+
topk_ids = (group_local_max_idx + group_offsets).to(torch.long) # [seq, n_groups, top_k_per_group]
93+
tokens_per_expert_flat = torch.histc(
94+
topk_ids.view(-1),
95+
bins=num_layers * n_routed_experts,
96+
min=0,
97+
max=num_layers * n_routed_experts,
98+
)
99+
tokens_per_expert = tokens_per_expert_flat.view(num_layers, n_routed_experts)
100+
return tokens_per_expert
101+
77102

78103
def z_loss(router_logits: torch.Tensor, global_average: bool = False):
79104
router_logits = router_logits.float() # (nlayers, seq, ne)

xtuner/v1/model/moe/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def build(self, router_scoring_func) -> BalancingLoss:
6969
return BalancingLoss(
7070
self.balancing_loss_alpha,
7171
self.balancing_loss_global_average,
72-
router_scoring_func=router_scoring_func,
7372
)
7473

7574

@@ -549,6 +548,7 @@ def _forward(
549548
router_weights=router_weights,
550549
n_routed_experts=self.config.n_routed_experts,
551550
num_experts_per_tok=self.config.num_experts_per_tok,
551+
router_n_groups=self.config.router.router_n_groups or 1,
552552
)
553553
output["balancing_loss"] = balancing_loss
554554

0 commit comments

Comments
 (0)