diff --git a/build/457.json b/build/457.json new file mode 100644 index 00000000..dddf7be0 --- /dev/null +++ b/build/457.json @@ -0,0 +1,54 @@ +{ + "id": "457", + "title": "Implement Sigmoid MoE Router with Bias Correction", + "difficulty": "hard", + "category": "Deep Learning", + "video": "", + "likes": "0", + "dislikes": "0", + "contributor": [], + "pytorch_difficulty": "hard", + "description": "## Problem\n\nWrite a Python function `sigmoid_moe_router(hidden_states, gate_weight, score_bias, top_k)` that implements the sigmoid-based Mixture of Experts routing used in MiniMax M2.5. The function takes:\n- `hidden_states`: array of shape `(num_tokens, hidden_dim)`\n- `gate_weight`: array of shape `(num_experts, hidden_dim)`\n- `score_bias`: array of shape `(num_experts,)` for expert load balancing\n- `top_k`: number of experts to select per token\n\nThe function should:\n1. Compute router logits via matrix multiplication\n2. Apply sigmoid activation (not softmax) to get routing weights\n3. Add the score bias to determine expert selection\n4. Select the top-k experts per token\n5. Gather the actual sigmoid weights (without bias) for selected experts\n6. Normalize the selected weights to sum to 1\n\nReturn a tuple of `(top_k_weights, top_k_indices)` where weights has shape `(num_tokens, top_k)` and indices has shape `(num_tokens, top_k)`. Only use numpy.", + "learn_section": "# **Sigmoid MoE Router with Bias Correction**\n\n## **1. Definition**\nThe Sigmoid MoE Router is the expert routing mechanism used in MiniMax M2.5. Unlike traditional MoE routers that use softmax scoring, this router uses **sigmoid activation** with a **learned bias correction** for expert selection, followed by weight normalization.\n\n## **2. Why Sigmoid Instead of Softmax?**\n- **Independent scoring:** Sigmoid scores each expert independently, while softmax creates competition between experts. This allows for more flexible expert utilization.\n- **Better load balancing:** The learned bias correction term helps distribute tokens more evenly across experts without auxiliary losses dominating training.\n- **Simpler gradient flow:** Sigmoid gradients don't depend on other experts' scores.\n\n## **3. Algorithm**\nGiven hidden states $H \\in \\mathbb{R}^{T \\times d}$, gate weights $W_g \\in \\mathbb{R}^{E \\times d}$, and bias correction $b \\in \\mathbb{R}^{E}$:\n\n**Step 1: Compute logits**\n$$\\text{logits} = H \\cdot W_g^T \\quad \\in \\mathbb{R}^{T \\times E}$$\n\n**Step 2: Apply sigmoid**\n$$w = \\sigma(\\text{logits}) = \\frac{1}{1 + e^{-\\text{logits}}} \\quad \\in \\mathbb{R}^{T \\times E}$$\n\n**Step 3: Bias-corrected scores for selection**\n$$s = w + b \\quad \\in \\mathbb{R}^{T \\times E}$$\n\n**Step 4: Top-k selection**\n$$\\text{indices} = \\text{argsort}(s, \\text{descending})[:, :k]$$\n\n**Step 5: Gather actual weights (without bias)**\n$$w_{\\text{selected}} = \\text{gather}(w, \\text{indices})$$\n\n**Step 6: Normalize**\n$$\\hat{w} = \\frac{w_{\\text{selected}}}{\\sum_{j=1}^{k} w_{\\text{selected}, j}}$$\n\n## **4. Key Design Choices**\n- The **bias is only used for selection**, not for the final weights. This decouples load balancing from the actual routing weights.\n- The bias $b$ is a **learned parameter** that adjusts which experts are preferred, compensating for natural imbalances.\n- The final weights are **normalized** so they sum to 1, making the weighted combination of expert outputs a proper convex combination.\n\n## **5. Role in MiniMax M2.5**\n- **256 experts** per layer, **top-8** selected per token\n- Gate: `Linear(3072, 256, bias=False)`\n- Bias correction: learned vector of size 256\n- Only ~3% of experts activated per token (8/256)", + "starter_code": "import numpy as np\n\n# Implement your function below.\ndef sigmoid_moe_router(hidden_states, gate_weight, score_bias, top_k):\n \"\"\"\n Implement sigmoid-based MoE routing with bias correction.\n\n Args:\n hidden_states (np.ndarray): Token representations, shape (num_tokens, hidden_dim).\n gate_weight (np.ndarray): Gate projection weights, shape (num_experts, hidden_dim).\n score_bias (np.ndarray): Learned bias for load balancing, shape (num_experts,).\n top_k (int): Number of experts to select per token.\n\n Returns:\n tuple: (top_k_weights, top_k_indices)\n - top_k_weights: Normalized routing weights, shape (num_tokens, top_k).\n - top_k_indices: Selected expert indices, shape (num_tokens, top_k).\n \"\"\"\n pass", + "solution": "import numpy as np\n\ndef sigmoid_moe_router(hidden_states, gate_weight, score_bias, top_k):\n logits = hidden_states @ gate_weight.T\n routing_weights = 1.0 / (1.0 + np.exp(-logits))\n scores_for_choice = routing_weights + score_bias\n\n top_k_indices = np.argsort(-scores_for_choice, axis=-1)[:, :top_k]\n\n top_k_weights = np.take_along_axis(routing_weights, top_k_indices, axis=-1)\n top_k_weights = top_k_weights / np.sum(top_k_weights, axis=-1, keepdims=True)\n\n return top_k_weights.astype(float), top_k_indices", + "example": { + "input": "import numpy as np\nhidden = np.array([[1.0, 0.0], [0.0, 1.0]])\ngate = np.array([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])\nbias = np.zeros(4)\nw, idx = sigmoid_moe_router(hidden, gate, bias, 2)\nprint(np.round(w, 4))\nprint(idx)", + "output": "[[0.5938 0.4062]\n [0.5938 0.4062]]\n[[0 1]\n [1 0]]", + "reasoning": "For token [1,0]: logits=[1,0,-1,0], sigmoid=[0.731,0.5,0.269,0.5]. Top-2 by score are experts 0,1 with weights [0.731,0.5]. Normalized: [0.731/1.231, 0.5/1.231] = [0.5938, 0.4062]." + }, + "test_cases": [ + { + "test": "import numpy as np\nhidden = np.array([[1.0, 0.0], [0.0, 1.0]])\ngate = np.array([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])\nbias = np.zeros(4)\nw, idx = sigmoid_moe_router(hidden, gate, bias, 2)\nprint(np.round(w, 4))\nprint(idx)", + "expected_output": "[[0.5938 0.4062]\n [0.5938 0.4062]]\n[[0 1]\n [1 0]]" + }, + { + "test": "import numpy as np\nhidden = np.array([[0.0, 0.0]])\ngate = np.array([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0]])\nbias = np.array([0.0, 0.0, 1.0])\nw, idx = sigmoid_moe_router(hidden, gate, bias, 2)\nprint(np.round(w, 4))\nprint(idx)", + "expected_output": "[[0.5 0.5]]\n[[2 0]]" + }, + { + "test": "import numpy as np\nnp.random.seed(42)\nhidden = np.random.randn(2, 4)\ngate = np.random.randn(6, 4)\nbias = np.array([0.1, -0.1, 0.2, -0.2, 0.0, 0.0])\nw, idx = sigmoid_moe_router(hidden, gate, bias, 3)\nprint(np.round(w, 4))\nprint(idx)", + "expected_output": "[[0.6001 0.2588 0.1412]\n [0.6644 0.2488 0.0868]]\n[[5 4 0]\n [5 0 2]]" + }, + { + "test": "import numpy as np\nhidden = np.array([[1.0, 1.0]])\ngate = np.array([[1.0, 1.0], [2.0, 2.0], [0.0, 0.0]])\nbias = np.zeros(3)\nw, idx = sigmoid_moe_router(hidden, gate, bias, 1)\nprint(np.round(w, 4))\nprint(idx)", + "expected_output": "[[1.]]\n[[1]]" + } + ], + "pytorch_starter_code": "import torch\n\ndef sigmoid_moe_router(hidden_states, gate_weight, score_bias, top_k):\n \"\"\"\n Implement sigmoid-based MoE routing with bias correction.\n\n Args:\n hidden_states (torch.Tensor): Token representations, shape (num_tokens, hidden_dim).\n gate_weight (torch.Tensor): Gate projection weights, shape (num_experts, hidden_dim).\n score_bias (torch.Tensor): Learned bias for load balancing, shape (num_experts,).\n top_k (int): Number of experts to select per token.\n\n Returns:\n tuple: (top_k_weights, top_k_indices)\n \"\"\"\n pass", + "pytorch_solution": "import torch\n\ndef sigmoid_moe_router(hidden_states, gate_weight, score_bias, top_k):\n logits = hidden_states @ gate_weight.T\n routing_weights = torch.sigmoid(logits)\n scores_for_choice = routing_weights + score_bias\n\n top_k_indices = torch.topk(scores_for_choice, top_k, dim=-1).indices\n\n top_k_weights = torch.gather(routing_weights, 1, top_k_indices)\n top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)\n\n return top_k_weights.float(), top_k_indices", + "pytorch_test_cases": [ + { + "test": "import torch\nhidden = torch.tensor([[1.0, 0.0], [0.0, 1.0]])\ngate = torch.tensor([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])\nbias = torch.zeros(4)\nw, idx = sigmoid_moe_router(hidden, gate, bias, 2)\nprint(torch.round(w, decimals=4))\nprint(idx)", + "expected_output": "tensor([[0.5938, 0.4062],\n [0.5938, 0.4062]])\ntensor([[0, 1],\n [1, 0]])" + }, + { + "test": "import torch\nhidden = torch.tensor([[0.0, 0.0]])\ngate = torch.tensor([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0]])\nbias = torch.tensor([0.0, 0.0, 1.0])\nw, idx = sigmoid_moe_router(hidden, gate, bias, 2)\nprint(torch.round(w, decimals=4))\nprint(idx)", + "expected_output": "tensor([[0.5000, 0.5000]])\ntensor([[2, 0]])" + }, + { + "test": "import torch\nhidden = torch.tensor([[1.0, 1.0]])\ngate = torch.tensor([[1.0, 1.0], [2.0, 2.0], [0.0, 0.0]])\nbias = torch.zeros(3)\nw, idx = sigmoid_moe_router(hidden, gate, bias, 1)\nprint(torch.round(w, decimals=4))\nprint(idx)", + "expected_output": "tensor([[1.]])\ntensor([[1]])" + } + ] +} \ No newline at end of file diff --git a/build/458.json b/build/458.json new file mode 100644 index 00000000..902eda05 --- /dev/null +++ b/build/458.json @@ -0,0 +1,58 @@ +{ + "id": "458", + "title": "Implement Lightning Attention (Linear Attention)", + "difficulty": "hard", + "category": "Deep Learning", + "video": "", + "likes": "0", + "dislikes": "0", + "contributor": [], + "pytorch_difficulty": "hard", + "description": "## Problem\n\nWrite a Python function `lightning_attention(Q, K, V, decay)` that implements causal linear attention with exponential decay, as used in Lightning Attention. The function takes:\n- `Q`: query array of shape `(seq_len, head_dim)`\n- `K`: key array of shape `(seq_len, head_dim)`\n- `V`: value array of shape `(seq_len, head_dim)`\n- `decay`: a float decay factor (lambda) between 0 and 1\n\nInstead of computing softmax(QK^T)V, compute the output using the recurrent form:\n- Maintain a state `S` of shape `(head_dim, head_dim)` initialized to zeros\n- At each timestep t: `S_t = decay * S_{t-1} + K_t^T @ V_t`, then `O_t = Q_t @ S_t`\n\nReturn the output array of shape `(seq_len, head_dim)` as floats. Only use numpy.", + "learn_section": "# **Lightning Attention (Linear Attention with Decay)**\n\n## **1. Definition**\nLightning Attention is a linear attention mechanism that replaces the quadratic softmax attention with a recurrent formulation. It achieves $O(nd^2)$ complexity instead of $O(n^2d)$, making it efficient for very long sequences. It was developed by the MiniMax team and used in MiniMax-01.\n\n## **2. Standard Attention vs Linear Attention**\n\n**Standard (Softmax) Attention:**\n$$O = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d}}\\right) V \\quad \\in O(n^2 d)$$\n\n**Linear Attention (kernel trick):**\n$$O_t = Q_t \\cdot \\sum_{s \\leq t} K_s^T V_s = Q_t \\cdot S_t \\quad \\in O(n d^2)$$\n\nWhere $S_t = \\sum_{s \\leq t} K_s^T V_s$ is a running state that accumulates key-value outer products.\n\n## **3. Adding Exponential Decay**\nTo prevent the state from growing unboundedly and to focus on more recent tokens, Lightning Attention adds an exponential decay factor $\\lambda$:\n\n$$S_t = \\lambda \\cdot S_{t-1} + K_t^T V_t$$\n$$O_t = Q_t \\cdot S_t$$\n\nWhere:\n- $S_t \\in \\mathbb{R}^{d \\times d}$ is the recurrent state\n- $\\lambda \\in (0, 1)$ is the decay factor\n- $K_t^T V_t$ is the outer product of key and value at position $t$\n\n## **4. Recurrent Interpretation**\nThe decay creates an exponentially weighted sum over history:\n\n$$S_t = \\sum_{s=1}^{t} \\lambda^{t-s} K_s^T V_s$$\n\nMore recent tokens contribute more strongly than distant ones, providing a natural notion of locality.\n\n## **5. Advantages**\n- **Linear complexity:** $O(nd^2)$ instead of $O(n^2d)$ — better for long sequences when $n \\gg d$\n- **Constant memory per step:** Only need to maintain the $d \\times d$ state matrix\n- **Streamable:** Can process tokens one at a time in a recurrent fashion\n- **Infinite context (in theory):** No fixed context window limitation\n\n## **6. Role in MiniMax Architecture**\nIn MiniMax-01, the predecessor to M2.5, Lightning Attention was used in a hybrid pattern: 7 linear attention layers followed by 1 softmax attention layer, repeated across the network. This combined the efficiency of linear attention with the expressiveness of softmax attention for long-range dependencies.", + "starter_code": "import numpy as np\n\n# Implement your function below.\ndef lightning_attention(Q, K, V, decay):\n \"\"\"\n Implement causal linear attention with exponential decay.\n\n Args:\n Q (np.ndarray): Query array of shape (seq_len, head_dim).\n K (np.ndarray): Key array of shape (seq_len, head_dim).\n V (np.ndarray): Value array of shape (seq_len, head_dim).\n decay (float): Exponential decay factor (lambda), between 0 and 1.\n\n Returns:\n np.ndarray: Output array of shape (seq_len, head_dim).\n \"\"\"\n pass", + "solution": "import numpy as np\n\ndef lightning_attention(Q, K, V, decay):\n seq_len, head_dim = Q.shape\n S = np.zeros((head_dim, head_dim))\n output = np.zeros((seq_len, head_dim))\n\n for t in range(seq_len):\n S = decay * S + np.outer(K[t], V[t])\n output[t] = Q[t] @ S\n\n return output.astype(float)", + "example": { + "input": "import numpy as np\nQ = np.ones((3, 2))\nK = np.ones((3, 2))\nV = np.ones((3, 2))\nprint(np.round(lightning_attention(Q, K, V, 0.5), 4))", + "output": "[[2. 2. ]\n [3. 3. ]\n [3.5 3.5]]", + "reasoning": "At t=0: S = outer([1,1],[1,1]) = [[1,1],[1,1]], O = [1,1]@S = [2,2]. At t=1: S = 0.5*S + outer = [[1.5,1.5],[1.5,1.5]], O = [3,3]. At t=2: S = 0.5*S + outer = [[1.75,1.75],[1.75,1.75]], O = [3.5,3.5]." + }, + "test_cases": [ + { + "test": "import numpy as np\nQ = np.array([[1.0, 0.0]])\nK = np.array([[1.0, 0.0]])\nV = np.array([[1.0, 2.0]])\nprint(np.round(lightning_attention(Q, K, V, 0.9), 4))", + "expected_output": "[[1. 2.]]" + }, + { + "test": "import numpy as np\nQ = np.array([[1.0, 0.0], [0.0, 1.0]])\nK = np.array([[1.0, 0.0], [0.0, 1.0]])\nV = np.array([[1.0, 2.0], [3.0, 4.0]])\nprint(np.round(lightning_attention(Q, K, V, 1.0), 4))", + "expected_output": "[[1. 2.]\n [3. 4.]]" + }, + { + "test": "import numpy as np\nQ = np.ones((3, 2))\nK = np.ones((3, 2))\nV = np.ones((3, 2))\nprint(np.round(lightning_attention(Q, K, V, 0.5), 4))", + "expected_output": "[[2. 2. ]\n [3. 3. ]\n [3.5 3.5]]" + }, + { + "test": "import numpy as np\nnp.random.seed(42)\nQ = np.random.randn(4, 3)\nK = np.random.randn(4, 3)\nV = np.random.randn(4, 3)\nprint(np.round(lightning_attention(Q, K, V, 0.9), 4))", + "expected_output": "[[ 0.3988 -0.0812 0.8431]\n [-0.8582 0.538 -1.0621]\n [ 1.4379 -4.9831 0.7769]\n [-0.9745 -0.3103 -2.1484]]" + }, + { + "test": "import numpy as np\nQ = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nK = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nV = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nprint(np.round(lightning_attention(Q, K, V, 0.0), 4))", + "expected_output": "[[1. 0.]\n [1. 0.]\n [1. 0.]]" + } + ], + "pytorch_starter_code": "import torch\n\ndef lightning_attention(Q, K, V, decay):\n \"\"\"\n Implement causal linear attention with exponential decay.\n\n Args:\n Q (torch.Tensor): Query tensor of shape (seq_len, head_dim).\n K (torch.Tensor): Key tensor of shape (seq_len, head_dim).\n V (torch.Tensor): Value tensor of shape (seq_len, head_dim).\n decay (float): Exponential decay factor (lambda), between 0 and 1.\n\n Returns:\n torch.Tensor: Output tensor of shape (seq_len, head_dim).\n \"\"\"\n pass", + "pytorch_solution": "import torch\n\ndef lightning_attention(Q, K, V, decay):\n seq_len, head_dim = Q.shape\n S = torch.zeros((head_dim, head_dim), dtype=Q.dtype)\n output = torch.zeros_like(Q)\n\n for t in range(seq_len):\n S = decay * S + torch.outer(K[t], V[t])\n output[t] = Q[t] @ S\n\n return output.float()", + "pytorch_test_cases": [ + { + "test": "import torch\nQ = torch.tensor([[1.0, 0.0]])\nK = torch.tensor([[1.0, 0.0]])\nV = torch.tensor([[1.0, 2.0]])\nprint(torch.round(lightning_attention(Q, K, V, 0.9), decimals=4))", + "expected_output": "tensor([[1., 2.]])" + }, + { + "test": "import torch\nQ = torch.ones((3, 2))\nK = torch.ones((3, 2))\nV = torch.ones((3, 2))\nprint(torch.round(lightning_attention(Q, K, V, 0.5), decimals=4))", + "expected_output": "tensor([[2.0000, 2.0000],\n [3.0000, 3.0000],\n [3.5000, 3.5000]])" + }, + { + "test": "import torch\nQ = torch.tensor([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nK = torch.tensor([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nV = torch.tensor([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nprint(torch.round(lightning_attention(Q, K, V, 0.0), decimals=4))", + "expected_output": "tensor([[1., 0.],\n [1., 0.],\n [1., 0.]])" + } + ] +} \ No newline at end of file diff --git a/build/459.json b/build/459.json new file mode 100644 index 00000000..6bdaeae5 --- /dev/null +++ b/build/459.json @@ -0,0 +1,46 @@ +{ + "id": "459", + "title": "Implement Multi-Token Prediction", + "difficulty": "hard", + "category": "Deep Learning", + "video": "", + "likes": "0", + "dislikes": "0", + "contributor": [], + "pytorch_difficulty": "hard", + "description": "## Problem\n\nWrite a Python function `multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future)` that implements the Multi-Token Prediction (MTP) training objective. The function takes:\n- `hidden_states`: array of shape `(seq_len, hidden_dim)` — the main model's last hidden states\n- `target_ids`: array of shape `(seq_len,)` — the target token IDs (integers)\n- `embedding_weight`: array of shape `(vocab_size, hidden_dim)` — token embedding matrix\n- `proj_weights`: list of `num_future` weight arrays, each of shape `(hidden_dim, 2 * hidden_dim)` — projection for each MTP head\n- `lm_head_weight`: array of shape `(vocab_size, hidden_dim)` — language model head weight\n- `num_future`: number of future tokens to predict (int)\n\nFor each future step `k` (1 to `num_future`):\n1. Get the target embedding for position shifted by `k-1`: `emb = embedding_weight[target_ids[k-1:seq_len-num_future+k-1]]`\n2. Concatenate `hidden_states[:seq_len-num_future]` with `emb` along the feature dimension\n3. Project through `proj_weights[k-1]` (matrix multiply)\n4. Compute logits via `lm_head_weight`\n5. Compute cross-entropy loss against `target_ids[k:seq_len-num_future+k]`\n\nReturn the average loss across all steps as a float. Only use numpy.", + "learn_section": "# **Multi-Token Prediction (MTP)**\n\n## **1. Definition**\nMulti-Token Prediction is an auxiliary training objective where the model predicts not just the next token, but multiple future tokens simultaneously. Instead of a single prediction head, the model uses additional lightweight heads that each predict a different future token.\n\n## **2. Motivation**\n- **Better representations:** Predicting further into the future forces hidden states to encode more long-range information.\n- **Faster inference:** MTP heads can be used for speculative decoding, where multiple token candidates are generated in one forward pass and verified in parallel.\n- **Improved sample efficiency:** Each training example provides multiple supervision signals.\n\n## **3. Architecture**\nGiven the main model's hidden states $H \\in \\mathbb{R}^{L \\times d}$ and $N$ MTP heads:\n\nFor each head $k \\in \\{1, \\ldots, N\\}$:\n\n**Step 1: Get target embeddings (shifted by k-1 positions)**\n$$E_k = \\text{Embed}(\\text{targets}[k-1 : L-N+k-1])$$\n\n**Step 2: Concatenate with hidden states**\n$$C_k = [H_{:L-N} \\; ; \\; E_k] \\quad \\in \\mathbb{R}^{(L-N) \\times 2d}$$\n\n**Step 3: Project back to hidden dimension**\n$$\\tilde{H}_k = C_k \\cdot W_k^T \\quad \\in \\mathbb{R}^{(L-N) \\times d}$$\n\n**Step 4: Compute logits**\n$$\\text{logits}_k = \\tilde{H}_k \\cdot W_{\\text{lm}}^T \\quad \\in \\mathbb{R}^{(L-N) \\times V}$$\n\n**Step 5: Cross-entropy loss against shifted targets**\n$$\\mathcal{L}_k = \\text{CE}(\\text{logits}_k, \\text{targets}[k : L-N+k])$$\n\n**Final loss:**\n$$\\mathcal{L}_{\\text{MTP}} = \\frac{1}{N} \\sum_{k=1}^{N} \\mathcal{L}_k$$\n\n## **4. Key Details**\n- Each MTP head has its own projection matrix $W_k$ but shares the embedding table and LM head with the main model\n- The concatenation of hidden states with target embeddings provides \"teacher forcing\" — each head sees the ground truth of previous positions\n- During inference, MTP heads enable speculative decoding for faster generation\n\n## **5. Role in MiniMax M2.5**\nMiniMax M2.5 uses 3 MTP modules, each containing 1 transformer layer. These auxiliary heads are trained jointly with the main model and can be used for speculative decoding during inference.", + "starter_code": "import numpy as np\n\n# Implement your function below.\ndef multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future):\n \"\"\"\n Implement Multi-Token Prediction training objective.\n\n Args:\n hidden_states (np.ndarray): Main model hidden states, shape (seq_len, hidden_dim).\n target_ids (np.ndarray): Target token IDs, shape (seq_len,).\n embedding_weight (np.ndarray): Token embedding matrix, shape (vocab_size, hidden_dim).\n proj_weights (list): List of num_future projection weights, each shape (hidden_dim, 2*hidden_dim).\n lm_head_weight (np.ndarray): LM head weight, shape (vocab_size, hidden_dim).\n num_future (int): Number of future tokens to predict.\n\n Returns:\n float: Average cross-entropy loss across all MTP heads.\n \"\"\"\n pass", + "solution": "import numpy as np\n\ndef multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future):\n seq_len, hidden_dim = hidden_states.shape\n effective_len = seq_len - num_future\n total_loss = 0.0\n\n for k in range(1, num_future + 1):\n emb = embedding_weight[target_ids[k-1:effective_len+k-1]]\n concat = np.concatenate([hidden_states[:effective_len], emb], axis=-1)\n projected = concat @ proj_weights[k-1].T\n logits = projected @ lm_head_weight.T\n\n targets = target_ids[k:effective_len+k]\n logits_shifted = logits - np.max(logits, axis=-1, keepdims=True)\n exp_logits = np.exp(logits_shifted)\n log_sum_exp = np.log(np.sum(exp_logits, axis=-1))\n correct_logits = logits_shifted[np.arange(effective_len), targets]\n loss = -np.mean(correct_logits - log_sum_exp)\n total_loss += loss\n\n return float(total_loss / num_future)", + "example": { + "input": "import numpy as np\nnp.random.seed(42)\nhidden_dim = 4\nvocab_size = 10\nseq_len = 6\nnum_future = 2\nhidden_states = np.random.randn(seq_len, hidden_dim)\ntarget_ids = np.array([3, 1, 4, 1, 5, 9])\nembedding_weight = np.random.randn(vocab_size, hidden_dim)\nproj_weights = [np.random.randn(hidden_dim, 2*hidden_dim) for _ in range(num_future)]\nlm_head_weight = np.random.randn(vocab_size, hidden_dim)\nprint(round(multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future), 4))", + "output": "6.5678", + "reasoning": "With 2 MTP heads over sequence length 6, each head predicts 4 tokens. Head 1 predicts targets[1:5] from hidden[0:4] + embed(targets[0:4]). Head 2 predicts targets[2:6]. The average cross-entropy loss across both heads is 6.5678." + }, + "test_cases": [ + { + "test": "import numpy as np\nnp.random.seed(42)\nhidden_dim = 4\nvocab_size = 10\nseq_len = 6\nnum_future = 2\nhidden_states = np.random.randn(seq_len, hidden_dim)\ntarget_ids = np.array([3, 1, 4, 1, 5, 9])\nembedding_weight = np.random.randn(vocab_size, hidden_dim)\nproj_weights = [np.random.randn(hidden_dim, 2*hidden_dim) for _ in range(num_future)]\nlm_head_weight = np.random.randn(vocab_size, hidden_dim)\nprint(round(multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future), 4))", + "expected_output": "6.5678" + }, + { + "test": "import numpy as np\nnp.random.seed(0)\nhidden_states = np.random.randn(4, 3)\ntarget_ids = np.array([0, 1, 2, 0])\nembedding_weight = np.random.randn(3, 3)\nproj_weights = [np.random.randn(3, 6)]\nlm_head_weight = np.random.randn(3, 3)\nprint(round(multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, 1), 4))", + "expected_output": "3.1856" + }, + { + "test": "import numpy as np\nnp.random.seed(7)\nhidden_states = np.random.randn(5, 2)\ntarget_ids = np.array([0, 1, 0, 1, 0])\nembedding_weight = np.random.randn(2, 2)\nproj_weights = [np.random.randn(2, 4), np.random.randn(2, 4)]\nlm_head_weight = np.random.randn(2, 2)\nprint(round(multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, 2), 4))", + "expected_output": "0.656" + } + ], + "pytorch_starter_code": "import torch\n\ndef multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future):\n \"\"\"\n Implement Multi-Token Prediction training objective.\n\n Args:\n hidden_states (torch.Tensor): Main model hidden states, shape (seq_len, hidden_dim).\n target_ids (torch.Tensor): Target token IDs (long), shape (seq_len,).\n embedding_weight (torch.Tensor): Token embedding matrix, shape (vocab_size, hidden_dim).\n proj_weights (list): List of num_future projection weights, each shape (hidden_dim, 2*hidden_dim).\n lm_head_weight (torch.Tensor): LM head weight, shape (vocab_size, hidden_dim).\n num_future (int): Number of future tokens to predict.\n\n Returns:\n float: Average cross-entropy loss across all MTP heads.\n \"\"\"\n pass", + "pytorch_solution": "import torch\nimport torch.nn.functional as F\n\ndef multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future):\n seq_len, hidden_dim = hidden_states.shape\n effective_len = seq_len - num_future\n total_loss = 0.0\n\n for k in range(1, num_future + 1):\n emb = embedding_weight[target_ids[k-1:effective_len+k-1]]\n concat = torch.cat([hidden_states[:effective_len], emb], dim=-1)\n projected = concat @ proj_weights[k-1].T\n logits = projected @ lm_head_weight.T\n\n targets = target_ids[k:effective_len+k]\n loss = F.cross_entropy(logits, targets)\n total_loss += loss.item()\n\n return float(total_loss / num_future)", + "pytorch_test_cases": [ + { + "test": "import torch\ntorch.manual_seed(42)\nhidden_dim = 4\nvocab_size = 10\nseq_len = 6\nnum_future = 2\nhidden_states = torch.randn(seq_len, hidden_dim)\ntarget_ids = torch.tensor([3, 1, 4, 1, 5, 9])\nembedding_weight = torch.randn(vocab_size, hidden_dim)\nproj_weights = [torch.randn(hidden_dim, 2*hidden_dim) for _ in range(num_future)]\nlm_head_weight = torch.randn(vocab_size, hidden_dim)\nloss = multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future)\nprint(round(loss, 4))", + "expected_output": "8.7782" + }, + { + "test": "import torch\ntorch.manual_seed(0)\nhidden_states = torch.randn(4, 3)\ntarget_ids = torch.tensor([0, 1, 2, 0])\nembedding_weight = torch.randn(3, 3)\nproj_weights = [torch.randn(3, 6)]\nlm_head_weight = torch.randn(3, 3)\nloss = multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, 1)\nprint(round(loss, 4))", + "expected_output": "1.4607" + } + ] +} \ No newline at end of file diff --git a/questions/457_implement-sigmoid-moe-router/description.md b/questions/457_implement-sigmoid-moe-router/description.md new file mode 100644 index 00000000..159fe043 --- /dev/null +++ b/questions/457_implement-sigmoid-moe-router/description.md @@ -0,0 +1,17 @@ +## Problem + +Write a Python function `sigmoid_moe_router(hidden_states, gate_weight, score_bias, top_k)` that implements the sigmoid-based Mixture of Experts routing used in MiniMax M2.5. The function takes: +- `hidden_states`: array of shape `(num_tokens, hidden_dim)` +- `gate_weight`: array of shape `(num_experts, hidden_dim)` +- `score_bias`: array of shape `(num_experts,)` for expert load balancing +- `top_k`: number of experts to select per token + +The function should: +1. Compute router logits via matrix multiplication +2. Apply sigmoid activation (not softmax) to get routing weights +3. Add the score bias to determine expert selection +4. Select the top-k experts per token +5. Gather the actual sigmoid weights (without bias) for selected experts +6. Normalize the selected weights to sum to 1 + +Return a tuple of `(top_k_weights, top_k_indices)` where weights has shape `(num_tokens, top_k)` and indices has shape `(num_tokens, top_k)`. Only use numpy. diff --git a/questions/457_implement-sigmoid-moe-router/example.json b/questions/457_implement-sigmoid-moe-router/example.json new file mode 100644 index 00000000..f0fa4f7b --- /dev/null +++ b/questions/457_implement-sigmoid-moe-router/example.json @@ -0,0 +1,5 @@ +{ + "input": "import numpy as np\nhidden = np.array([[1.0, 0.0], [0.0, 1.0]])\ngate = np.array([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])\nbias = np.zeros(4)\nw, idx = sigmoid_moe_router(hidden, gate, bias, 2)\nprint(np.round(w, 4))\nprint(idx)", + "output": "[[0.5938 0.4062]\n [0.5938 0.4062]]\n[[0 1]\n [1 0]]", + "reasoning": "For token [1,0]: logits=[1,0,-1,0], sigmoid=[0.731,0.5,0.269,0.5]. Top-2 by score are experts 0,1 with weights [0.731,0.5]. Normalized: [0.731/1.231, 0.5/1.231] = [0.5938, 0.4062]." +} diff --git a/questions/457_implement-sigmoid-moe-router/learn.md b/questions/457_implement-sigmoid-moe-router/learn.md new file mode 100644 index 00000000..970a48fb --- /dev/null +++ b/questions/457_implement-sigmoid-moe-router/learn.md @@ -0,0 +1,41 @@ +# **Sigmoid MoE Router with Bias Correction** + +## **1. Definition** +The Sigmoid MoE Router is the expert routing mechanism used in MiniMax M2.5. Unlike traditional MoE routers that use softmax scoring, this router uses **sigmoid activation** with a **learned bias correction** for expert selection, followed by weight normalization. + +## **2. Why Sigmoid Instead of Softmax?** +- **Independent scoring:** Sigmoid scores each expert independently, while softmax creates competition between experts. This allows for more flexible expert utilization. +- **Better load balancing:** The learned bias correction term helps distribute tokens more evenly across experts without auxiliary losses dominating training. +- **Simpler gradient flow:** Sigmoid gradients don't depend on other experts' scores. + +## **3. Algorithm** +Given hidden states $H \in \mathbb{R}^{T \times d}$, gate weights $W_g \in \mathbb{R}^{E \times d}$, and bias correction $b \in \mathbb{R}^{E}$: + +**Step 1: Compute logits** +$$\text{logits} = H \cdot W_g^T \quad \in \mathbb{R}^{T \times E}$$ + +**Step 2: Apply sigmoid** +$$w = \sigma(\text{logits}) = \frac{1}{1 + e^{-\text{logits}}} \quad \in \mathbb{R}^{T \times E}$$ + +**Step 3: Bias-corrected scores for selection** +$$s = w + b \quad \in \mathbb{R}^{T \times E}$$ + +**Step 4: Top-k selection** +$$\text{indices} = \text{argsort}(s, \text{descending})[:, :k]$$ + +**Step 5: Gather actual weights (without bias)** +$$w_{\text{selected}} = \text{gather}(w, \text{indices})$$ + +**Step 6: Normalize** +$$\hat{w} = \frac{w_{\text{selected}}}{\sum_{j=1}^{k} w_{\text{selected}, j}}$$ + +## **4. Key Design Choices** +- The **bias is only used for selection**, not for the final weights. This decouples load balancing from the actual routing weights. +- The bias $b$ is a **learned parameter** that adjusts which experts are preferred, compensating for natural imbalances. +- The final weights are **normalized** so they sum to 1, making the weighted combination of expert outputs a proper convex combination. + +## **5. Role in MiniMax M2.5** +- **256 experts** per layer, **top-8** selected per token +- Gate: `Linear(3072, 256, bias=False)` +- Bias correction: learned vector of size 256 +- Only ~3% of experts activated per token (8/256) diff --git a/questions/457_implement-sigmoid-moe-router/meta.json b/questions/457_implement-sigmoid-moe-router/meta.json new file mode 100644 index 00000000..7af82b2c --- /dev/null +++ b/questions/457_implement-sigmoid-moe-router/meta.json @@ -0,0 +1,11 @@ +{ + "id": "457", + "title": "Implement Sigmoid MoE Router with Bias Correction", + "difficulty": "hard", + "category": "Deep Learning", + "video": "", + "likes": "0", + "dislikes": "0", + "contributor": [], + "pytorch_difficulty": "hard" +} diff --git a/questions/457_implement-sigmoid-moe-router/pytorch/solution.py b/questions/457_implement-sigmoid-moe-router/pytorch/solution.py new file mode 100644 index 00000000..aa9cb592 --- /dev/null +++ b/questions/457_implement-sigmoid-moe-router/pytorch/solution.py @@ -0,0 +1,13 @@ +import torch + +def sigmoid_moe_router(hidden_states, gate_weight, score_bias, top_k): + logits = hidden_states @ gate_weight.T + routing_weights = torch.sigmoid(logits) + scores_for_choice = routing_weights + score_bias + + top_k_indices = torch.topk(scores_for_choice, top_k, dim=-1).indices + + top_k_weights = torch.gather(routing_weights, 1, top_k_indices) + top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) + + return top_k_weights.float(), top_k_indices diff --git a/questions/457_implement-sigmoid-moe-router/pytorch/starter_code.py b/questions/457_implement-sigmoid-moe-router/pytorch/starter_code.py new file mode 100644 index 00000000..53b2757f --- /dev/null +++ b/questions/457_implement-sigmoid-moe-router/pytorch/starter_code.py @@ -0,0 +1,16 @@ +import torch + +def sigmoid_moe_router(hidden_states, gate_weight, score_bias, top_k): + """ + Implement sigmoid-based MoE routing with bias correction. + + Args: + hidden_states (torch.Tensor): Token representations, shape (num_tokens, hidden_dim). + gate_weight (torch.Tensor): Gate projection weights, shape (num_experts, hidden_dim). + score_bias (torch.Tensor): Learned bias for load balancing, shape (num_experts,). + top_k (int): Number of experts to select per token. + + Returns: + tuple: (top_k_weights, top_k_indices) + """ + pass diff --git a/questions/457_implement-sigmoid-moe-router/pytorch/tests.json b/questions/457_implement-sigmoid-moe-router/pytorch/tests.json new file mode 100644 index 00000000..6d888679 --- /dev/null +++ b/questions/457_implement-sigmoid-moe-router/pytorch/tests.json @@ -0,0 +1,14 @@ +[ + { + "test": "import torch\nhidden = torch.tensor([[1.0, 0.0], [0.0, 1.0]])\ngate = torch.tensor([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])\nbias = torch.zeros(4)\nw, idx = sigmoid_moe_router(hidden, gate, bias, 2)\nprint(torch.round(w, decimals=4))\nprint(idx)", + "expected_output": "tensor([[0.5938, 0.4062],\n [0.5938, 0.4062]])\ntensor([[0, 1],\n [1, 0]])" + }, + { + "test": "import torch\nhidden = torch.tensor([[0.0, 0.0]])\ngate = torch.tensor([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0]])\nbias = torch.tensor([0.0, 0.0, 1.0])\nw, idx = sigmoid_moe_router(hidden, gate, bias, 2)\nprint(torch.round(w, decimals=4))\nprint(idx)", + "expected_output": "tensor([[0.5000, 0.5000]])\ntensor([[2, 0]])" + }, + { + "test": "import torch\nhidden = torch.tensor([[1.0, 1.0]])\ngate = torch.tensor([[1.0, 1.0], [2.0, 2.0], [0.0, 0.0]])\nbias = torch.zeros(3)\nw, idx = sigmoid_moe_router(hidden, gate, bias, 1)\nprint(torch.round(w, decimals=4))\nprint(idx)", + "expected_output": "tensor([[1.]])\ntensor([[1]])" + } +] diff --git a/questions/457_implement-sigmoid-moe-router/solution.py b/questions/457_implement-sigmoid-moe-router/solution.py new file mode 100644 index 00000000..1df1aeef --- /dev/null +++ b/questions/457_implement-sigmoid-moe-router/solution.py @@ -0,0 +1,13 @@ +import numpy as np + +def sigmoid_moe_router(hidden_states, gate_weight, score_bias, top_k): + logits = hidden_states @ gate_weight.T + routing_weights = 1.0 / (1.0 + np.exp(-logits)) + scores_for_choice = routing_weights + score_bias + + top_k_indices = np.argsort(-scores_for_choice, axis=-1)[:, :top_k] + + top_k_weights = np.take_along_axis(routing_weights, top_k_indices, axis=-1) + top_k_weights = top_k_weights / np.sum(top_k_weights, axis=-1, keepdims=True) + + return top_k_weights.astype(float), top_k_indices diff --git a/questions/457_implement-sigmoid-moe-router/starter_code.py b/questions/457_implement-sigmoid-moe-router/starter_code.py new file mode 100644 index 00000000..5762a9f2 --- /dev/null +++ b/questions/457_implement-sigmoid-moe-router/starter_code.py @@ -0,0 +1,19 @@ +import numpy as np + +# Implement your function below. +def sigmoid_moe_router(hidden_states, gate_weight, score_bias, top_k): + """ + Implement sigmoid-based MoE routing with bias correction. + + Args: + hidden_states (np.ndarray): Token representations, shape (num_tokens, hidden_dim). + gate_weight (np.ndarray): Gate projection weights, shape (num_experts, hidden_dim). + score_bias (np.ndarray): Learned bias for load balancing, shape (num_experts,). + top_k (int): Number of experts to select per token. + + Returns: + tuple: (top_k_weights, top_k_indices) + - top_k_weights: Normalized routing weights, shape (num_tokens, top_k). + - top_k_indices: Selected expert indices, shape (num_tokens, top_k). + """ + pass diff --git a/questions/457_implement-sigmoid-moe-router/tests.json b/questions/457_implement-sigmoid-moe-router/tests.json new file mode 100644 index 00000000..fd8bae26 --- /dev/null +++ b/questions/457_implement-sigmoid-moe-router/tests.json @@ -0,0 +1,18 @@ +[ + { + "test": "import numpy as np\nhidden = np.array([[1.0, 0.0], [0.0, 1.0]])\ngate = np.array([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])\nbias = np.zeros(4)\nw, idx = sigmoid_moe_router(hidden, gate, bias, 2)\nprint(np.round(w, 4))\nprint(idx)", + "expected_output": "[[0.5938 0.4062]\n [0.5938 0.4062]]\n[[0 1]\n [1 0]]" + }, + { + "test": "import numpy as np\nhidden = np.array([[0.0, 0.0]])\ngate = np.array([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0]])\nbias = np.array([0.0, 0.0, 1.0])\nw, idx = sigmoid_moe_router(hidden, gate, bias, 2)\nprint(np.round(w, 4))\nprint(idx)", + "expected_output": "[[0.5 0.5]]\n[[2 0]]" + }, + { + "test": "import numpy as np\nnp.random.seed(42)\nhidden = np.random.randn(2, 4)\ngate = np.random.randn(6, 4)\nbias = np.array([0.1, -0.1, 0.2, -0.2, 0.0, 0.0])\nw, idx = sigmoid_moe_router(hidden, gate, bias, 3)\nprint(np.round(w, 4))\nprint(idx)", + "expected_output": "[[0.6001 0.2588 0.1412]\n [0.6644 0.2488 0.0868]]\n[[5 4 0]\n [5 0 2]]" + }, + { + "test": "import numpy as np\nhidden = np.array([[1.0, 1.0]])\ngate = np.array([[1.0, 1.0], [2.0, 2.0], [0.0, 0.0]])\nbias = np.zeros(3)\nw, idx = sigmoid_moe_router(hidden, gate, bias, 1)\nprint(np.round(w, 4))\nprint(idx)", + "expected_output": "[[1.]]\n[[1]]" + } +] diff --git a/questions/458_implement-lightning-attention/description.md b/questions/458_implement-lightning-attention/description.md new file mode 100644 index 00000000..36cf4af8 --- /dev/null +++ b/questions/458_implement-lightning-attention/description.md @@ -0,0 +1,13 @@ +## Problem + +Write a Python function `lightning_attention(Q, K, V, decay)` that implements causal linear attention with exponential decay, as used in Lightning Attention. The function takes: +- `Q`: query array of shape `(seq_len, head_dim)` +- `K`: key array of shape `(seq_len, head_dim)` +- `V`: value array of shape `(seq_len, head_dim)` +- `decay`: a float decay factor (lambda) between 0 and 1 + +Instead of computing softmax(QK^T)V, compute the output using the recurrent form: +- Maintain a state `S` of shape `(head_dim, head_dim)` initialized to zeros +- At each timestep t: `S_t = decay * S_{t-1} + K_t^T @ V_t`, then `O_t = Q_t @ S_t` + +Return the output array of shape `(seq_len, head_dim)` as floats. Only use numpy. diff --git a/questions/458_implement-lightning-attention/example.json b/questions/458_implement-lightning-attention/example.json new file mode 100644 index 00000000..c0fbcaab --- /dev/null +++ b/questions/458_implement-lightning-attention/example.json @@ -0,0 +1,5 @@ +{ + "input": "import numpy as np\nQ = np.ones((3, 2))\nK = np.ones((3, 2))\nV = np.ones((3, 2))\nprint(np.round(lightning_attention(Q, K, V, 0.5), 4))", + "output": "[[2. 2. ]\n [3. 3. ]\n [3.5 3.5]]", + "reasoning": "At t=0: S = outer([1,1],[1,1]) = [[1,1],[1,1]], O = [1,1]@S = [2,2]. At t=1: S = 0.5*S + outer = [[1.5,1.5],[1.5,1.5]], O = [3,3]. At t=2: S = 0.5*S + outer = [[1.75,1.75],[1.75,1.75]], O = [3.5,3.5]." +} diff --git a/questions/458_implement-lightning-attention/learn.md b/questions/458_implement-lightning-attention/learn.md new file mode 100644 index 00000000..e86423bd --- /dev/null +++ b/questions/458_implement-lightning-attention/learn.md @@ -0,0 +1,41 @@ +# **Lightning Attention (Linear Attention with Decay)** + +## **1. Definition** +Lightning Attention is a linear attention mechanism that replaces the quadratic softmax attention with a recurrent formulation. It achieves $O(nd^2)$ complexity instead of $O(n^2d)$, making it efficient for very long sequences. It was developed by the MiniMax team and used in MiniMax-01. + +## **2. Standard Attention vs Linear Attention** + +**Standard (Softmax) Attention:** +$$O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V \quad \in O(n^2 d)$$ + +**Linear Attention (kernel trick):** +$$O_t = Q_t \cdot \sum_{s \leq t} K_s^T V_s = Q_t \cdot S_t \quad \in O(n d^2)$$ + +Where $S_t = \sum_{s \leq t} K_s^T V_s$ is a running state that accumulates key-value outer products. + +## **3. Adding Exponential Decay** +To prevent the state from growing unboundedly and to focus on more recent tokens, Lightning Attention adds an exponential decay factor $\lambda$: + +$$S_t = \lambda \cdot S_{t-1} + K_t^T V_t$$ +$$O_t = Q_t \cdot S_t$$ + +Where: +- $S_t \in \mathbb{R}^{d \times d}$ is the recurrent state +- $\lambda \in (0, 1)$ is the decay factor +- $K_t^T V_t$ is the outer product of key and value at position $t$ + +## **4. Recurrent Interpretation** +The decay creates an exponentially weighted sum over history: + +$$S_t = \sum_{s=1}^{t} \lambda^{t-s} K_s^T V_s$$ + +More recent tokens contribute more strongly than distant ones, providing a natural notion of locality. + +## **5. Advantages** +- **Linear complexity:** $O(nd^2)$ instead of $O(n^2d)$ — better for long sequences when $n \gg d$ +- **Constant memory per step:** Only need to maintain the $d \times d$ state matrix +- **Streamable:** Can process tokens one at a time in a recurrent fashion +- **Infinite context (in theory):** No fixed context window limitation + +## **6. Role in MiniMax Architecture** +In MiniMax-01, the predecessor to M2.5, Lightning Attention was used in a hybrid pattern: 7 linear attention layers followed by 1 softmax attention layer, repeated across the network. This combined the efficiency of linear attention with the expressiveness of softmax attention for long-range dependencies. diff --git a/questions/458_implement-lightning-attention/meta.json b/questions/458_implement-lightning-attention/meta.json new file mode 100644 index 00000000..7c5b1df1 --- /dev/null +++ b/questions/458_implement-lightning-attention/meta.json @@ -0,0 +1,11 @@ +{ + "id": "458", + "title": "Implement Lightning Attention (Linear Attention)", + "difficulty": "hard", + "category": "Deep Learning", + "video": "", + "likes": "0", + "dislikes": "0", + "contributor": [], + "pytorch_difficulty": "hard" +} diff --git a/questions/458_implement-lightning-attention/pytorch/solution.py b/questions/458_implement-lightning-attention/pytorch/solution.py new file mode 100644 index 00000000..1db32fde --- /dev/null +++ b/questions/458_implement-lightning-attention/pytorch/solution.py @@ -0,0 +1,12 @@ +import torch + +def lightning_attention(Q, K, V, decay): + seq_len, head_dim = Q.shape + S = torch.zeros((head_dim, head_dim), dtype=Q.dtype, device=Q.device) + output = torch.zeros_like(Q) + + for t in range(seq_len): + S = decay * S + torch.outer(K[t], V[t]) + output[t] = Q[t] @ S + + return output.float() diff --git a/questions/458_implement-lightning-attention/pytorch/starter_code.py b/questions/458_implement-lightning-attention/pytorch/starter_code.py new file mode 100644 index 00000000..adeafe6c --- /dev/null +++ b/questions/458_implement-lightning-attention/pytorch/starter_code.py @@ -0,0 +1,16 @@ +import torch + +def lightning_attention(Q, K, V, decay): + """ + Implement causal linear attention with exponential decay. + + Args: + Q (torch.Tensor): Query tensor of shape (seq_len, head_dim). + K (torch.Tensor): Key tensor of shape (seq_len, head_dim). + V (torch.Tensor): Value tensor of shape (seq_len, head_dim). + decay (float): Exponential decay factor (lambda), between 0 and 1. + + Returns: + torch.Tensor: Output tensor of shape (seq_len, head_dim). + """ + pass diff --git a/questions/458_implement-lightning-attention/pytorch/tests.json b/questions/458_implement-lightning-attention/pytorch/tests.json new file mode 100644 index 00000000..4b6f79c3 --- /dev/null +++ b/questions/458_implement-lightning-attention/pytorch/tests.json @@ -0,0 +1,14 @@ +[ + { + "test": "import torch\nQ = torch.tensor([[1.0, 0.0]])\nK = torch.tensor([[1.0, 0.0]])\nV = torch.tensor([[1.0, 2.0]])\nprint(torch.round(lightning_attention(Q, K, V, 0.9), decimals=4))", + "expected_output": "tensor([[1., 2.]])" + }, + { + "test": "import torch\nQ = torch.ones((3, 2))\nK = torch.ones((3, 2))\nV = torch.ones((3, 2))\nprint(torch.round(lightning_attention(Q, K, V, 0.5), decimals=4))", + "expected_output": "tensor([[2.0000, 2.0000],\n [3.0000, 3.0000],\n [3.5000, 3.5000]])" + }, + { + "test": "import torch\nQ = torch.tensor([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nK = torch.tensor([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nV = torch.tensor([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nprint(torch.round(lightning_attention(Q, K, V, 0.0), decimals=4))", + "expected_output": "tensor([[1., 0.],\n [1., 0.],\n [1., 0.]])" + } +] diff --git a/questions/458_implement-lightning-attention/solution.py b/questions/458_implement-lightning-attention/solution.py new file mode 100644 index 00000000..8735b8a4 --- /dev/null +++ b/questions/458_implement-lightning-attention/solution.py @@ -0,0 +1,12 @@ +import numpy as np + +def lightning_attention(Q, K, V, decay): + seq_len, head_dim = Q.shape + S = np.zeros((head_dim, head_dim)) + output = np.zeros((seq_len, head_dim)) + + for t in range(seq_len): + S = decay * S + np.outer(K[t], V[t]) + output[t] = Q[t] @ S + + return output.astype(float) diff --git a/questions/458_implement-lightning-attention/starter_code.py b/questions/458_implement-lightning-attention/starter_code.py new file mode 100644 index 00000000..d0c5de0c --- /dev/null +++ b/questions/458_implement-lightning-attention/starter_code.py @@ -0,0 +1,17 @@ +import numpy as np + +# Implement your function below. +def lightning_attention(Q, K, V, decay): + """ + Implement causal linear attention with exponential decay. + + Args: + Q (np.ndarray): Query array of shape (seq_len, head_dim). + K (np.ndarray): Key array of shape (seq_len, head_dim). + V (np.ndarray): Value array of shape (seq_len, head_dim). + decay (float): Exponential decay factor (lambda), between 0 and 1. + + Returns: + np.ndarray: Output array of shape (seq_len, head_dim). + """ + pass diff --git a/questions/458_implement-lightning-attention/tests.json b/questions/458_implement-lightning-attention/tests.json new file mode 100644 index 00000000..6a662769 --- /dev/null +++ b/questions/458_implement-lightning-attention/tests.json @@ -0,0 +1,22 @@ +[ + { + "test": "import numpy as np\nQ = np.array([[1.0, 0.0]])\nK = np.array([[1.0, 0.0]])\nV = np.array([[1.0, 2.0]])\nprint(np.round(lightning_attention(Q, K, V, 0.9), 4))", + "expected_output": "[[1. 2.]]" + }, + { + "test": "import numpy as np\nQ = np.array([[1.0, 0.0], [0.0, 1.0]])\nK = np.array([[1.0, 0.0], [0.0, 1.0]])\nV = np.array([[1.0, 2.0], [3.0, 4.0]])\nprint(np.round(lightning_attention(Q, K, V, 1.0), 4))", + "expected_output": "[[1. 2.]\n [3. 4.]]" + }, + { + "test": "import numpy as np\nQ = np.ones((3, 2))\nK = np.ones((3, 2))\nV = np.ones((3, 2))\nprint(np.round(lightning_attention(Q, K, V, 0.5), 4))", + "expected_output": "[[2. 2. ]\n [3. 3. ]\n [3.5 3.5]]" + }, + { + "test": "import numpy as np\nnp.random.seed(42)\nQ = np.random.randn(4, 3)\nK = np.random.randn(4, 3)\nV = np.random.randn(4, 3)\nprint(np.round(lightning_attention(Q, K, V, 0.9), 4))", + "expected_output": "[[ 0.3988 -0.0812 0.8431]\n [-0.8582 0.538 -1.0621]\n [ 1.4379 -4.9831 0.7769]\n [-0.9745 -0.3103 -2.1484]]" + }, + { + "test": "import numpy as np\nQ = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nK = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nV = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])\nprint(np.round(lightning_attention(Q, K, V, 0.0), 4))", + "expected_output": "[[1. 0.]\n [1. 0.]\n [1. 0.]]" + } +] diff --git a/questions/459_implement-multi-token-prediction/description.md b/questions/459_implement-multi-token-prediction/description.md new file mode 100644 index 00000000..24a2cd58 --- /dev/null +++ b/questions/459_implement-multi-token-prediction/description.md @@ -0,0 +1,18 @@ +## Problem + +Write a Python function `multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future)` that implements the Multi-Token Prediction (MTP) training objective. The function takes: +- `hidden_states`: array of shape `(seq_len, hidden_dim)` — the main model's last hidden states +- `target_ids`: array of shape `(seq_len,)` — the target token IDs (integers) +- `embedding_weight`: array of shape `(vocab_size, hidden_dim)` — token embedding matrix +- `proj_weights`: list of `num_future` weight arrays, each of shape `(hidden_dim, 2 * hidden_dim)` — projection for each MTP head +- `lm_head_weight`: array of shape `(vocab_size, hidden_dim)` — language model head weight +- `num_future`: number of future tokens to predict (int) + +For each future step `k` (1 to `num_future`): +1. Get the target embedding for position shifted by `k-1`: `emb = embedding_weight[target_ids[k-1:seq_len-num_future+k-1]]` +2. Concatenate `hidden_states[:seq_len-num_future]` with `emb` along the feature dimension +3. Project through `proj_weights[k-1]` (matrix multiply) +4. Compute logits via `lm_head_weight` +5. Compute cross-entropy loss against `target_ids[k:seq_len-num_future+k]` + +Return the average loss across all steps as a float. Only use numpy. diff --git a/questions/459_implement-multi-token-prediction/example.json b/questions/459_implement-multi-token-prediction/example.json new file mode 100644 index 00000000..634b3f4b --- /dev/null +++ b/questions/459_implement-multi-token-prediction/example.json @@ -0,0 +1,5 @@ +{ + "input": "import numpy as np\nnp.random.seed(42)\nhidden_dim = 4\nvocab_size = 10\nseq_len = 6\nnum_future = 2\nhidden_states = np.random.randn(seq_len, hidden_dim)\ntarget_ids = np.array([3, 1, 4, 1, 5, 9])\nembedding_weight = np.random.randn(vocab_size, hidden_dim)\nproj_weights = [np.random.randn(hidden_dim, 2*hidden_dim) for _ in range(num_future)]\nlm_head_weight = np.random.randn(vocab_size, hidden_dim)\nprint(round(multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future), 4))", + "output": "6.5678", + "reasoning": "With 2 MTP heads over sequence length 6, each head predicts 4 tokens. Head 1 predicts targets[1:5] from hidden[0:4] + embed(targets[0:4]). Head 2 predicts targets[2:6]. The average cross-entropy loss across both heads is 6.5678." +} diff --git a/questions/459_implement-multi-token-prediction/learn.md b/questions/459_implement-multi-token-prediction/learn.md new file mode 100644 index 00000000..1a0eeaa0 --- /dev/null +++ b/questions/459_implement-multi-token-prediction/learn.md @@ -0,0 +1,40 @@ +# **Multi-Token Prediction (MTP)** + +## **1. Definition** +Multi-Token Prediction is an auxiliary training objective where the model predicts not just the next token, but multiple future tokens simultaneously. Instead of a single prediction head, the model uses additional lightweight heads that each predict a different future token. + +## **2. Motivation** +- **Better representations:** Predicting further into the future forces hidden states to encode more long-range information. +- **Faster inference:** MTP heads can be used for speculative decoding, where multiple token candidates are generated in one forward pass and verified in parallel. +- **Improved sample efficiency:** Each training example provides multiple supervision signals. + +## **3. Architecture** +Given the main model's hidden states $H \in \mathbb{R}^{L \times d}$ and $N$ MTP heads: + +For each head $k \in \{1, \ldots, N\}$: + +**Step 1: Get target embeddings (shifted by k-1 positions)** +$$E_k = \text{Embed}(\text{targets}[k-1 : L-N+k-1])$$ + +**Step 2: Concatenate with hidden states** +$$C_k = [H_{:L-N} \; ; \; E_k] \quad \in \mathbb{R}^{(L-N) \times 2d}$$ + +**Step 3: Project back to hidden dimension** +$$\tilde{H}_k = C_k \cdot W_k^T \quad \in \mathbb{R}^{(L-N) \times d}$$ + +**Step 4: Compute logits** +$$\text{logits}_k = \tilde{H}_k \cdot W_{\text{lm}}^T \quad \in \mathbb{R}^{(L-N) \times V}$$ + +**Step 5: Cross-entropy loss against shifted targets** +$$\mathcal{L}_k = \text{CE}(\text{logits}_k, \text{targets}[k : L-N+k])$$ + +**Final loss:** +$$\mathcal{L}_{\text{MTP}} = \frac{1}{N} \sum_{k=1}^{N} \mathcal{L}_k$$ + +## **4. Key Details** +- Each MTP head has its own projection matrix $W_k$ but shares the embedding table and LM head with the main model +- The concatenation of hidden states with target embeddings provides "teacher forcing" — each head sees the ground truth of previous positions +- During inference, MTP heads enable speculative decoding for faster generation + +## **5. Role in MiniMax M2.5** +MiniMax M2.5 uses 3 MTP modules, each containing 1 transformer layer. These auxiliary heads are trained jointly with the main model and can be used for speculative decoding during inference. diff --git a/questions/459_implement-multi-token-prediction/meta.json b/questions/459_implement-multi-token-prediction/meta.json new file mode 100644 index 00000000..207a975a --- /dev/null +++ b/questions/459_implement-multi-token-prediction/meta.json @@ -0,0 +1,11 @@ +{ + "id": "459", + "title": "Implement Multi-Token Prediction", + "difficulty": "hard", + "category": "Deep Learning", + "video": "", + "likes": "0", + "dislikes": "0", + "contributor": [], + "pytorch_difficulty": "hard" +} diff --git a/questions/459_implement-multi-token-prediction/pytorch/solution.py b/questions/459_implement-multi-token-prediction/pytorch/solution.py new file mode 100644 index 00000000..1dbf3543 --- /dev/null +++ b/questions/459_implement-multi-token-prediction/pytorch/solution.py @@ -0,0 +1,19 @@ +import torch +import torch.nn.functional as F + +def multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future): + seq_len, hidden_dim = hidden_states.shape + effective_len = seq_len - num_future + total_loss = 0.0 + + for k in range(1, num_future + 1): + emb = embedding_weight[target_ids[k-1:effective_len+k-1]] + concat = torch.cat([hidden_states[:effective_len], emb], dim=-1) + projected = concat @ proj_weights[k-1].T + logits = projected @ lm_head_weight.T + + targets = target_ids[k:effective_len+k] + loss = F.cross_entropy(logits, targets) + total_loss += loss.item() + + return float(total_loss / num_future) diff --git a/questions/459_implement-multi-token-prediction/pytorch/starter_code.py b/questions/459_implement-multi-token-prediction/pytorch/starter_code.py new file mode 100644 index 00000000..833a393c --- /dev/null +++ b/questions/459_implement-multi-token-prediction/pytorch/starter_code.py @@ -0,0 +1,18 @@ +import torch + +def multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future): + """ + Implement Multi-Token Prediction training objective. + + Args: + hidden_states (torch.Tensor): Main model hidden states, shape (seq_len, hidden_dim). + target_ids (torch.Tensor): Target token IDs (long), shape (seq_len,). + embedding_weight (torch.Tensor): Token embedding matrix, shape (vocab_size, hidden_dim). + proj_weights (list): List of num_future projection weights, each shape (hidden_dim, 2*hidden_dim). + lm_head_weight (torch.Tensor): LM head weight, shape (vocab_size, hidden_dim). + num_future (int): Number of future tokens to predict. + + Returns: + float: Average cross-entropy loss across all MTP heads. + """ + pass diff --git a/questions/459_implement-multi-token-prediction/pytorch/tests.json b/questions/459_implement-multi-token-prediction/pytorch/tests.json new file mode 100644 index 00000000..1ee0f60e --- /dev/null +++ b/questions/459_implement-multi-token-prediction/pytorch/tests.json @@ -0,0 +1,10 @@ +[ + { + "test": "import torch\ntorch.manual_seed(42)\nhidden_dim = 4\nvocab_size = 10\nseq_len = 6\nnum_future = 2\nhidden_states = torch.randn(seq_len, hidden_dim)\ntarget_ids = torch.tensor([3, 1, 4, 1, 5, 9])\nembedding_weight = torch.randn(vocab_size, hidden_dim)\nproj_weights = [torch.randn(hidden_dim, 2*hidden_dim) for _ in range(num_future)]\nlm_head_weight = torch.randn(vocab_size, hidden_dim)\nloss = multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future)\nprint(round(loss, 4))", + "expected_output": "8.7782" + }, + { + "test": "import torch\ntorch.manual_seed(0)\nhidden_states = torch.randn(4, 3)\ntarget_ids = torch.tensor([0, 1, 2, 0])\nembedding_weight = torch.randn(3, 3)\nproj_weights = [torch.randn(3, 6)]\nlm_head_weight = torch.randn(3, 3)\nloss = multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, 1)\nprint(round(loss, 4))", + "expected_output": "1.4607" + } +] diff --git a/questions/459_implement-multi-token-prediction/solution.py b/questions/459_implement-multi-token-prediction/solution.py new file mode 100644 index 00000000..2ee63090 --- /dev/null +++ b/questions/459_implement-multi-token-prediction/solution.py @@ -0,0 +1,22 @@ +import numpy as np + +def multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future): + seq_len, hidden_dim = hidden_states.shape + effective_len = seq_len - num_future + total_loss = 0.0 + + for k in range(1, num_future + 1): + emb = embedding_weight[target_ids[k-1:effective_len+k-1]] + concat = np.concatenate([hidden_states[:effective_len], emb], axis=-1) + projected = concat @ proj_weights[k-1].T + logits = projected @ lm_head_weight.T + + targets = target_ids[k:effective_len+k] + logits_shifted = logits - np.max(logits, axis=-1, keepdims=True) + exp_logits = np.exp(logits_shifted) + log_sum_exp = np.log(np.sum(exp_logits, axis=-1)) + correct_logits = logits_shifted[np.arange(effective_len), targets] + loss = -np.mean(correct_logits - log_sum_exp) + total_loss += loss + + return float(total_loss / num_future) diff --git a/questions/459_implement-multi-token-prediction/starter_code.py b/questions/459_implement-multi-token-prediction/starter_code.py new file mode 100644 index 00000000..ae101b23 --- /dev/null +++ b/questions/459_implement-multi-token-prediction/starter_code.py @@ -0,0 +1,19 @@ +import numpy as np + +# Implement your function below. +def multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future): + """ + Implement Multi-Token Prediction training objective. + + Args: + hidden_states (np.ndarray): Main model hidden states, shape (seq_len, hidden_dim). + target_ids (np.ndarray): Target token IDs, shape (seq_len,). + embedding_weight (np.ndarray): Token embedding matrix, shape (vocab_size, hidden_dim). + proj_weights (list): List of num_future projection weights, each shape (hidden_dim, 2*hidden_dim). + lm_head_weight (np.ndarray): LM head weight, shape (vocab_size, hidden_dim). + num_future (int): Number of future tokens to predict. + + Returns: + float: Average cross-entropy loss across all MTP heads. + """ + pass diff --git a/questions/459_implement-multi-token-prediction/tests.json b/questions/459_implement-multi-token-prediction/tests.json new file mode 100644 index 00000000..9ee6ac9b --- /dev/null +++ b/questions/459_implement-multi-token-prediction/tests.json @@ -0,0 +1,14 @@ +[ + { + "test": "import numpy as np\nnp.random.seed(42)\nhidden_dim = 4\nvocab_size = 10\nseq_len = 6\nnum_future = 2\nhidden_states = np.random.randn(seq_len, hidden_dim)\ntarget_ids = np.array([3, 1, 4, 1, 5, 9])\nembedding_weight = np.random.randn(vocab_size, hidden_dim)\nproj_weights = [np.random.randn(hidden_dim, 2*hidden_dim) for _ in range(num_future)]\nlm_head_weight = np.random.randn(vocab_size, hidden_dim)\nprint(round(multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, num_future), 4))", + "expected_output": "6.5678" + }, + { + "test": "import numpy as np\nnp.random.seed(0)\nhidden_states = np.random.randn(4, 3)\ntarget_ids = np.array([0, 1, 2, 0])\nembedding_weight = np.random.randn(3, 3)\nproj_weights = [np.random.randn(3, 6)]\nlm_head_weight = np.random.randn(3, 3)\nprint(round(multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, 1), 4))", + "expected_output": "3.1856" + }, + { + "test": "import numpy as np\nnp.random.seed(7)\nhidden_states = np.random.randn(5, 2)\ntarget_ids = np.array([0, 1, 0, 1, 0])\nembedding_weight = np.random.randn(2, 2)\nproj_weights = [np.random.randn(2, 4), np.random.randn(2, 4)]\nlm_head_weight = np.random.randn(2, 2)\nprint(round(multi_token_prediction(hidden_states, target_ids, embedding_weight, proj_weights, lm_head_weight, 2), 4))", + "expected_output": "0.656" + } +]