-
Notifications
You must be signed in to change notification settings - Fork 160
Description
Description
I am attempting to train Qwen3-Omni Thinker using DeepSpeed-ZeRO Stage 2. During this process, I noticed that the DeepSpeed engine initialization is extremely slow, which may be attributed to the discrete expert architecture within each layer. This design results in approximately 10,000 modules that DeepSpeed must initialize. Although the initialization eventually completes, NCCL timeout errors frequently occur during the training phase, typically within 1-20 steps.
Initially, I suspected these errors were data-related; however, the same dataset can successfully train Qwen3-VL-30BA3B without issues. Upon further investigation, I discovered that Qwen3-VL-30BA3B concatenates the experts within each layer, thereby enabling parallel computation. This architectural difference appears to resolve all the aforementioned issues. Consequently, I created a script to replace the Qwen3OmniMoeThinkerTextSparseMoeBlock module.
Following this replacement, the DeepSpeed engine initialization became approximately 10× faster, and the NCCL timeout errors were eliminated entirely. Additionally, this modification accelerated inference using the transformers model.generate method by approximately 3×.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import copy
class SpeedupMoeExperts(nn.Module):
def __init__(self, num_experts, hidden_size, intermediate_size, act_fn):
super().__init__()
self.num_experts = num_experts
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.act_fn = act_fn
# gate_up_proj: (num_experts, hidden_size, 2 * intermediate_size)
self.gate_up_proj = nn.Parameter(
torch.empty(num_experts, hidden_size, 2 * intermediate_size)
)
# down_proj: (num_experts, intermediate_size, hidden_size)
self.down_proj = nn.Parameter(
torch.empty(num_experts, intermediate_size, hidden_size)
)
def forward(
self,
hidden_states: torch.Tensor,
routing_weights: torch.Tensor,
selected_experts: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * seq_len, hidden_size)
routing_weights: (batch_size * seq_len, top_k)
selected_experts: (batch_size * seq_len, top_k)
"""
batch_size, hidden_dim = hidden_states.shape
if self.training:
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, batch_size)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero().squeeze(-1)
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx])
current_state = hidden_states[top_x] # (num_tokens, hidden_size)
gate_up = current_state @ self.gate_up_proj[expert_idx] # (num_tokens, 2*intermediate)
gate, up = gate_up.chunk(2, dim=-1)
gated_output = up * self.act_fn(gate)
expert_out = gated_output @ self.down_proj[expert_idx] # (num_tokens, hidden_size)
weighted_output = expert_out * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype))
else:
hidden_states_expanded = hidden_states.repeat(self.num_experts, 1, 1)
# (num_experts, batch_size, hidden_size)
gate_up = torch.bmm(hidden_states_expanded, self.gate_up_proj)
gate, up = gate_up.chunk(2, dim=-1)
gated_output = up * self.act_fn(gate)
expert_outputs = torch.bmm(gated_output, self.down_proj)
# (num_experts, batch_size, hidden_size)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
# (batch_size, top_k, num_experts)
expert_mask = expert_mask.permute(2, 0, 1).float() # (num_experts, batch_size, top_k)
routing_weights_expanded = routing_weights.unsqueeze(0) # (1, batch_size, top_k)
weighted_mask = expert_mask * routing_weights_expanded # (num_experts, batch_size, top_k)
weighted_mask = weighted_mask.sum(dim=-1, keepdim=True) # (num_experts, batch_size, 1)
final_hidden_states = (expert_outputs * weighted_mask).sum(dim=0)
# (batch_size, hidden_size)
return final_hidden_states
class SpeedupMoeSparseMoeBlock(nn.Module):
def __init__(self, config_or_original_moe, original_experts=None):
super().__init__()
if original_experts is not None:
original_moe = config_or_original_moe
self.num_experts = original_moe.num_experts
self.top_k = original_moe.top_k
self.norm_topk_prob = original_moe.norm_topk_prob
self.gate = original_moe.gate
first_expert = original_experts[0]
self.experts = SpeedupMoeExperts(
num_experts=self.num_experts,
hidden_size=first_expert.hidden_size,
intermediate_size=first_expert.intermediate_size,
act_fn=first_expert.act_fn
)
self._copy_weights_from_original(original_experts)
else:
config = config_or_original_moe
self.num_experts = config['num_experts']
self.top_k = config['top_k']
self.norm_topk_prob = config['norm_topk_prob']
self.gate = None
self.experts = None
def _copy_weights_from_original(self, original_experts):
with torch.no_grad():
for expert_idx, expert in enumerate(original_experts):
gate_weight = expert.gate_proj.weight.data.t() # (hidden, intermediate)
up_weight = expert.up_proj.weight.data.t() # (hidden, intermediate)
gate_up_weight = torch.cat([gate_weight, up_weight], dim=1) # (hidden, 2*intermediate)
self.experts.gate_up_proj.data[expert_idx].copy_(gate_up_weight)
down_weight = expert.down_proj.weight.data.t() # (intermediate, hidden)
self.experts.down_proj.data[expert_idx].copy_(down_weight)
def forward(self, hidden_states: torch.Tensor) -> tuple:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
# Router logits
router_logits = self.gate(hidden_states_flat)
# Routing weights
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
# Expert forward
final_hidden_states = self.experts(hidden_states_flat, routing_weights, selected_experts)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
def apply_speedup_moe_block(model, update_record_config=True):
converted_count = 0
for layer_idx, layer in enumerate(model.model.layers):
if type(layer.mlp).__name__ == "Qwen3OmniMoeThinkerTextSparseMoeBlock":
print(f"[SpeedUp] Converting MoE block in layer {layer_idx}...")
original_moe = layer.mlp
original_experts = original_moe.experts
speedup_moe = SpeedupMoeSparseMoeBlock(original_moe, original_experts)
layer.mlp = speedup_moe
del original_moe.experts
del original_moe
del original_experts
converted_count += 1
if layer_idx % 4 == 0:
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if update_record_config and hasattr(model.model, '_can_record_outputs'):
old_config = model.model._can_record_outputs
if 'router_logits' in old_config:
from transformers.utils.generic import OutputRecorder
model.model._can_record_outputs['router_logits'] = OutputRecorder(
SpeedupMoeSparseMoeBlock,
index=1)
print(f"[SpeedUp] Updated _can_record_outputs for router_logits")
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"[SpeedUp] Successfully converted {converted_count} MoE blocks")
print(f"[SpeedUp] Original expert parameters have been deleted")
return model
The aforementioned replacement procedure still requires considerable time. To further expedite this process, one can save the converted models and directly replace the Qwen3OmniMoeThinkerTextSparseMoeBlock module in the source code of modeling_qwen3_omni_moe.py with the following block:
class Qwen3OmniMoeThinkerTextExperts(nn.Module):
def __init__(self, num_experts, hidden_size, intermediate_size, act_fn):
super().__init__()
self.num_experts = num_experts
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.act_fn = act_fn
# gate_up_proj: (num_experts, hidden_size, 2 * intermediate_size)
self.gate_up_proj = nn.Parameter(
torch.empty(num_experts, hidden_size, 2 * intermediate_size)
)
# down_proj: (num_experts, intermediate_size, hidden_size)
self.down_proj = nn.Parameter(
torch.empty(num_experts, intermediate_size, hidden_size)
)
def forward(
self,
hidden_states: torch.Tensor,
routing_weights: torch.Tensor,
selected_experts: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * seq_len, hidden_size)
routing_weights: (batch_size * seq_len, top_k)
selected_experts: (batch_size * seq_len, top_k)
"""
batch_size, hidden_dim = hidden_states.shape
if self.training:
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, batch_size)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero().squeeze(-1)
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx])
current_state = hidden_states[top_x] # (num_tokens, hidden_size)
gate_up = current_state @ self.gate_up_proj[expert_idx] # (num_tokens, 2*intermediate)
gate, up = gate_up.chunk(2, dim=-1)
gated_output = up * self.act_fn(gate)
expert_out = gated_output @ self.down_proj[expert_idx] # (num_tokens, hidden_size)
weighted_output = expert_out * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype))
else:
hidden_states_expanded = hidden_states.repeat(self.num_experts, 1, 1)
# (num_experts, batch_size, hidden_size)
gate_up = torch.bmm(hidden_states_expanded, self.gate_up_proj)
gate, up = gate_up.chunk(2, dim=-1)
gated_output = up * self.act_fn(gate)
expert_outputs = torch.bmm(gated_output, self.down_proj)
# (num_experts, batch_size, hidden_size)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
# (batch_size, top_k, num_experts)
expert_mask = expert_mask.permute(2, 0, 1).float().to(routing_weights.dtype) # (num_experts, batch_size, top_k)
routing_weights_expanded = routing_weights.unsqueeze(0) # (1, batch_size, top_k)
weighted_mask = expert_mask * routing_weights_expanded # (num_experts, batch_size, top_k)
weighted_mask = weighted_mask.sum(dim=-1, keepdim=True) # (num_experts, batch_size, 1)
final_hidden_states = (expert_outputs * weighted_mask).sum(dim=0)
# (batch_size, hidden_size)
return final_hidden_states
class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = Qwen3OmniMoeThinkerTextExperts(
num_experts=self.num_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
act_fn=ACT2FN[config.hidden_act]
)
def forward(self, hidden_states: torch.Tensor) -> tuple:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
# Router logits
router_logits = self.gate(hidden_states_flat)
# Routing weights
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
# Expert forward
final_hidden_states = self.experts(hidden_states_flat, routing_weights, selected_experts)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits