Skip to content

Training: Extremely Slow Deepspeed Engine initialization and unkown error causing NCCL timeout. #115

@zhyang2226

Description

@zhyang2226

Description

I am attempting to train Qwen3-Omni Thinker using DeepSpeed-ZeRO Stage 2. During this process, I noticed that the DeepSpeed engine initialization is extremely slow, which may be attributed to the discrete expert architecture within each layer. This design results in approximately 10,000 modules that DeepSpeed must initialize. Although the initialization eventually completes, NCCL timeout errors frequently occur during the training phase, typically within 1-20 steps.

Initially, I suspected these errors were data-related; however, the same dataset can successfully train Qwen3-VL-30BA3B without issues. Upon further investigation, I discovered that Qwen3-VL-30BA3B concatenates the experts within each layer, thereby enabling parallel computation. This architectural difference appears to resolve all the aforementioned issues. Consequently, I created a script to replace the Qwen3OmniMoeThinkerTextSparseMoeBlock module.

Following this replacement, the DeepSpeed engine initialization became approximately 10× faster, and the NCCL timeout errors were eliminated entirely. Additionally, this modification accelerated inference using the transformers model.generate method by approximately 3×.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import copy


class SpeedupMoeExperts(nn.Module):
    
    def __init__(self, num_experts, hidden_size, intermediate_size, act_fn):
        super().__init__()
        self.num_experts = num_experts
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.act_fn = act_fn
        
        # gate_up_proj: (num_experts, hidden_size, 2 * intermediate_size)
        self.gate_up_proj = nn.Parameter(
            torch.empty(num_experts, hidden_size, 2 * intermediate_size)
        )
        # down_proj: (num_experts, intermediate_size, hidden_size)
        self.down_proj = nn.Parameter(
            torch.empty(num_experts, intermediate_size, hidden_size)
        )
        
    def forward(
        self, 
        hidden_states: torch.Tensor, 
        routing_weights: torch.Tensor, 
        selected_experts: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: (batch_size * seq_len, hidden_size)
            routing_weights: (batch_size * seq_len, top_k)
            selected_experts: (batch_size * seq_len, top_k)
        """
        batch_size, hidden_dim = hidden_states.shape
        
        if self.training:
            final_hidden_states = torch.zeros_like(hidden_states)
            
            expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
            expert_mask = expert_mask.permute(2, 1, 0)  # (num_experts, top_k, batch_size)
            
            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero().squeeze(-1)
            
            for expert_idx in expert_hit:
                idx, top_x = torch.where(expert_mask[expert_idx])
                
                current_state = hidden_states[top_x]  # (num_tokens, hidden_size)
                
                gate_up = current_state @ self.gate_up_proj[expert_idx]  # (num_tokens, 2*intermediate)
                gate, up = gate_up.chunk(2, dim=-1)
                gated_output = up * self.act_fn(gate)
                expert_out = gated_output @ self.down_proj[expert_idx]  # (num_tokens, hidden_size)
                
                weighted_output = expert_out * routing_weights[top_x, idx, None]
                
                final_hidden_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype))
                
        else:
            hidden_states_expanded = hidden_states.repeat(self.num_experts, 1, 1)
            # (num_experts, batch_size, hidden_size)
            
            gate_up = torch.bmm(hidden_states_expanded, self.gate_up_proj)
            gate, up = gate_up.chunk(2, dim=-1)
            gated_output = up * self.act_fn(gate)
            expert_outputs = torch.bmm(gated_output, self.down_proj)
            # (num_experts, batch_size, hidden_size)
            
            expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
            # (batch_size, top_k, num_experts)
            expert_mask = expert_mask.permute(2, 0, 1).float()  # (num_experts, batch_size, top_k)
            
            routing_weights_expanded = routing_weights.unsqueeze(0)  # (1, batch_size, top_k)
            weighted_mask = expert_mask * routing_weights_expanded  # (num_experts, batch_size, top_k)
            weighted_mask = weighted_mask.sum(dim=-1, keepdim=True)  # (num_experts, batch_size, 1)
            
            final_hidden_states = (expert_outputs * weighted_mask).sum(dim=0)
            # (batch_size, hidden_size)
        
        return final_hidden_states


class SpeedupMoeSparseMoeBlock(nn.Module):
    
    def __init__(self, config_or_original_moe, original_experts=None):
        super().__init__()
        
        if original_experts is not None:
            original_moe = config_or_original_moe
            self.num_experts = original_moe.num_experts
            self.top_k = original_moe.top_k
            self.norm_topk_prob = original_moe.norm_topk_prob
            
            self.gate = original_moe.gate
            
            first_expert = original_experts[0]
            self.experts = SpeedupMoeExperts(
                num_experts=self.num_experts,
                hidden_size=first_expert.hidden_size,
                intermediate_size=first_expert.intermediate_size,
                act_fn=first_expert.act_fn
            )
            
            self._copy_weights_from_original(original_experts)
        else:
            config = config_or_original_moe
            self.num_experts = config['num_experts']
            self.top_k = config['top_k']
            self.norm_topk_prob = config['norm_topk_prob']
            
            self.gate = None
            self.experts = None
        
    def _copy_weights_from_original(self, original_experts):
        with torch.no_grad():
            for expert_idx, expert in enumerate(original_experts):
                gate_weight = expert.gate_proj.weight.data.t()  # (hidden, intermediate)
                up_weight = expert.up_proj.weight.data.t()      # (hidden, intermediate)
                gate_up_weight = torch.cat([gate_weight, up_weight], dim=1)  # (hidden, 2*intermediate)
                self.experts.gate_up_proj.data[expert_idx].copy_(gate_up_weight)
                
                down_weight = expert.down_proj.weight.data.t()  # (intermediate, hidden)
                self.experts.down_proj.data[expert_idx].copy_(down_weight)
    
    def forward(self, hidden_states: torch.Tensor) -> tuple:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states_flat = hidden_states.view(-1, hidden_dim)
        
        # Router logits
        router_logits = self.gate(hidden_states_flat)
        
        # Routing weights
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        
        if self.norm_topk_prob:
            routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
        
        routing_weights = routing_weights.to(hidden_states.dtype)
        
        # Expert forward
        final_hidden_states = self.experts(hidden_states_flat, routing_weights, selected_experts)
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        
        return final_hidden_states, router_logits


def apply_speedup_moe_block(model, update_record_config=True):
    converted_count = 0
    
    for layer_idx, layer in enumerate(model.model.layers):
        if type(layer.mlp).__name__ == "Qwen3OmniMoeThinkerTextSparseMoeBlock":
            print(f"[SpeedUp] Converting MoE block in layer {layer_idx}...")
            original_moe = layer.mlp
            
            original_experts = original_moe.experts
            
            speedup_moe = SpeedupMoeSparseMoeBlock(original_moe, original_experts)
            
            layer.mlp = speedup_moe
            
            del original_moe.experts
            del original_moe
            del original_experts
            
            converted_count += 1
            
            if layer_idx % 4 == 0:
                import gc
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

    if update_record_config and hasattr(model.model, '_can_record_outputs'):
        old_config = model.model._can_record_outputs
        
        if 'router_logits' in old_config:
            from transformers.utils.generic import OutputRecorder
            model.model._can_record_outputs['router_logits'] = OutputRecorder(
                SpeedupMoeSparseMoeBlock, 
                index=1)
            print(f"[SpeedUp] Updated _can_record_outputs for router_logits")

    import gc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print(f"[SpeedUp] Successfully converted {converted_count} MoE blocks")
    print(f"[SpeedUp] Original expert parameters have been deleted")
    return model

The aforementioned replacement procedure still requires considerable time. To further expedite this process, one can save the converted models and directly replace the Qwen3OmniMoeThinkerTextSparseMoeBlock module in the source code of modeling_qwen3_omni_moe.py with the following block:

class Qwen3OmniMoeThinkerTextExperts(nn.Module):
    def __init__(self, num_experts, hidden_size, intermediate_size, act_fn):
        super().__init__()
        self.num_experts = num_experts
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.act_fn = act_fn
        
        # gate_up_proj: (num_experts, hidden_size, 2 * intermediate_size)
        self.gate_up_proj = nn.Parameter(
            torch.empty(num_experts, hidden_size, 2 * intermediate_size)
        )
        # down_proj: (num_experts, intermediate_size, hidden_size)
        self.down_proj = nn.Parameter(
            torch.empty(num_experts, intermediate_size, hidden_size)
        )
        
    def forward(
        self, 
        hidden_states: torch.Tensor, 
        routing_weights: torch.Tensor, 
        selected_experts: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: (batch_size * seq_len, hidden_size)
            routing_weights: (batch_size * seq_len, top_k)
            selected_experts: (batch_size * seq_len, top_k)
        """
        batch_size, hidden_dim = hidden_states.shape
        
        if self.training:
            final_hidden_states = torch.zeros_like(hidden_states)
            
            expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
            expert_mask = expert_mask.permute(2, 1, 0)  # (num_experts, top_k, batch_size)
            
            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero().squeeze(-1)
            
            for expert_idx in expert_hit:
                idx, top_x = torch.where(expert_mask[expert_idx])
                
                current_state = hidden_states[top_x]  # (num_tokens, hidden_size)
                
                gate_up = current_state @ self.gate_up_proj[expert_idx]  # (num_tokens, 2*intermediate)
                gate, up = gate_up.chunk(2, dim=-1)
                gated_output = up * self.act_fn(gate)
                expert_out = gated_output @ self.down_proj[expert_idx]  # (num_tokens, hidden_size)
                
                weighted_output = expert_out * routing_weights[top_x, idx, None]
                
                final_hidden_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype))
                
        else:
            hidden_states_expanded = hidden_states.repeat(self.num_experts, 1, 1)
            # (num_experts, batch_size, hidden_size)
            
            gate_up = torch.bmm(hidden_states_expanded, self.gate_up_proj)
            gate, up = gate_up.chunk(2, dim=-1)
            gated_output = up * self.act_fn(gate)
            expert_outputs = torch.bmm(gated_output, self.down_proj)
            # (num_experts, batch_size, hidden_size)
            
            expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
            # (batch_size, top_k, num_experts)
            expert_mask = expert_mask.permute(2, 0, 1).float().to(routing_weights.dtype)  # (num_experts, batch_size, top_k)
            
            routing_weights_expanded = routing_weights.unsqueeze(0)  # (1, batch_size, top_k)
            weighted_mask = expert_mask * routing_weights_expanded  # (num_experts, batch_size, top_k)
            weighted_mask = weighted_mask.sum(dim=-1, keepdim=True)  # (num_experts, batch_size, 1)
            
            final_hidden_states = (expert_outputs * weighted_mask).sum(dim=0)
            # (batch_size, hidden_size)
        
        return final_hidden_states


class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob

        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)

        self.experts = Qwen3OmniMoeThinkerTextExperts(
                num_experts=self.num_experts,
                hidden_size=config.hidden_size,
                intermediate_size=config.moe_intermediate_size,
                act_fn=ACT2FN[config.hidden_act]
            )

    
    def forward(self, hidden_states: torch.Tensor) -> tuple:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states_flat = hidden_states.view(-1, hidden_dim)
        
        # Router logits
        router_logits = self.gate(hidden_states_flat)
        
        # Routing weights
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        
        if self.norm_topk_prob:
            routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
        
        routing_weights = routing_weights.to(hidden_states.dtype)
        
        # Expert forward
        final_hidden_states = self.experts(hidden_states_flat, routing_weights, selected_experts)
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        
        return final_hidden_states, router_logits

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions