Skip to content

riteshbhirud/MixtureOfExperts

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

44 Commits
 
 
 
 
 
 
 
 

Repository files navigation

MixtureOfExperts.jl

A comprehensive Julia library for Mixture of Experts (MoE) architectures with seamless integration to existing models, particularly Llama2.jl. This library implements state-of-the-art MoE techniques from recent research including Switch Transformer, DeepSeek V3, and Stanford CS336 methodologies.

Features

  • Multiple Gating Mechanisms: TopK, Switch, Expert Choice, Soft MoE, Hash routing
  • Expert Architectures: Standard FFN, Gated FFN (Llama-style), CUR compressed experts
  • Load Balancing: Switch Transformer loss, DeepSeek variants, auxiliary-free balancing
  • Llama2 Integration: Convert existing Llama2 models to MoE with preserved functionality
  • Advanced Features: Shared experts, expert virtualization, routing analysis

Quick Start

Basic MoE Layer

using MixtureOfExperts

config = create_moe_config(
    num_experts=8,
    expert_type=:gated,
    input_dim=512,
    hidden_dim=2048,
    output_dim=512,
    top_k=2,
    gate_type=TopKGating(2),
    balance_loss=SwitchTransformerLoss(0.01f0)
)

moe_layer = MoELayer(config)

input = randn(Float32, 512, 32)
output, balance_loss = moe_layer(input; training=true)

Llama2 Model Conversion

using Llama2
using MixtureOfExperts

original_model = Llama2.load_karpathy_model("stories42M.bin", "tokenizer.bin")

moe_model = convert_to_moe(
    original_model,
    [2, 4, 6];
    num_experts=8,
    top_k=2,
    expert_init_strategy=:perturb,
    expert_init_noise=0.01f0,
    gate_type=TopKGating(2),
    balance_loss=SwitchTransformerLoss(0.01f0),
    expert_type=:gated
)

println("Original parameters: $(count_llama_parameters(original_model))")
println("MoE total parameters: $(count_parameters(moe_model))")
println("MoE active parameters: $(count_active_parameters(moe_model))")

Text Generation with MoE

prompt = "Once upon a time"

original_output = Llama2.sample(original_model, prompt; temperature=0.9f0)

moe_output = sample_moe(moe_model, prompt; 
                       temperature=0.9f0,
                       show_expert_stats=true,
                       show_routing_entropy=true)

Expert Usage Analysis

test_prompts = [
    "The dragon flew",
    "In the forest", 
    "The magic spell",
    "Once upon a time"
]

for prompt in test_prompts
    println("Prompt: \"$prompt\"")
    
    result = sample_moe(moe_model, prompt;
                       temperature=0.3f0,
                       max_seq_len=50,
                       show_expert_stats=true,
                       expert_usage_threshold=0.05f0)
    
    println("Generated: \"$result\"\n")
end

Core API Reference

Model Conversion

convert_to_moe(model, moe_layers; kwargs...)

Convert existing Llama2 model to MoE by replacing specified layers.

Arguments:

  • model::Llama2.LanguageModel: Original model to convert
  • moe_layers::Vector{Int}: Layer indices to convert to MoE (1-based)

Key Options:

  • num_experts::Int=8: Number of experts per MoE layer
  • top_k::Int=2: Number of experts to activate per token
  • expert_init_strategy::Symbol=:perturb: Weight initialization (:copy, :perturb, :split, :random)
  • expert_type::Symbol=:gated: Expert architecture (:standard, :gated, :cur)
  • gate_type::GatingMechanism=TopKGating(top_k): Routing mechanism
  • balance_loss::LoadBalancingLoss=SwitchTransformerLoss(0.01f0): Load balancing

Returns: MoELanguageModel

Text Generation

sample_moe(model, prompt; kwargs...)

Generate text using MoE model with expert tracking.

Arguments:

  • model::MoELanguageModel: MoE model for generation
  • prompt::String="": Input text prompt

Key Options:

  • temperature::Float32=0.9f0: Sampling temperature
  • max_seq_len::Int=typemax(Int): Maximum sequence length
  • show_expert_stats::Bool=false: Display expert usage statistics
  • show_routing_entropy::Bool=false: Show routing entropy analysis
  • expert_usage_threshold::Float32=0.01f0: Threshold for reporting expert usage

Returns: String (generated text)

Model Analysis

count_parameters(model::MoELanguageModel)

Count total parameters in MoE model.

count_active_parameters(model::MoELanguageModel)

Count parameters active during inference (considering top-k routing).

get_expert_stats(model, tokens)

Analyze expert usage patterns for given token sequence.

Configuration

create_moe_config(; kwargs...)

Create MoE layer configuration with sensible defaults.

