A comprehensive Julia library for Mixture of Experts (MoE) architectures with seamless integration to existing models, particularly Llama2.jl. This library implements state-of-the-art MoE techniques from recent research including Switch Transformer, DeepSeek V3, and Stanford CS336 methodologies.
- Multiple Gating Mechanisms: TopK, Switch, Expert Choice, Soft MoE, Hash routing
- Expert Architectures: Standard FFN, Gated FFN (Llama-style), CUR compressed experts
- Load Balancing: Switch Transformer loss, DeepSeek variants, auxiliary-free balancing
- Llama2 Integration: Convert existing Llama2 models to MoE with preserved functionality
- Advanced Features: Shared experts, expert virtualization, routing analysis
using MixtureOfExperts
config = create_moe_config(
num_experts=8,
expert_type=:gated,
input_dim=512,
hidden_dim=2048,
output_dim=512,
top_k=2,
gate_type=TopKGating(2),
balance_loss=SwitchTransformerLoss(0.01f0)
)
moe_layer = MoELayer(config)
input = randn(Float32, 512, 32)
output, balance_loss = moe_layer(input; training=true)using Llama2
using MixtureOfExperts
original_model = Llama2.load_karpathy_model("stories42M.bin", "tokenizer.bin")
moe_model = convert_to_moe(
original_model,
[2, 4, 6];
num_experts=8,
top_k=2,
expert_init_strategy=:perturb,
expert_init_noise=0.01f0,
gate_type=TopKGating(2),
balance_loss=SwitchTransformerLoss(0.01f0),
expert_type=:gated
)
println("Original parameters: $(count_llama_parameters(original_model))")
println("MoE total parameters: $(count_parameters(moe_model))")
println("MoE active parameters: $(count_active_parameters(moe_model))")prompt = "Once upon a time"
original_output = Llama2.sample(original_model, prompt; temperature=0.9f0)
moe_output = sample_moe(moe_model, prompt;
temperature=0.9f0,
show_expert_stats=true,
show_routing_entropy=true)test_prompts = [
"The dragon flew",
"In the forest",
"The magic spell",
"Once upon a time"
]
for prompt in test_prompts
println("Prompt: \"$prompt\"")
result = sample_moe(moe_model, prompt;
temperature=0.3f0,
max_seq_len=50,
show_expert_stats=true,
expert_usage_threshold=0.05f0)
println("Generated: \"$result\"\n")
endConvert existing Llama2 model to MoE by replacing specified layers.
Arguments:
model::Llama2.LanguageModel: Original model to convertmoe_layers::Vector{Int}: Layer indices to convert to MoE (1-based)
Key Options:
num_experts::Int=8: Number of experts per MoE layertop_k::Int=2: Number of experts to activate per tokenexpert_init_strategy::Symbol=:perturb: Weight initialization (:copy,:perturb,:split,:random)expert_type::Symbol=:gated: Expert architecture (:standard,:gated,:cur)gate_type::GatingMechanism=TopKGating(top_k): Routing mechanismbalance_loss::LoadBalancingLoss=SwitchTransformerLoss(0.01f0): Load balancing
Returns: MoELanguageModel
Generate text using MoE model with expert tracking.
Arguments:
model::MoELanguageModel: MoE model for generationprompt::String="": Input text prompt
Key Options:
temperature::Float32=0.9f0: Sampling temperaturemax_seq_len::Int=typemax(Int): Maximum sequence lengthshow_expert_stats::Bool=false: Display expert usage statisticsshow_routing_entropy::Bool=false: Show routing entropy analysisexpert_usage_threshold::Float32=0.01f0: Threshold for reporting expert usage
Returns: String (generated text)
Count total parameters in MoE model.
Count parameters active during inference (considering top-k routing).
Analyze expert usage patterns for given token sequence.
Create MoE layer configuration with sensible defaults.
Key Options:
num_experts::Int=8: Number of expertsexpert_type::Symbol=:standard: Expert architecturetop_k::Int=2: Experts to activategate_type::GatingMechanism=TopKGating(2): Routing mechanismbalance_loss::LoadBalancingLoss=SwitchTransformerLoss(0.01f0): Load balancing
Stanford CS336 implementation with softmax renormalization:
gate = TopKGating(k=2)Switch Transformer (k=1 special case):
gate = SwitchGating()Experts select tokens instead of tokens selecting experts:
gate = ExpertChoiceGating(capacity_factor=1.25f0)gate = SoftMoEGating(k=2, λ=1.0f0)
gate = HashGating(k=2, num_experts=8)
gate = SharedExpertGating(num_shared=2, base_gate=TopKGating(2))Basic 2-layer FFN with configurable activation:
expert = StandardExpert(input_dim, hidden_dim, output_dim, gelu; dropout=0.1f0)Llama-style gated FFN: w2(silu(w1(x)) * w3(x)):
expert = GatedExpert(input_dim, hidden_dim, output_dim, silu)Compressed expert using CUR decomposition:
expert = CURExpert(input_dim, hidden_dim, output_dim, gelu; rank=64)Original Switch Transformer auxiliary loss:
loss = SwitchTransformerLoss(α=0.01f0)DeepSeek V1/V2 variants with device-aware balancing:
loss = DeepSeekLoss(α=0.01f0, balance_type=:device)DeepSeek V3 innovation with online bias learning:
loss = AuxiliaryFreeLoss(num_experts=8, learning_rate=0.01f0)src/
├── MixtureOfExperts.jl # Main module
├── gating/ # Routing mechanisms
│ ├── base.jl # Abstract types and interfaces
│ ├── simple.jl # RandomGating (testing/baseline)
│ ├── topk.jl # TopKGating (Stanford CS336)
│ ├── switch.jl # SwitchGating, JitterGating
│ ├── expert_choice.jl # ExpertChoiceGating
│ └── advanced.jl # SoftMoE, HashGating, SharedExpert
├── experts/ # Expert architectures
│ ├── standard.jl # Basic 2-layer FFN
│ ├── gated.jl # Llama-style gated FFN
│ └── cur.jl # CUR decomposition experts
├── balancing/ # Load balancing losses
│ ├── losses.jl # Switch, DeepSeek, Z-loss
│ └── auxiliary_free.jl # DeepSeek V3 innovation
├── core/ # Core MoE components
│ ├── router.jl # Neural routing network
│ ├── moe_layer.jl # Main MoE layer implementation
│ └── utils.jl # Utility functions
└── llama2/ # Llama2.jl integration
├── types.jl # MoE wrapper types
├── conversion.jl # convert_to_moe functionality
├── inference.jl # MoE transformer forward pass
├── attention.jl # Attention with RoPE support
├── generation.jl # sample_moe and text generation
└── utils.jl # Save/load, analysis utilities
struct MyGating <: GatingMechanism
k::Int
temperature::Float32
end
function compute_gates(gate::MyGating, router_logits::AbstractMatrix)
scaled_logits = router_logits ./ gate.temperature
router_probs = softmax(scaled_logits; dims=1)
expert_indices = zeros(Int, gate.k, size(router_logits, 2))
expert_gates = zeros(Float32, gate.k, size(router_logits, 2))
for i in 1:size(router_logits, 2)
topk_indices = partialsortperm(router_probs[:, i], 1:gate.k, rev=true)
expert_indices[:, i] = topk_indices
expert_gates[:, i] = router_probs[topk_indices, i] ./ sum(router_probs[topk_indices, i])
end
return expert_indices, expert_gates, router_probs
endprompts = [
"The dragon",
"In the castle",
"Magic spell",
"Forest adventure"
]
results = sample_moe_batch(moe_model, prompts;
temperature=0.7f0,
max_seq_len=100,
show_progress=true)
for (prompt, result) in zip(prompts, results)
println("\"$prompt\" → \"$result\"")
endsave_moe_model(moe_model, "my_moe_model.jls")
loaded_model = load_moe_model("my_moe_model.jls")
metadata = model_info(loaded_model)
println("Model info: $metadata")comparison = compare_models(original_model, moe_model, "The brave knight")
for (i, comp) in enumerate(comparison["comparisons"])
println("Run $i:")
println(" Original: $(comp["original_output"])")
println(" MoE: $(comp["moe_output"])")
println()
endThis library implements techniques from:
- Switch Transformer (Fedus et al., 2022): Core MoE architecture and load balancing
- DeepSeek V1-V3 (2024): Shared experts, auxiliary-free balancing, advanced routing
- Stanford CS336 (2024): Mathematical formulations and routing algorithms
- Expert Choice Routing (Zhou et al., 2022): Alternative routing paradigm
- CUR Decomposition: Memory-efficient expert compression
The library provides full GPU acceleration using CUDA for significantly improved performance on NVIDIA GPUs.
using MixtureOfExperts
gpu_moe = create_cuda_moe(
num_experts = 8,
expert_type = :gated,
input_dim = 768,
hidden_dim = 3072,
output_dim = 768,
top_k = 2
)
gpu_input = CUDA.randn(Float32, 768, 32)
gpu_output, balance_loss = cuda_moe_forward!(gpu_moe, gpu_input; training=true)
cpu_moe, gpu_moe, sync_success = create_synchronized_moe_pair(
CudaMoEConfig(
num_experts = 8,
expert_type = :gated,
input_dim = 768,
hidden_dim = 3072,
output_dim = 768,
top_k = 2
)
)
gpu_tensor = to_cuda(cpu_tensor)
cpu_tensor = to_cpu(gpu_tensor)
expert_counts, usage_percentages = get_cuda_expert_stats(expert_indices, num_experts)
test_input = generate_realistic_input(input_dim=768, batch_size=32)
config = CudaMoEConfig(
num_experts = 16,
expert_type = :gated,
input_dim = 1024,
hidden_dim = 4096,
output_dim = 1024,
top_k = 2,
noise_scale = 0.0f0,
use_noise_network = false,
balance_weight = 0.01f0
)
gpu_moe = CudaMoELayer(config)Drop-in replacement for GPT-2 feedforward layers with Mixture of Experts (MoE).
Supports both TopK and Expert Choice routing with native transformer integration.
using MixtureOfExperts
# Create MoE-enabled GPT-2 transformer
config = moe_gpt2_base_config(
vocab_size = 50257,
num_experts = 8,
expert_top_k = 2
)
model = create_moe_transformer_model(config; lm_head=true)
# Use like any transformer
input_ids = rand(1:config.vocab_size, 4, 32) # (batch, sequence)
input_nt = prepare_transformer_inputs(input_ids)
output = model(input_nt) # Returns NamedTuple with logit and aux_loss- Method: Tokens choose their favorite
Kexperts - Usage: Default in
moe_gpt2_*_config()functions - Best for: General purpose, research, proven performance
- Method: Experts choose their favorite tokens (up to capacity)
- Superior load balancing: Typically 20–60% lower auxiliary loss
- Better training stability: More consistent expert utilization
# Method 1: Via routing_type parameter
config = moe_gpt2_base_config(
routing_type = :expert_choice,
capacity_factor = 1.25f0,
specialization_strength = 1.0f0
)
model = create_moe_transformer_model(config; lm_head=true)
# Method 2: Via convenience function
config = moe_gpt2_expert_choice_config(size=:base)
model = create_moe_transformer_model(config; lm_head=true)
# Method 3: Compare both routing methods
topk_config, ec_config = compare_moe_gpt2_routing_configs(num_experts=8)
topk_model = create_moe_transformer_model(topk_config; lm_head=true)
ec_model = create_moe_transformer_model(ec_config; lm_head=true)
# Compare load balancing
input_nt = prepare_transformer_inputs(rand(1:50257, 4, 32))
topk_output = topk_model(merge(input_nt, (training = true,)))
ec_output = ec_model(merge(input_nt, (training = true,)))
println("TopK aux loss: $(topk_output.aux_loss)")
println("Expert Choice aux loss: $(ec_output.aux_loss)") # Typically 20-60% lower# Setup training (works with both routing types)
optimizer = Flux.Adam(3e-4)
opt_state = Flux.setup(optimizer, model)
targets = rand(1:config.vocab_size, 4, 32)
# Training step with auxiliary loss
function train_step!(model, opt_state, input_nt, targets)
loss, grads = Flux.withgradient(model) do m
output = m(merge(input_nt, (training = true,)))
loss_info = create_moe_training_loss(output, targets)
return loss_info.total_loss
end
Flux.update!(opt_state, model, grads[1])
return loss
end
# Train with load balancing monitoring
for epoch in 1:10
loss = train_step!(model, opt_state, input_nt, targets)
# Monitor expert load balancing
if epoch % 5 == 0
output = model(merge(input_nt, (training = true,)))
println("Epoch $epoch: Loss = $loss, Aux Loss = $(output.aux_loss)")
end
end# Small GPT-2 model for experiments
small_config = moe_gpt2_base_config(
hidden_size = 512,
num_experts = 4,
expert_top_k = 2,
routing_type = :expert_choice # Better for smaller models
)
# Medium GPT-2 model for production with Expert Choice
medium_config = moe_gpt2_medium_config(
routing_type = :expert_choice,
capacity_factor = 1.25f0, # Conservative capacity for stability
specialization_strength = 1.0f0
)
# Large GPT-2 model with many experts
large_config = moe_gpt2_large_config(
num_experts = 32,
expert_top_k = 2,
routing_type = :expert_choice,
capacity_factor = 1.25f0,
specialization_strength = 1.0f0
)
# Custom GPT-2 configuration with fine-tuned routing
custom_config = MoETransformerConfig(
vocab_size = 32000, # Custom vocabulary
hidden_size = 1024,
num_experts = 16,
expert_type = :gated, # :standard or :gated
routing_type = :expert_choice,
balance_loss_weight = 0.005f0,
capacity_factor = 1.5f0,
specialization_strength = 1.2f0
)config = moe_gpt2_topk_config(
size = :base, # :base, :medium, or :large
num_experts = 8,
expert_top_k = 2, # Number of experts per token
balance_loss_weight = 0.01f0, # Load balancing strength
z_loss_weight = 0.001f0 # Logit regularization
)config = moe_gpt2_expert_choice_config(
size = :base, # :base, :medium, or :large
num_experts = 8,
capacity_factor = 1.25f0, # How many tokens each expert can process
specialization_strength = 1.0f0, # Balance vs specialization trade-off
balance_loss_weight = 0.005f0 # Usually lower than TopK
)# Inference mode (no auxiliary loss computation)
input_nt = prepare_transformer_inputs(input_ids)
output = model(input_nt) # aux_loss will be 0.0
# Get predictions
predictions = Flux.softmax(output.logit, dims=1)
next_tokens = argmax(predictions, dims=1)# Check model size and expert utilization
params = count_moe_parameters(model)
println("Total parameters: $(params.total)")
println("Expert parameters: $(params.expert)")
println("Parameter efficiency: $(round(params.expert_ratio * 100, digits=1))%")
# Analyze expert usage during training
output = model(merge(input_nt, (training = true,)))
println("Auxiliary loss: $(output.aux_loss)") # Lower values = better load balancing# Switch routing method based on training phase
function adaptive_gpt2_routing_config(epoch, total_epochs)
if epoch < total_epochs * 0.3 # Early training: focus on exploration
return moe_gpt2_topk_config(expert_top_k = 2)
else # Later training: focus on efficiency
return moe_gpt2_expert_choice_config(
capacity_factor = 1.25f0,
specialization_strength = 1.1f0
)
end
endresearch_config = moe_gpt2_expert_choice_config(
size = :medium,
capacity_factor = 2.0f0,
specialization_strength = 0.8f0
)
production_config = moe_gpt2_expert_choice_config(
size = :base,
capacity_factor = 1.1f0,
specialization_strength = 1.3f0,
balance_loss_weight = 0.003f0
)new_config = moe_gpt2_base_config()
new_medium = moe_gpt2_medium_config()
new_large = moe_gpt2_large_config()- Better Load Balancing – Expert Choice typically reduces aux loss by 20–60%
- Training Stability – More consistent expert utilization across batches
- Efficient Inference – Predictable computational load per expert
- Production Ready – Handles edge cases and varying batch sizes robustly
- Research Flexible – Both routing paradigms in the same codebase
- GPT-2 Optimized – Configurations tuned specifically for GPT-2
Choose TopK for research & interpretability
Choose Expert Choice for production & training efficiency