Skip to content

Commit bb058ee

Browse files
committed
[Feature] Support group router based balance loss in BalancingLoss
1 parent cf3bb40 commit bb058ee

File tree

2 files changed

+46
-17
lines changed

2 files changed

+46
-17
lines changed

xtuner/v1/loss/moe_loss.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Literal
1+
from typing import cast
22

33
import torch
44
import torch.nn as nn
55
from torch import distributed as dist
6+
from torch.distributed import ProcessGroup
67
from torch.distributed._functional_collectives import all_reduce
78

89

@@ -29,33 +30,34 @@ def __init__(
2930
self,
3031
balancing_loss_alpha: float,
3132
balancing_loss_global_average: bool,
32-
router_scoring_func: Literal["sigmoid", "softmax"],
3333
) -> None:
3434
super().__init__()
3535
self.loss_weight = balancing_loss_alpha
3636
self.global_average = balancing_loss_global_average
3737

38-
def forward(self, router_weights, n_routed_experts, num_experts_per_tok):
38+
def forward(
39+
self,
40+
router_weights: torch.Tensor,
41+
n_routed_experts: int,
42+
num_experts_per_tok: int,
43+
router_n_groups: int,
44+
):
3945
if self.loss_weight == 0:
4046
return torch.tensor(0.0, device=router_weights.device, dtype=torch.float32)
4147

42-
num_layers = router_weights.shape[0]
4348
router_weights = router_weights.float() # (nlayers, seq, ne)
44-
_, selected_experts = torch.topk(router_weights, num_experts_per_tok, dim=-1)
45-
selected_experts_flat = selected_experts.view(num_layers, -1)
46-
offset = torch.arange(num_layers, device=router_weights.device).unsqueeze(1) * n_routed_experts
47-
selected_experts_offset = selected_experts_flat + offset
48-
tokens_per_expert_flat = torch.histc(
49-
selected_experts_offset.view(-1),
50-
bins=num_layers * n_routed_experts,
51-
min=0,
52-
max=num_layers * n_routed_experts,
53-
)
54-
tokens_per_expert = tokens_per_expert_flat.view(num_layers, n_routed_experts) # (nlayers, ne)
49+
tokens_per_expert = self._get_tokens_per_experts(
50+
router_weights,
51+
n_routed_experts,
52+
num_experts_per_tok,
53+
router_n_groups,
54+
) # (nlayers, ne)
5555

5656
tokens_per_expert_global = tokens_per_expert.to(router_weights.dtype) # (nlayers, ne)
5757
if self.global_average and dist.is_initialized():
58-
tokens_per_expert_global = all_reduce(tokens_per_expert_global, "sum", dist.group.WORLD) # (nlayers, ne)
58+
tokens_per_expert_global = all_reduce( # (nlayers, ne)
59+
tokens_per_expert_global, "sum", cast(ProcessGroup, dist.group.WORLD)
60+
)
5961
tokens_global = tokens_per_expert_global.sum(-1) # (nlayers, )
6062
seqlen_global = tokens_global // num_experts_per_tok
6163
routing_weights_sum_global = all_reduce_autograd(
@@ -74,6 +76,33 @@ def forward(self, router_weights, n_routed_experts, num_experts_per_tok):
7476
# ProberList.record_tensor(scale_global, "[balancing_loss][after]scale_global")
7577
return loss * self.loss_weight
7678

79+
def _get_tokens_per_experts(
80+
self,
81+
router_weights: torch.Tensor, # (nlayers, seq, ne)
82+
n_routed_experts: int,
83+
num_experts_per_tok: int,
84+
n_groups: int,
85+
):
86+
num_layers, seq, n_routed_experts = router_weights.shape
87+
group_size = max(1, n_routed_experts // n_groups)
88+
89+
scores_for_choice = router_weights.view(num_layers, seq, n_groups, group_size)
90+
_, group_local_max_idx = torch.topk(
91+
scores_for_choice, k=num_experts_per_tok // n_groups, dim=3
92+
) # nlayers, seq, n_groups, top_k_per_group
93+
group_offsets = torch.arange(num_layers * n_groups, device=router_weights.device) * group_size
94+
group_offsets = group_offsets.view(num_layers, 1, n_groups, 1)
95+
96+
topk_ids = (group_local_max_idx + group_offsets).to(torch.long) # [seq, n_groups, top_k_per_group]
97+
tokens_per_expert_flat = torch.histc(
98+
topk_ids.view(-1),
99+
bins=num_layers * n_routed_experts,
100+
min=0,
101+
max=num_layers * n_routed_experts,
102+
)
103+
tokens_per_expert = tokens_per_expert_flat.view(num_layers, n_routed_experts)
104+
return tokens_per_expert
105+
77106

78107
def z_loss(router_logits: torch.Tensor, global_average: bool = False):
79108
router_logits = router_logits.float() # (nlayers, seq, ne)

xtuner/v1/model/moe/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def build(self, router_scoring_func) -> BalancingLoss:
6969
return BalancingLoss(
7070
self.balancing_loss_alpha,
7171
self.balancing_loss_global_average,
72-
router_scoring_func=router_scoring_func,
7372
)
7473

7574

@@ -549,6 +548,7 @@ def _forward(
549548
router_weights=router_weights,
550549
n_routed_experts=self.config.n_routed_experts,
551550
num_experts_per_tok=self.config.num_experts_per_tok,
551+
router_n_groups=self.config.router.router_n_groups or 1,
552552
)
553553
output["balancing_loss"] = balancing_loss
554554

0 commit comments

Comments
 (0)