Key Options:

  • num_experts::Int=8: Number of experts
  • expert_type::Symbol=:standard: Expert architecture
  • top_k::Int=2: Experts to activate
  • gate_type::GatingMechanism=TopKGating(2): Routing mechanism
  • balance_loss::LoadBalancingLoss=SwitchTransformerLoss(0.01f0): Load balancing

Gating Mechanisms

TopKGating

Stanford CS336 implementation with softmax renormalization:

gate = TopKGating(k=2)

SwitchGating

Switch Transformer (k=1 special case):

gate = SwitchGating()

ExpertChoiceGating

Experts select tokens instead of tokens selecting experts:

gate = ExpertChoiceGating(capacity_factor=1.25f0)

Advanced Routing

gate = SoftMoEGating(k=2, λ=1.0f0)
gate = HashGating(k=2, num_experts=8)
gate = SharedExpertGating(num_shared=2, base_gate=TopKGating(2))

Expert Types

StandardExpert

Basic 2-layer FFN with configurable activation:

expert = StandardExpert(input_dim, hidden_dim, output_dim, gelu; dropout=0.1f0)

GatedExpert

Llama-style gated FFN: w2(silu(w1(x)) * w3(x)):

expert = GatedExpert(input_dim, hidden_dim, output_dim, silu)

CURExpert

Compressed expert using CUR decomposition:

expert = CURExpert(input_dim, hidden_dim, output_dim, gelu; rank=64)

Load Balancing

SwitchTransformerLoss

Original Switch Transformer auxiliary loss:

loss = SwitchTransformerLoss=0.01f0)

DeepSeekLoss

DeepSeek V1/V2 variants with device-aware balancing:

loss = DeepSeekLoss=0.01f0, balance_type=:device)

AuxiliaryFreeLoss

DeepSeek V3 innovation with online bias learning:

loss = AuxiliaryFreeLoss(num_experts=8, learning_rate=0.01f0)

File Structure

src/
├── MixtureOfExperts.jl          # Main module
├── gating/                      # Routing mechanisms
│   ├── base.jl                  # Abstract types and interfaces
│   ├── simple.jl                # RandomGating (testing/baseline)
│   ├── topk.jl                  # TopKGating (Stanford CS336)
│   ├── switch.jl                # SwitchGating, JitterGating
│   ├── expert_choice.jl         # ExpertChoiceGating
│   └── advanced.jl              # SoftMoE, HashGating, SharedExpert
├── experts/                     # Expert architectures
│   ├── standard.jl              # Basic 2-layer FFN
│   ├── gated.jl                 # Llama-style gated FFN
│   └── cur.jl                   # CUR decomposition experts
├── balancing/                   # Load balancing losses
│   ├── losses.jl                # Switch, DeepSeek, Z-loss
│   └── auxiliary_free.jl        # DeepSeek V3 innovation
├── core/                        # Core MoE components
│   ├── router.jl                # Neural routing network
│   ├── moe_layer.jl             # Main MoE layer implementation
│   └── utils.jl                 # Utility functions
└── llama2/                      # Llama2.jl integration
    ├── types.jl                 # MoE wrapper types
    ├── conversion.jl            # convert_to_moe functionality
    ├── inference.jl             # MoE transformer forward pass
    ├── attention.jl             # Attention with RoPE support
    ├── generation.jl            # sample_moe and text generation
    └── utils.jl                 # Save/load, analysis utilities

Advanced Usage

Custom Gating Mechanism

struct MyGating <: GatingMechanism
    k::Int
    temperature::Float32
end

function compute_gates(gate::MyGating, router_logits::AbstractMatrix)
    scaled_logits = router_logits ./ gate.temperature
    router_probs = softmax(scaled_logits; dims=1)
    
    expert_indices = zeros(Int, gate.k, size(router_logits, 2))
    expert_gates = zeros(Float32, gate.k, size(router_logits, 2))
    
    for i in 1:size(router_logits, 2)
        topk_indices = partialsortperm(router_probs[:, i], 1:gate.k, rev=true)
        expert_indices[:, i] = topk_indices
        expert_gates[:, i] = router_probs[topk_indices, i] ./ sum(router_probs[topk_indices, i])
    end
    
    return expert_indices, expert_gates, router_probs
end

Batch Generation

prompts = [
    "The dragon",
    "In the castle", 
    "Magic spell",
    "Forest adventure"
]

results = sample_moe_batch(moe_model, prompts;
                          temperature=0.7f0,
                          max_seq_len=100,
                          show_progress=true)

for (prompt, result) in zip(prompts, results)
    println("\"$prompt\"\"$result\"")
end

Model Saving and Loading

save_moe_model(moe_model, "my_moe_model.jls")

loaded_model = load_moe_model("my_moe_model.jls")

metadata = model_info(loaded_model)
println("Model info: $metadata")

Comparative Analysis

comparison = compare_models(original_model, moe_model, "The brave knight")

for (i, comp) in enumerate(comparison["comparisons"])
    println("Run $i:")
    println("  Original: $(comp["original_output"])")
    println("  MoE:      $(comp["moe_output"])")
    println()
end

Research Implementation Notes

This library implements techniques from:

  • Switch Transformer (Fedus et al., 2022): Core MoE architecture and load balancing
  • DeepSeek V1-V3 (2024): Shared experts, auxiliary-free balancing, advanced routing
  • Stanford CS336 (2024): Mathematical formulations and routing algorithms
  • Expert Choice Routing (Zhou et al., 2022): Alternative routing paradigm
  • CUR Decomposition: Memory-efficient expert compression

GPU Acceleration (CUDA)

The library provides full GPU acceleration using CUDA for significantly improved performance on NVIDIA GPUs.

Quick Start

using MixtureOfExperts

gpu_moe = create_cuda_moe(
    num_experts = 8,
    expert_type = :gated,
    input_dim = 768,
    hidden_dim = 3072,
    output_dim = 768,
    top_k = 2
)

gpu_input = CUDA.randn(Float32, 768, 32)  
gpu_output, balance_loss = cuda_moe_forward!(gpu_moe, gpu_input; training=true)

cpu_moe, gpu_moe, sync_success = create_synchronized_moe_pair(
    CudaMoEConfig(
        num_experts = 8,
        expert_type = :gated,
        input_dim = 768,
        hidden_dim = 3072,
        output_dim = 768,
        top_k = 2
    )
)

gpu_tensor = to_cuda(cpu_tensor)
cpu_tensor = to_cpu(gpu_tensor)

expert_counts, usage_percentages = get_cuda_expert_stats(expert_indices, num_experts)

test_input = generate_realistic_input(input_dim=768, batch_size=32)
config = CudaMoEConfig(
    num_experts = 16,
    expert_type = :gated,          
    input_dim = 1024,
    hidden_dim = 4096,
    output_dim = 1024,
    top_k = 2,
    noise_scale = 0.0f0,           
    use_noise_network = false,    
    balance_weight = 0.01f0        
)

gpu_moe = CudaMoELayer(config)

MoE-GPT2 Integration

Drop-in replacement for GPT-2 feedforward layers with Mixture of Experts (MoE).
Supports both TopK and Expert Choice routing with native transformer integration.


Quick Start

using MixtureOfExperts

# Create MoE-enabled GPT-2 transformer
config = moe_gpt2_base_config(
    vocab_size = 50257,
    num_experts = 8,
    expert_top_k = 2
)
model = create_moe_transformer_model(config; lm_head=true)

# Use like any transformer
input_ids = rand(1:config.vocab_size, 4, 32)  # (batch, sequence)
input_nt = prepare_transformer_inputs(input_ids)
output = model(input_nt)  # Returns NamedTuple with logit and aux_loss

Available Routing Mechanisms

TopK Routing (Default)

  • Method: Tokens choose their favorite K experts
  • Usage: Default in moe_gpt2_*_config() functions
  • Best for: General purpose, research, proven performance

Expert Choice Routing (Native Support)

  • Method: Experts choose their favorite tokens (up to capacity)
  • Superior load balancing: Typically 20–60% lower auxiliary loss
  • Better training stability: More consistent expert utilization
# Method 1: Via routing_type parameter
config = moe_gpt2_base_config(
    routing_type = :expert_choice,
    capacity_factor = 1.25f0,
    specialization_strength = 1.0f0
)
model = create_moe_transformer_model(config; lm_head=true)

# Method 2: Via convenience function
config = moe_gpt2_expert_choice_config(size=:base)
model = create_moe_transformer_model(config; lm_head=true)

# Method 3: Compare both routing methods
topk_config, ec_config = compare_moe_gpt2_routing_configs(num_experts=8)
topk_model = create_moe_transformer_model(topk_config; lm_head=true)
ec_model = create_moe_transformer_model(ec_config; lm_head=true)

# Compare load balancing
input_nt = prepare_transformer_inputs(rand(1:50257, 4, 32))
topk_output = topk_model(merge(input_nt, (training = true,)))
ec_output = ec_model(merge(input_nt, (training = true,)))
println("TopK aux loss: $(topk_output.aux_loss)")
println("Expert Choice aux loss: $(ec_output.aux_loss)")  # Typically 20-60% lower

Training

# Setup training (works with both routing types)
optimizer = Flux.Adam(3e-4)
opt_state = Flux.setup(optimizer, model)
targets = rand(1:config.vocab_size, 4, 32)

# Training step with auxiliary loss
function train_step!(model, opt_state, input_nt, targets)
    loss, grads = Flux.withgradient(model) do m
        output = m(merge(input_nt, (training = true,)))
        loss_info = create_moe_training_loss(output, targets)
        return loss_info.total_loss
    end
    Flux.update!(opt_state, model, grads[1])
    return loss
end

# Train with load balancing monitoring
for epoch in 1:10
    loss = train_step!(model, opt_state, input_nt, targets)
    
    # Monitor expert load balancing
    if epoch % 5 == 0
        output = model(merge(input_nt, (training = true,)))
        println("Epoch $epoch: Loss = $loss, Aux Loss = $(output.aux_loss)")
    end
end

Different GPT-2 Model Sizes

# Small GPT-2 model for experiments  
small_config = moe_gpt2_base_config(
    hidden_size = 512,
    num_experts = 4,
    expert_top_k = 2,
    routing_type = :expert_choice  # Better for smaller models
)

# Medium GPT-2 model for production with Expert Choice
medium_config = moe_gpt2_medium_config(
    routing_type = :expert_choice,
    capacity_factor = 1.25f0,  # Conservative capacity for stability
    specialization_strength = 1.0f0
)

# Large GPT-2 model with many experts
large_config = moe_gpt2_large_config(
    num_experts = 32,
    expert_top_k = 2,
    routing_type = :expert_choice,
    capacity_factor = 1.25f0,
    specialization_strength = 1.0f0
)

# Custom GPT-2 configuration with fine-tuned routing
custom_config = MoETransformerConfig(
    vocab_size = 32000,  # Custom vocabulary
    hidden_size = 1024,
    num_experts = 16,
    expert_type = :gated,  # :standard or :gated
    routing_type = :expert_choice,
    balance_loss_weight = 0.005f0,
    capacity_factor = 1.5f0,
    specialization_strength = 1.2f0
)

Routing Configuration Options

TopK Routing

config = moe_gpt2_topk_config(
    size = :base,                   # :base, :medium, or :large
    num_experts = 8,
    expert_top_k = 2,               # Number of experts per token
    balance_loss_weight = 0.01f0,   # Load balancing strength
    z_loss_weight = 0.001f0         # Logit regularization
)

Expert Choice Routing

config = moe_gpt2_expert_choice_config(
    size = :base,                         # :base, :medium, or :large
    num_experts = 8,
    capacity_factor = 1.25f0,             # How many tokens each expert can process
    specialization_strength = 1.0f0,      # Balance vs specialization trade-off
    balance_loss_weight = 0.005f0         # Usually lower than TopK
)

Inference Only

# Inference mode (no auxiliary loss computation)
input_nt = prepare_transformer_inputs(input_ids)
output = model(input_nt)  # aux_loss will be 0.0

# Get predictions
predictions = Flux.softmax(output.logit, dims=1)
next_tokens = argmax(predictions, dims=1)

Model Analysis & Monitoring

# Check model size and expert utilization
params = count_moe_parameters(model)
println("Total parameters: $(params.total)")
println("Expert parameters: $(params.expert)")
println("Parameter efficiency: $(round(params.expert_ratio * 100, digits=1))%")

# Analyze expert usage during training
output = model(merge(input_nt, (training = true,)))
println("Auxiliary loss: $(output.aux_loss)")  # Lower values = better load balancing

Advanced Usage

Adaptive Routing

# Switch routing method based on training phase
function adaptive_gpt2_routing_config(epoch, total_epochs)
    if epoch < total_epochs * 0.3  # Early training: focus on exploration
        return moe_gpt2_topk_config(expert_top_k = 2)
    else  # Later training: focus on efficiency  
        return moe_gpt2_expert_choice_config(
            capacity_factor = 1.25f0,
            specialization_strength = 1.1f0
        )
    end
end

Custom Expert Choice Configurations

research_config = moe_gpt2_expert_choice_config(
    size = :medium,
    capacity_factor = 2.0f0,
    specialization_strength = 0.8f0
)

production_config = moe_gpt2_expert_choice_config(
    size = :base,
    capacity_factor = 1.1f0,
    specialization_strength = 1.3f0,
    balance_loss_weight = 0.003f0
)

Convenience functions

new_config = moe_gpt2_base_config()
new_medium = moe_gpt2_medium_config()
new_large = moe_gpt2_large_config()

Key Benefits

  • Better Load Balancing – Expert Choice typically reduces aux loss by 20–60%
  • Training Stability – More consistent expert utilization across batches
  • Efficient Inference – Predictable computational load per expert
  • Production Ready – Handles edge cases and varying batch sizes robustly
  • Research Flexible – Both routing paradigms in the same codebase
  • GPT-2 Optimized – Configurations tuned specifically for GPT-2

Choose TopK for research & interpretability
Choose Expert Choice for production & training efficiency

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages