diff --git a/LEARNING_GUIDE.md b/LEARNING_GUIDE.md new file mode 100644 index 0000000..874f078 --- /dev/null +++ b/LEARNING_GUIDE.md @@ -0,0 +1,551 @@ +# Learning Guide: Deep Understanding of Bidirectional Attention + +Welcome! This guide will help you build deep intuitions about bidirectional attention and modern LLM architectures using this diffusion model as a teaching tool. + +## ๐ŸŽฏ Learning Objectives + +By the end of this guide, you will understand: +1. **What bidirectional attention is** and how it differs from causal attention +2. **Why bidirectional attention enables parallel decoding** in diffusion models +3. **Modern architectural components**: RoPE, RMSNorm, multi-head attention, QK normalization +4. **When to use causal vs bidirectional** attention for different tasks +5. **Best practices** for implementing transformer architectures + +--- + +## ๐Ÿ“š Structured Learning Path + +### Phase 1: Conceptual Understanding (30-45 minutes) + +**Start here to build mental models:** + +1. **Read the tutorial** (20 min) + ```bash + open docs/bidirectional_attention_tutorial.md + ``` + - Comprehensive explanation with diagrams + - Mathematical foundations + - Comparison with causal attention + - Real-world use cases + +2. **Review the annotated code** (15 min) + ```bash + open model.py + # Focus on the BidirectionalAttention class (lines 132-263) + ``` + - Every line is explained + - Shape transformations documented + - Design decisions justified + +3. **Quick check**: Can you answer these questions? + - What does `is_causal=False` mean? + - Why can't GPT use bidirectional attention? + - What are the three roles of Q, K, V? + +--- + +### Phase 2: Hands-On Experimentation (45-60 minutes) + +**Learn by doing:** + +#### Experiment 1: Compare Attention Types + +Run the comparison tool to see patterns side-by-side: + +```bash +uv run attention_comparison.py +``` + +**Expected output:** +- `attention_comparison.png` - Visual comparison of causal vs bidirectional +- Console output explaining key differences +- Statistics showing how outputs differ + +**What to observe:** +- Causal attention: Lower triangular pattern +- Bidirectional attention: Full matrix pattern +- How this affects the output representations + +**Try modifying:** +```python +# In attention_comparison.py, line 372: +visualize_attention_comparison(seq_len=16, n_embd=128, n_head=8) +``` +Experiment with different parameters to see how attention scales. + +--- + +#### Experiment 2: Visualize Your Trained Model + +See attention patterns from the actual trained model: + +```bash +uv run visualize_model_attention.py +``` + +**Expected output:** +- `actual_attention_pattern.png` - Real attention weights from trained model +- `diffusion_attention_evolution.png` - How generation evolves step-by-step +- Console analysis of attention patterns + +**What to observe:** +- How different heads learn different patterns +- How tokens gradually get decoded during diffusion +- Which tokens attend to which (full bidirectional access) + +**Try modifying:** +```python +# In visualize_model_attention.py, line 297: +visualize_actual_attention_matrix( + model, + text="Your custom text here", # Change this! + save_path="custom_attention.png" +) +``` + +--- + +#### Experiment 3: Generate Text and Observe + +Generate text to see bidirectional attention in action: + +```bash +uv run sample.py +``` + +**What's happening under the hood:** +1. Start with all `[MASK]` tokens +2. Model uses bidirectional attention to see full context +3. High-confidence tokens get decoded +4. Process repeats with better context each time +5. Final coherent text emerges + +**Key insight:** This parallel decoding is ONLY possible because of bidirectional attention. Each token can see its neighbors to maintain coherence. + +--- + +### Phase 3: Deep Dive into Components (60-90 minutes) + +**Master the individual architectural components:** + +#### Component 1: Multi-Head Attention + +**Concept:** Instead of one attention pattern, learn multiple patterns in parallel. + +**In the code (model.py:132-263):** +```python +self.n_head = 6 # 6 different attention patterns +self.head_dim = 384 // 6 = 64 # Each head has 64 dimensions +``` + +**Why multiple heads?** +- Head 1 might focus on grammar +- Head 2 might focus on semantics +- Head 3 might focus on local context +- Head 4 might focus on distant dependencies +- etc. + +**Experiment:** +Try visualizing different heads in `visualize_model_attention.py` - notice how each learns different patterns! + +--- + +#### Component 2: RoPE (Rotary Position Embeddings) + +**Concept:** Encode position through geometric rotation, not learned embeddings. + +**In the code (model.py:67-129):** +```python +q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) +``` + +**Why RoPE?** +- Relative position encoding (distance matters, not absolute position) +- Better extrapolation to longer sequences +- Used in LLaMA, Mistral, and most modern LLMs + +**Mathematical insight:** +``` +Rotation by angle ฮธ: +[x1'] [cos ฮธ -sin ฮธ] [x1] +[x2'] = [sin ฮธ cos ฮธ] [x2] +``` + +Tokens at positions m and n have relative rotation `ฮธ_m - ฮธ_n`, which the dot product naturally captures! + +**Further reading:** +- Original paper: https://arxiv.org/abs/2104.09864 +- Explanation: https://blog.eleuther.ai/rotary-embeddings/ + +--- + +#### Component 3: RMSNorm + +**Concept:** Simpler normalization that works just as well as LayerNorm. + +**In the code (model.py:39-64):** +```python +q, k = norm(q), norm(k) # QK normalization +``` + +**Formula:** +``` +RMSNorm(x) = x / sqrt(mean(xยฒ) + ฮต) +``` + +**Why simpler than LayerNorm?** +- No learnable parameters (gamma, beta) +- Doesn't subtract mean (only scales) +- Faster and equally effective +- Used in LLaMA and Mistral + +**Comparison:** +``` +LayerNorm: y = (x - mean) / sqrt(var) * gamma + beta [6 operations, 2 params] +RMSNorm: y = x / sqrt(mean(xยฒ)) [3 operations, 0 params] +``` + +--- + +#### Component 4: QK Normalization + +**Concept:** Normalize queries and keys before computing attention. + +**In the code (model.py:181):** +```python +q, k = norm(q), norm(k) # Normalize Q and K +``` + +**Why?** +- Prevents attention logits from exploding (stabilizes training) +- Allows higher learning rates +- More stable gradients + +**Without QK norm:** +``` +scores = Q @ K^T / sqrt(d) +# Can become very large โ†’ unstable softmax โ†’ vanishing/exploding gradients +``` + +**With QK norm:** +``` +Q_norm, K_norm = normalize(Q), normalize(K) +scores = Q_norm @ K_norm^T / sqrt(d) +# Stays in reasonable range โ†’ stable softmax โ†’ better training +``` + +--- + +### Phase 4: Code Walkthrough (45-60 minutes) + +**Trace through a forward pass step-by-step:** + +#### Exercise: Manual Forward Pass + +Let's trace what happens when we process the text `"KING:"` + +**Step 1: Tokenization** +```python +from model import encode_text +tokens = encode_text("KING:") +# Result: [75, 73, 78, 71, 58] (ASCII values) +``` + +**Step 2: Embeddings** +```python +x = model.token_emb(tokens) # [5, 384] +t_emb = model.time_emb(t) # [1, 384] +x = x + t_emb # Broadcast time across sequence +``` + +**Step 3: First Attention Layer** +```python +# Inside BidirectionalAttention.forward() +q = self.c_q(x).view(1, 5, 6, 64) # [batch, seq_len, n_head, head_dim] +k = self.c_k(x).view(1, 5, 6, 64) +v = self.c_v(x).view(1, 5, 6, 64) + +# Apply RoPE +q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + +# Normalize +q, k = norm(q), norm(k) + +# Reshape for attention +q = q.transpose(1, 2) # [1, 6, 5, 64] (batch, heads, seq, dim) +k = k.transpose(1, 2) +v = v.transpose(1, 2) + +# Compute attention - THE KEY STEP! +# This computes: softmax(Q @ K^T / sqrt(64)) @ V +# is_causal=False means NO MASK - bidirectional! +y = F.scaled_dot_product_attention(q, k, v, is_causal=False) +# Result: [1, 6, 5, 64] + +# Merge heads +y = y.transpose(1, 2).view(1, 5, 384) # [batch, seq_len, n_embd] + +# Output projection +y = self.c_proj(y) +``` + +**Step 4: MLP (Feedforward)** +```python +x = x + attn_output # Residual connection +x = x + mlp(norm(x)) # Another residual connection +``` + +**Step 5: Repeat for 6 Layers** +```python +for block in self.blocks: + x = block(x, cos_sin) # Each block does attention + MLP +``` + +**Step 6: Output Head** +```python +logits = self.output_head(x) # [1, 5, 128] - predict token for each position +``` + +Try adding print statements in `model.py` to see these shapes during actual execution! + +--- + +### Phase 5: Experiments and Modifications (60-90 minutes) + +**Solidify understanding by breaking things:** + +#### Experiment 1: Switch to Causal Attention + +**Hypothesis:** Making attention causal will break the diffusion process. + +**Modification:** +```python +# In model.py, line 235, change: +y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # was False +``` + +**Run:** +```bash +uv run sample.py +``` + +**Expected result:** +- Degraded quality (might work a bit due to context, but much worse) +- Tokens can't see future context for coherence +- Parallel decoding becomes less effective + +**Why it breaks:** +In diffusion, we need bidirectional context: +``` +[M][M][M] โ†’ [M]"cat"[M] +``` +When predicting the first token, it NEEDS to see "cat" to know it should be "The" (with capital T). + +**Revert the change before continuing!** + +--- + +#### Experiment 2: Remove RoPE + +**Hypothesis:** Without positional encoding, the model loses position information. + +**Modification:** +```python +# In model.py, line 165, comment out RoPE: +# q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) +``` + +**Expected result:** +- Model can't distinguish token positions +- "The cat" and "cat The" look identical to the model +- Generation becomes incoherent + +**Why positions matter:** +Attention is permutation-invariant without position info! +``` +attn("The cat sat") == attn("sat cat The") # Without position encoding +``` + +--- + +#### Experiment 3: Single Head vs Multi-Head + +**Hypothesis:** Multiple heads help learn diverse patterns. + +**Modification:** +```python +# In model.py, change config (line 33): +n_head: int = 1 # was 6 +``` + +**Run:** +```bash +uv run visualize_model_attention.py +``` + +**Expected result:** +- Only one attention pattern (less expressive) +- Model has harder time learning complex dependencies +- Training might be less stable + +**Why multi-head helps:** +Different heads specialize: +- Local context (nearby tokens) +- Global context (distant tokens) +- Syntax patterns +- Semantic patterns + +--- + +#### Experiment 4: Visualize QK Normalization Effect + +**Hypothesis:** QK norm prevents attention from becoming too peaked. + +**Modification:** +```python +# In model.py, line 181, comment out normalization: +# q, k = norm(q), norm(k) +``` + +**What to check:** +- Does attention become more concentrated (peaky)? +- Does training become unstable? + +**Expected:** +- Attention weights might collapse to one-hot (all weight on one token) +- Less smooth attention patterns +- Harder to train + +--- + +### Phase 6: Comparison with Other Architectures (30 minutes) + +**Understand the design space:** + +| Architecture | Attention Type | Use Case | Example Models | +|--------------|---------------|----------|----------------| +| **GPT** | Causal (unidirectional) | Autoregressive generation | GPT-3, GPT-4, LLaMA, Mistral | +| **BERT** | Bidirectional | Masked language modeling | BERT, RoBERTa | +| **T5** | Encoder: Bidirectional
Decoder: Causal | Seq2seq tasks | T5, BART | +| **This model** | Bidirectional | Diffusion generation | This repo! | + +**Key insight:** The attention type is chosen based on the task: +- **Generate left-to-right?** โ†’ Causal +- **Understand full context?** โ†’ Bidirectional +- **Both?** โ†’ Encoder-decoder (T5) + +--- + +### Phase 7: Additional Resources (Optional) + +**Deepen your knowledge:** + +#### Papers to Read: +1. **Attention Is All You Need** (Vaswani et al., 2017) + - Original transformer paper + - https://arxiv.org/abs/1706.03762 + +2. **BERT: Pre-training of Deep Bidirectional Transformers** (Devlin et al., 2018) + - Introduces bidirectional pre-training + - https://arxiv.org/abs/1810.04805 + +3. **RoFormer: Enhanced Transformer with Rotary Position Embedding** (Su et al., 2021) + - Introduces RoPE + - https://arxiv.org/abs/2104.09864 + +4. **Root Mean Square Layer Normalization** (Zhang & Sennrich, 2019) + - RMSNorm paper + - https://arxiv.org/abs/1910.07467 + +#### Blog Posts: +1. **The Illustrated Transformer** by Jay Alammar + - Visual explanation of attention + - https://jalammar.github.io/illustrated-transformer/ + +2. **Understanding RoPE** by EleutherAI + - Deep dive into rotary embeddings + - https://blog.eleuther.ai/rotary-embeddings/ + +3. **Flash Attention** by Tri Dao + - How to make attention efficient + - https://arxiv.org/abs/2205.14135 + +--- + +## ๐ŸŽ“ Assessment: Test Your Understanding + +Try answering these questions without looking at the materials: + +### Basic Understanding: +1. What does "bidirectional" mean in the context of attention? +2. What are the shapes of Q, K, V in multi-head attention? +3. Why do we scale attention scores by `sqrt(head_dim)`? +4. What's the difference between `is_causal=True` and `is_causal=False`? + +### Intermediate: +5. Why can't GPT use bidirectional attention during generation? +6. How does RoPE encode position information? +7. What's the advantage of RMSNorm over LayerNorm? +8. Why do we normalize Q and K before computing attention? + +### Advanced: +9. How does bidirectional attention enable parallel decoding in diffusion? +10. What would happen if we used causal attention in this diffusion model? +11. How do different attention heads specialize during training? +12. Why is RoPE applied to Q and K but not V? + +--- + +## ๐Ÿš€ Next Steps + +After mastering bidirectional attention, explore: + +1. **Other modern architectures:** + - Flash Attention (efficient attention) + - Sparse attention (reduce O(nยฒ) complexity) + - Cross-attention (encoder-decoder) + +2. **Training techniques:** + - Mixed precision training + - Gradient checkpointing + - Learning rate schedules + +3. **Scaling laws:** + - How model size affects performance + - Optimal batch size and learning rate + +4. **Advanced topics:** + - LoRA (efficient fine-tuning) + - Quantization (8-bit, 4-bit) + - Distributed training + +--- + +## ๐Ÿ“ Summary + +You now deeply understand: + +โœ… **Bidirectional attention** - Full context access, no causal mask +โœ… **When to use it** - Diffusion models, BERT-style tasks, full-context understanding +โœ… **How it differs from causal** - Matrix structure, use cases, trade-offs +โœ… **Modern components** - RoPE, RMSNorm, multi-head, QK norm +โœ… **Implementation details** - Every line of code explained +โœ… **Practical experience** - Ran experiments, visualized patterns + +**Key insight:** The choice between causal and bidirectional attention fundamentally determines what your model can do: +- **Causal:** Sequential generation, online processing +- **Bidirectional:** Full context understanding, parallel refinement + +This tiny-diffusion repo demonstrates how bidirectional attention enables a completely different generation paradigm: iterative refinement instead of left-to-right autoregression. + +--- + +## ๐Ÿ’ฌ Questions? + +If you have questions or want to explore specific topics deeper: +1. Review the relevant section in `docs/bidirectional_attention_tutorial.md` +2. Check the detailed comments in `model.py` +3. Run the visualization tools to build intuition +4. Experiment by modifying the code + +Happy learning! ๐ŸŽ‰ diff --git a/attention_comparison.py b/attention_comparison.py new file mode 100644 index 0000000..aa7f230 --- /dev/null +++ b/attention_comparison.py @@ -0,0 +1,387 @@ +""" +Side-by-side comparison of Causal vs Bidirectional Attention + +This module demonstrates the key differences between causal (GPT-style) and +bidirectional (BERT-style) attention mechanisms. + +Run this to see: +1. How masking affects attention patterns +2. Output differences between the two approaches +3. Visual comparison of attention matrices +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +import numpy as np + + +class CausalAttention(nn.Module): + """ + Causal (unidirectional) attention - used in GPT, LLaMA, Mistral + + Key characteristic: Each token can only attend to itself and previous tokens + This is enforced through a causal mask (lower triangular matrix) + """ + + def __init__(self, n_embd, n_head): + super().__init__() + self.n_head = n_head + self.n_embd = n_embd + self.head_dim = n_embd // n_head + assert n_embd % n_head == 0, "n_embd must be divisible by n_head" + + # Three projection matrices for Q, K, V + self.c_q = nn.Linear(n_embd, n_embd, bias=False) + self.c_k = nn.Linear(n_embd, n_embd, bias=False) + self.c_v = nn.Linear(n_embd, n_embd, bias=False) + + # Output projection + self.c_proj = nn.Linear(n_embd, n_embd, bias=False) + + def forward(self, x, return_attn_weights=False): + """ + Args: + x: Input tensor of shape (batch, seq_len, n_embd) + return_attn_weights: If True, also return attention weights for visualization + + Returns: + output: Attention output of shape (batch, seq_len, n_embd) + attn_weights: (optional) Attention weights of shape (batch, n_head, seq_len, seq_len) + """ + B, T, C = x.size() + + # Step 1: Project input to Q, K, V + # Each projection learns a different transformation + q = self.c_q(x).view(B, T, self.n_head, self.head_dim) + k = self.c_k(x).view(B, T, self.n_head, self.head_dim) + v = self.c_v(x).view(B, T, self.n_head, self.head_dim) + + # Step 2: Transpose to get (batch, n_head, seq_len, head_dim) + # This makes it easier to process each head independently + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Step 3: Compute attention scores + # scores[b, h, i, j] = similarity between query_i and key_j in head h + scores = torch.matmul(q, k.transpose(-2, -1)) # (B, n_head, T, T) + + # Step 4: Scale by sqrt(head_dim) to prevent vanishing gradients + scores = scores / (self.head_dim ** 0.5) + + # Step 5: Apply CAUSAL MASK + # This is the key difference from bidirectional attention! + # Create lower triangular mask: 1s on and below diagonal, 0s above + causal_mask = torch.tril(torch.ones(T, T, device=x.device)).bool() + + # Set masked positions to -inf so they become 0 after softmax + scores = scores.masked_fill(~causal_mask, float('-inf')) + + # Step 6: Apply softmax to get attention weights + # Each row sums to 1.0 (probability distribution) + attn_weights = F.softmax(scores, dim=-1) + + # Handle NaN from -inf (when entire row is masked) + attn_weights = torch.where( + torch.isnan(attn_weights), + torch.zeros_like(attn_weights), + attn_weights + ) + + # Step 7: Apply attention weights to values + # This computes the weighted combination of value vectors + output = torch.matmul(attn_weights, v) # (B, n_head, T, head_dim) + + # Step 8: Merge heads back together + output = output.transpose(1, 2).contiguous() # (B, T, n_head, head_dim) + output = output.view(B, T, C) # (B, T, n_embd) + + # Step 9: Final output projection + output = self.c_proj(output) + + if return_attn_weights: + return output, attn_weights + return output + + +class BidirectionalAttention(nn.Module): + """ + Bidirectional (omnidirectional) attention - used in BERT, diffusion models + + Key characteristic: Each token can attend to ALL tokens (past, present, future) + No causal mask is applied, allowing full context awareness + """ + + def __init__(self, n_embd, n_head): + super().__init__() + self.n_head = n_head + self.n_embd = n_embd + self.head_dim = n_embd // n_head + assert n_embd % n_head == 0, "n_embd must be divisible by n_head" + + # Three projection matrices for Q, K, V + self.c_q = nn.Linear(n_embd, n_embd, bias=False) + self.c_k = nn.Linear(n_embd, n_embd, bias=False) + self.c_v = nn.Linear(n_embd, n_embd, bias=False) + + # Output projection + self.c_proj = nn.Linear(n_embd, n_embd, bias=False) + + def forward(self, x, return_attn_weights=False): + """ + Args: + x: Input tensor of shape (batch, seq_len, n_embd) + return_attn_weights: If True, also return attention weights for visualization + + Returns: + output: Attention output of shape (batch, seq_len, n_embd) + attn_weights: (optional) Attention weights of shape (batch, n_head, seq_len, seq_len) + """ + B, T, C = x.size() + + # Step 1: Project input to Q, K, V + q = self.c_q(x).view(B, T, self.n_head, self.head_dim) + k = self.c_k(x).view(B, T, self.n_head, self.head_dim) + v = self.c_v(x).view(B, T, self.n_head, self.head_dim) + + # Step 2: Transpose to get (batch, n_head, seq_len, head_dim) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Step 3: Compute attention scores + scores = torch.matmul(q, k.transpose(-2, -1)) # (B, n_head, T, T) + + # Step 4: Scale by sqrt(head_dim) + scores = scores / (self.head_dim ** 0.5) + + # Step 5: NO CAUSAL MASK! + # This is the key difference: we don't mask any positions + # Every token can attend to every other token + + # Step 6: Apply softmax to get attention weights + attn_weights = F.softmax(scores, dim=-1) + + # Step 7: Apply attention weights to values + output = torch.matmul(attn_weights, v) # (B, n_head, T, head_dim) + + # Step 8: Merge heads back together + output = output.transpose(1, 2).contiguous() + output = output.view(B, T, C) + + # Step 9: Final output projection + output = self.c_proj(output) + + if return_attn_weights: + return output, attn_weights + return output + + +def visualize_attention_comparison(seq_len=8, n_embd=64, n_head=4): + """ + Visualize the attention patterns for both causal and bidirectional attention + + This creates a side-by-side comparison showing: + 1. Causal attention pattern (lower triangular) + 2. Bidirectional attention pattern (full matrix) + """ + # Create sample input + torch.manual_seed(42) + batch_size = 1 + x = torch.randn(batch_size, seq_len, n_embd) + + # Initialize both attention types + causal_attn = CausalAttention(n_embd, n_head) + bidir_attn = BidirectionalAttention(n_embd, n_head) + + # Get outputs and attention weights + _, causal_weights = causal_attn(x, return_attn_weights=True) + _, bidir_weights = bidir_attn(x, return_attn_weights=True) + + # Convert to numpy for plotting + causal_weights = causal_weights[0].detach().numpy() # (n_head, seq_len, seq_len) + bidir_weights = bidir_weights[0].detach().numpy() + + # Create visualization + fig, axes = plt.subplots(2, n_head, figsize=(4*n_head, 8)) + fig.suptitle('Attention Pattern Comparison: Causal vs Bidirectional', fontsize=16, fontweight='bold') + + for head_idx in range(n_head): + # Plot causal attention + ax1 = axes[0, head_idx] + im1 = ax1.imshow(causal_weights[head_idx], cmap='viridis', aspect='auto', vmin=0, vmax=1) + ax1.set_title(f'Causal - Head {head_idx}', fontweight='bold') + ax1.set_xlabel('Key Position') + ax1.set_ylabel('Query Position') + plt.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04) + + # Add grid + ax1.set_xticks(range(seq_len)) + ax1.set_yticks(range(seq_len)) + ax1.grid(True, alpha=0.3) + + # Plot bidirectional attention + ax2 = axes[1, head_idx] + im2 = ax2.imshow(bidir_weights[head_idx], cmap='viridis', aspect='auto', vmin=0, vmax=1) + ax2.set_title(f'Bidirectional - Head {head_idx}', fontweight='bold') + ax2.set_xlabel('Key Position') + ax2.set_ylabel('Query Position') + plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04) + + # Add grid + ax2.set_xticks(range(seq_len)) + ax2.set_yticks(range(seq_len)) + ax2.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig('attention_comparison.png', dpi=150, bbox_inches='tight') + print("Visualization saved to: attention_comparison.png") + plt.close() + + +def demonstrate_output_differences(): + """ + Show how the two attention types produce different outputs for the same input + """ + print("\n" + "="*70) + print("DEMONSTRATING OUTPUT DIFFERENCES") + print("="*70) + + # Setup + torch.manual_seed(42) + batch_size = 1 + seq_len = 6 + n_embd = 64 + n_head = 4 + + # Create input with clear pattern + x = torch.randn(batch_size, seq_len, n_embd) + + # Initialize both attention types with same weights for fair comparison + causal_attn = CausalAttention(n_embd, n_head) + bidir_attn = BidirectionalAttention(n_embd, n_head) + + # Copy weights to make comparison fair + bidir_attn.c_q.weight.data = causal_attn.c_q.weight.data.clone() + bidir_attn.c_k.weight.data = causal_attn.c_k.weight.data.clone() + bidir_attn.c_v.weight.data = causal_attn.c_v.weight.data.clone() + bidir_attn.c_proj.weight.data = causal_attn.c_proj.weight.data.clone() + + # Get outputs + causal_out, causal_weights = causal_attn(x, return_attn_weights=True) + bidir_out, bidir_weights = bidir_attn(x, return_attn_weights=True) + + print("\nInput shape:", x.shape) + print("Causal output shape:", causal_out.shape) + print("Bidirectional output shape:", bidir_out.shape) + + print("\n" + "-"*70) + print("ATTENTION WEIGHT STATISTICS (First Head)") + print("-"*70) + + # Analyze first head + causal_head0 = causal_weights[0, 0].detach() + bidir_head0 = bidir_weights[0, 0].detach() + + print("\nCausal Attention (each row = where this token looks):") + print(causal_head0.numpy().round(3)) + + print("\nBidirectional Attention (each row = where this token looks):") + print(bidir_head0.numpy().round(3)) + + print("\n" + "-"*70) + print("KEY OBSERVATIONS") + print("-"*70) + + print("\n1. Causal Attention Pattern:") + print(" - Lower triangular structure (zeros above diagonal)") + print(" - Token 0 only looks at itself") + print(" - Token 5 can look at all previous tokens (0-5)") + + print("\n2. Bidirectional Attention Pattern:") + print(" - Full matrix (no zeros)") + print(" - Every token looks at all tokens") + print(" - Token 0 can see future tokens (1-5)") + + print("\n3. Output Differences:") + output_diff = torch.abs(causal_out - bidir_out).mean().item() + print(f" - Mean absolute difference: {output_diff:.6f}") + print(f" - Max absolute difference: {torch.abs(causal_out - bidir_out).max().item():.6f}") + + print("\n" + "="*70 + "\n") + + +def demonstrate_use_cases(): + """ + Show practical examples of when to use each attention type + """ + print("\n" + "="*70) + print("USE CASE DEMONSTRATIONS") + print("="*70) + + print("\n๐Ÿ“ SCENARIO 1: Next Token Prediction (Text Generation)") + print("-" * 70) + print("Task: Given 'The cat sat on', predict next word") + print("โœ“ Use: CAUSAL attention") + print("Why: When predicting 'the', we shouldn't see 'mat' that comes after") + print("This prevents the model from 'cheating' during training") + + print("\n๐ŸŽญ SCENARIO 2: Masked Language Modeling (BERT-style)") + print("-" * 70) + print("Task: Given 'The cat [MASK] on the mat', predict masked word") + print("โœ“ Use: BIDIRECTIONAL attention") + print("Why: We need context from both sides ('cat' and 'on') to predict 'sat'") + + print("\n๐ŸŽจ SCENARIO 3: Diffusion-based Generation (This Repo!)") + print("-" * 70) + print("Task: Iteratively refine: [M][M][M] โ†’ [M]'cat'[M] โ†’ 'The''cat'[M] โ†’ 'The''cat''sat'") + print("โœ“ Use: BIDIRECTIONAL attention") + print("Why: Each token needs to see neighbors to maintain coherence during refinement") + print(" 'cat' needs to see 'The' to know it should be lowercase") + + print("\n๐Ÿ’ฌ SCENARIO 4: Classification (Sentiment Analysis)") + print("-" * 70) + print("Task: Given 'This movie was great!', predict positive/negative") + print("โœ“ Use: BIDIRECTIONAL attention") + print("Why: Need to understand full sentence context (word 'great' at end is crucial)") + + print("\n๐Ÿ”„ SCENARIO 5: Translation Encoder") + print("-" * 70) + print("Task: Encode French sentence before translating to English") + print("โœ“ Use: BIDIRECTIONAL attention") + print("Why: Full sentence context helps understand meaning before translation") + + print("\n" + "="*70 + "\n") + + +def main(): + """ + Run all demonstrations + """ + print("\n" + "="*70) + print("ATTENTION MECHANISM COMPARISON TOOL") + print("Causal (GPT-style) vs Bidirectional (BERT-style)") + print("="*70) + + # Run demonstrations + print("\n[1/3] Visualizing attention patterns...") + visualize_attention_comparison(seq_len=8, n_embd=64, n_head=4) + + print("\n[2/3] Analyzing output differences...") + demonstrate_output_differences() + + print("\n[3/3] Showing use cases...") + demonstrate_use_cases() + + print("\nโœ… All demonstrations complete!") + print("\nNext steps:") + print("1. Open 'attention_comparison.png' to see visual patterns") + print("2. Experiment with different seq_len and n_head values") + print("3. Try modifying model.py to switch between attention types") + print("4. Read docs/bidirectional_attention_tutorial.md for deeper understanding") + + +if __name__ == "__main__": + main() diff --git a/docs/bidirectional_attention_tutorial.md b/docs/bidirectional_attention_tutorial.md new file mode 100644 index 0000000..3a576b7 --- /dev/null +++ b/docs/bidirectional_attention_tutorial.md @@ -0,0 +1,439 @@ +# Deep Dive: Bidirectional Attention + +## Table of Contents +1. [Introduction](#introduction) +2. [What is Attention?](#what-is-attention) +3. [Causal vs Bidirectional Attention](#causal-vs-bidirectional) +4. [Mathematical Formulation](#mathematical-formulation) +5. [Implementation Deep Dive](#implementation-deep-dive) +6. [When to Use Each Type](#when-to-use) +7. [Visualization](#visualization) + +--- + +## Introduction + +Attention is the core mechanism that allows transformers to model relationships between tokens in a sequence. The key distinction between **causal (unidirectional)** and **bidirectional** attention determines what information each token can "see" when computing its representation. + +### Quick Analogy +Imagine reading a sentence to predict the next word: + +**Causal Attention (GPT):** Like reading left-to-right with a piece of paper covering future words. Each word can only look at itself and previous words. + +**Bidirectional Attention (BERT/Diffusion):** Like reading the entire sentence at once. Each word can look at ALL other words, past and future. + +--- + +## What is Attention? + +At its core, attention computes a weighted combination of values based on the similarity between queries and keys. + +### The Three Components + +For each token, we compute three vectors: + +1. **Query (Q):** "What am I looking for?" +2. **Key (K):** "What do I contain?" +3. **Value (V):** "What information should I pass forward?" + +### The Attention Formula + +``` +Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V +``` + +Let's break this down step by step: + +1. **Q @ K^T:** Compute similarity between every query and every key + - Shape: (batch, heads, seq_len, seq_len) + - Each position gets a score for how relevant every other position is + +2. **/ sqrt(d_k):** Scale by square root of key dimension + - Prevents dot products from becoming too large + - Keeps gradients stable + +3. **softmax(...):** Convert scores to probabilities + - Each row sums to 1.0 + - High scores become higher, low scores become lower + +4. **@ V:** Weighted combination of values + - Mix information from all positions based on attention weights + +--- + +## Causal vs Bidirectional + +### Causal Attention (GPT, LLaMA, Mistral) + +**Mask Pattern:** +``` +Token 0: can see [0] +Token 1: can see [0, 1] +Token 2: can see [0, 1, 2] +Token 3: can see [0, 1, 2, 3] +``` + +**Attention Matrix (โœ“ = allowed, โœ— = masked):** +``` + T0 T1 T2 T3 +T0 โœ“ โœ— โœ— โœ— +T1 โœ“ โœ“ โœ— โœ— +T2 โœ“ โœ“ โœ“ โœ— +T3 โœ“ โœ“ โœ“ โœ“ +``` + +**Why?** For autoregressive generation, we must prevent "cheating" by looking at future tokens we're trying to predict. + +**Use Cases:** +- Text generation (GPT, LLaMA) +- Code generation (Codex, Code LLaMA) +- Any task where you generate left-to-right + +### Bidirectional Attention (BERT, Diffusion Models) + +**Mask Pattern:** +``` +Token 0: can see [0, 1, 2, 3] +Token 1: can see [0, 1, 2, 3] +Token 2: can see [0, 1, 2, 3] +Token 3: can see [0, 1, 2, 3] +``` + +**Attention Matrix (โœ“ = allowed, โœ— = masked):** +``` + T0 T1 T2 T3 +T0 โœ“ โœ“ โœ“ โœ“ +T1 โœ“ โœ“ โœ“ โœ“ +T2 โœ“ โœ“ โœ“ โœ“ +T3 โœ“ โœ“ โœ“ โœ“ +``` + +**Why?** When you need full context understanding or iterative refinement, seeing the entire sequence helps. + +**Use Cases:** +- Masked language modeling (BERT) +- Diffusion models (this repo!) +- Encoders for translation/summarization +- Any task where full context is available + +--- + +## Mathematical Formulation + +### Step-by-Step Computation + +Let's walk through a concrete example with: +- Sequence length: 4 tokens +- Embedding dimension: 8 +- Number of heads: 2 +- Head dimension: 4 (= 8 / 2) + +**Input:** +```python +x = [batch, seq_len, n_embd] # [B, 4, 8] +``` + +**Step 1: Project to Q, K, V** +```python +Q = x @ W_q # [B, 4, 8] โ†’ [B, 4, 8] +K = x @ W_k # [B, 4, 8] โ†’ [B, 4, 8] +V = x @ W_v # [B, 4, 8] โ†’ [B, 4, 8] +``` + +**Step 2: Split into Multiple Heads** +```python +Q = Q.reshape(B, 4, 2, 4) # [B, seq_len, n_head, head_dim] +K = K.reshape(B, 4, 2, 4) +V = V.reshape(B, 4, 2, 4) + +# Transpose for easier computation +Q = Q.transpose(1, 2) # [B, n_head, seq_len, head_dim] = [B, 2, 4, 4] +K = K.transpose(1, 2) # [B, 2, 4, 4] +V = V.transpose(1, 2) # [B, 2, 4, 4] +``` + +**Step 3: Compute Attention Scores** +```python +scores = Q @ K.transpose(-2, -1) # [B, 2, 4, 4] +# Each head gets its own [4, 4] attention matrix +``` + +**Example Attention Matrix (before masking):** +``` + K0 K1 K2 K3 +Q0 [0.8 0.2 0.1 0.3] +Q1 [0.3 0.9 0.4 0.2] +Q2 [0.1 0.3 0.7 0.5] +Q3 [0.2 0.1 0.6 0.8] +``` + +**Step 4: Scale** +```python +scores = scores / sqrt(head_dim) # Divide by sqrt(4) = 2.0 +``` + +**Step 5: Apply Mask (if causal)** +```python +# For CAUSAL attention: +mask = torch.triu(torch.ones(4, 4), diagonal=1).bool() +scores = scores.masked_fill(mask, -inf) + +# After masking: + K0 K1 K2 K3 +Q0 [0.4 -inf -inf -inf] +Q1 [0.15 0.45 -inf -inf] +Q2 [0.05 0.15 0.35 -inf] +Q3 [0.1 0.05 0.3 0.4] + +# For BIDIRECTIONAL attention: +# No masking! Use scores as-is +``` + +**Step 6: Softmax** +```python +attn_weights = softmax(scores, dim=-1) # [B, 2, 4, 4] + +# For bidirectional (no masking): + K0 K1 K2 K3 +Q0 [0.50 0.20 0.15 0.15] โ† Sums to 1.0 +Q1 [0.15 0.55 0.20 0.10] โ† Sums to 1.0 +Q2 [0.10 0.15 0.45 0.30] โ† Sums to 1.0 +Q3 [0.15 0.10 0.30 0.45] โ† Sums to 1.0 +``` + +**Step 7: Apply Attention to Values** +```python +output = attn_weights @ V # [B, 2, 4, 4] @ [B, 2, 4, 4] = [B, 2, 4, 4] +``` + +**Step 8: Merge Heads** +```python +output = output.transpose(1, 2) # [B, 4, 2, 4] +output = output.reshape(B, 4, 8) # [B, seq_len, n_embd] +``` + +**Step 9: Final Projection** +```python +output = output @ W_o # [B, 4, 8] โ†’ [B, 4, 8] +``` + +--- + +## Implementation Deep Dive + +Let's examine the actual code from `model.py`: + +```python +class BidirectionalAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.n_head = config.n_head + self.n_embd = config.n_embd + self.head_dim = self.n_embd // self.n_head + assert self.n_embd % self.n_head == 0 + + # Three separate weight matrices + self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False) + self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False) + self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False) + + # Output projection + self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) + + def forward(self, x, cos_sin): + B, T, C = x.size() + + # Project to Q, K, V + q = self.c_q(x).view(B, T, self.n_head, self.head_dim) + k = self.c_k(x).view(B, T, self.n_head, self.head_dim) + v = self.c_v(x).view(B, T, self.n_head, self.head_dim) + + # Apply Rotary Embeddings (we'll cover this separately) + cos, sin = cos_sin + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + + # QK normalization (stabilizes training) + q, k = norm(q), norm(k) + + # Reshape for multi-head attention + q, k, v = ( + q.transpose(1, 2), # [B, n_head, T, head_dim] + k.transpose(1, 2), + v.transpose(1, 2), + ) + + # THE KEY LINE: is_causal=False means BIDIRECTIONAL! + y = F.scaled_dot_product_attention(q, k, v, is_causal=False) + + # Merge heads and project + y = y.transpose(1, 2).contiguous().view(B, T, -1) + y = self.c_proj(y) + return y +``` + +### Key Insights + +**1. `is_causal=False`** +This single parameter makes it bidirectional! PyTorch's `scaled_dot_product_attention` implements: +- `is_causal=True`: Applies lower triangular mask (GPT-style) +- `is_causal=False`: No mask (BERT-style, our case) + +**2. Separate Q, K, V Projections** +```python +self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False) +self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False) +self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False) +``` +Three different learned transformations allow the model to: +- **Q:** Learn what patterns to look for +- **K:** Learn what patterns to advertise +- **V:** Learn what information to pass forward + +**3. Multi-Head Attention** +```python +self.n_head = 6 # From config +self.head_dim = 384 // 6 = 64 +``` +Instead of one big attention matrix, we have 6 smaller ones that learn different patterns: +- Head 1 might focus on grammar +- Head 2 might focus on semantic relationships +- Head 3 might focus on nearby tokens +- etc. + +**4. QK Normalization** +```python +q, k = norm(q), norm(k) # RMSNorm +``` +Modern technique to stabilize training by normalizing queries and keys before computing attention. + +--- + +## When to Use Each Type + +### Use Causal (Unidirectional) When: + +1. **Autoregressive Generation** + - GPT-style text generation + - You generate one token at a time, left-to-right + - Each token only knows about the past + +2. **Online/Streaming Tasks** + - Real-time processing where future context isn't available + - Chat applications, live transcription + +3. **Language Modeling** + - Training objective: predict next token given previous tokens + - Examples: GPT-2, GPT-3, LLaMA, Mistral + +### Use Bidirectional When: + +1. **Full Context Understanding** + - BERT-style masked language modeling + - Classification tasks (sentiment, topic, etc.) + - You have the full sequence and want to understand it + +2. **Iterative Refinement** + - Diffusion models (like this repo!) + - Start with noise, iteratively denoise using full context + - Each position needs to see neighbors to refine predictions + +3. **Encoder Models** + - Translation encoders (encode source sentence) + - Summarization encoders + - Any task where you encode before decoding + +4. **Fill-in-the-Blank Tasks** + - Masked token prediction + - Text infilling + - Code completion in the middle of a function + +--- + +## Why Bidirectional for Diffusion? + +In this repo, we use a **diffusion process** to generate text: + +1. **Start:** Fully masked sequence: `[MASK] [MASK] [MASK] [MASK]` +2. **Step 1:** Predict all tokens using context, unmask high-confidence ones: `[MASK] "the" [MASK] [MASK]` +3. **Step 2:** Re-predict using partial context: `"Once" "the" [MASK] "king"` +4. **Step 3:** Continue until done: `"Once" "the" "old" "king"` + +**Why bidirectional helps:** +- Token 3 can see Token 1 to maintain coherence +- Token 1 can see Token 3 to ensure consistency +- Parallel decoding: predict all positions simultaneously +- Refinement: each iteration uses better context from previous predictions + +Compare to GPT: +- GPT: `"Once"` โ†’ `"the"` โ†’ `"old"` โ†’ `"king"` (sequential) +- Diffusion: `[M][M][M][M]` โ†’ `[M]"the"[M][M]` โ†’ `"Once""the""old""king"` (parallel) + +--- + +## Visualization + +Imagine a sentence: **"The cat sat"** + +### Causal Attention Pattern +``` +Query: "The" โ†’ Attends to: ["The"] +Query: "cat" โ†’ Attends to: ["The", "cat"] +Query: "sat" โ†’ Attends to: ["The", "cat", "sat"] +``` + +**Visual:** +``` + The cat sat +The [โ–ˆ] [ ] [ ] +cat [โ–ˆ] [โ–ˆ] [ ] +sat [โ–ˆ] [โ–ˆ] [โ–ˆ] +``` + +### Bidirectional Attention Pattern +``` +Query: "The" โ†’ Attends to: ["The", "cat", "sat"] +Query: "cat" โ†’ Attends to: ["The", "cat", "sat"] +Query: "sat" โ†’ Attends to: ["The", "cat", "sat"] +``` + +**Visual:** +``` + The cat sat +The [โ–ˆ] [โ–ˆ] [โ–ˆ] +cat [โ–ˆ] [โ–ˆ] [โ–ˆ] +sat [โ–ˆ] [โ–ˆ] [โ–ˆ] +``` + +--- + +## Key Takeaways + +1. **Attention is about weighted combinations** - each token mixes information from other tokens based on relevance + +2. **Causal = lower triangular mask** - prevents looking at future tokens (GPT, LLaMA) + +3. **Bidirectional = no mask** - allows looking everywhere (BERT, this diffusion model) + +4. **The choice depends on your task:** + - Need to generate sequentially? โ†’ Causal + - Have full context available? โ†’ Bidirectional + +5. **Implementation difference is tiny** - literally just `is_causal=True` vs `is_causal=False` + +6. **But the implications are huge:** + - Causal: online, streaming, autoregressive + - Bidirectional: offline, full context, parallel + +--- + +## Next Steps + +Now that you understand bidirectional attention, explore: + +1. **Run the visualization tool** (coming next) to see attention patterns +2. **Compare implementations** - see causal vs bidirectional side-by-side +3. **Experiment** - try modifying the code to use causal attention and see what breaks +4. **Study RoPE** - understand how positional information is encoded +5. **Dive into diffusion** - understand why bidirectional attention enables parallel decoding + +Ready to see these concepts in action? Let's build visualization tools! diff --git a/docs/quick_reference.md b/docs/quick_reference.md new file mode 100644 index 0000000..ad04cee --- /dev/null +++ b/docs/quick_reference.md @@ -0,0 +1,260 @@ +# Quick Reference: Bidirectional Attention + +A one-page reference for quick lookups while coding. + +--- + +## Core Concepts + +### Attention Formula +``` +Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V +``` + +### Causal vs Bidirectional + +| Aspect | Causal (GPT) | Bidirectional (BERT) | +|--------|-------------|---------------------| +| **Mask** | Lower triangular | No mask | +| **Each token sees** | Only past | All tokens | +| **Use case** | Generation | Understanding | +| **Matrix pattern** | โ—ฃ (triangle) | โ—ผ (full) | +| **PyTorch** | `is_causal=True` | `is_causal=False` | + +--- + +## Architecture Components + +### Multi-Head Attention +```python +# Split into heads +q = q.view(B, T, n_head, head_dim).transpose(1, 2) +k = k.view(B, T, n_head, head_dim).transpose(1, 2) +v = v.view(B, T, n_head, head_dim).transpose(1, 2) + +# Compute attention +output = F.scaled_dot_product_attention(q, k, v, is_causal=False) + +# Merge heads +output = output.transpose(1, 2).view(B, T, n_embd) +``` + +### RoPE (Rotary Position Embeddings) +```python +# 2D rotation for each dimension pair +y1 = x1 * cos + x2 * sin +y2 = x1 * (-sin) + x2 * cos +``` + +**Benefits:** +- Relative position encoding +- Better extrapolation +- No learned parameters + +### RMSNorm +```python +# Simpler than LayerNorm +RMSNorm(x) = x / sqrt(mean(xยฒ) + ฮต) +``` + +**Advantages:** +- No learnable params +- Faster computation +- Equally effective + +### QK Normalization +```python +# Normalize before attention +q, k = norm(q), norm(k) +scores = q @ k.transpose(-2, -1) +``` + +**Why:** +- Stabilizes training +- Prevents exploding logits +- Enables higher learning rates + +--- + +## Common Shape Transformations + +### Input โ†’ Multi-Head +```python +# Input: (B, T, C) +x = x.view(B, T, n_head, head_dim) # Split into heads +x = x.transpose(1, 2) # (B, n_head, T, head_dim) +``` + +### Multi-Head โ†’ Output +```python +# Input: (B, n_head, T, head_dim) +x = x.transpose(1, 2) # (B, T, n_head, head_dim) +x = x.contiguous().view(B, T, n_embd) # Merge heads +``` + +### Attention Computation +```python +# Q, K, V: (B, n_head, T, head_dim) +scores = q @ k.transpose(-2, -1) # (B, n_head, T, T) +attn = softmax(scores / sqrt(head_dim)) # (B, n_head, T, T) +output = attn @ v # (B, n_head, T, head_dim) +``` + +--- + +## Code Snippets + +### Basic Bidirectional Attention +```python +class BidirectionalAttention(nn.Module): + def __init__(self, n_embd, n_head): + super().__init__() + self.n_head = n_head + self.head_dim = n_embd // n_head + + self.c_q = nn.Linear(n_embd, n_embd, bias=False) + self.c_k = nn.Linear(n_embd, n_embd, bias=False) + self.c_v = nn.Linear(n_embd, n_embd, bias=False) + self.c_proj = nn.Linear(n_embd, n_embd, bias=False) + + def forward(self, x): + B, T, C = x.size() + + # Project to Q, K, V + q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) + k = self.c_k(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) + v = self.c_v(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) + + # Bidirectional attention + y = F.scaled_dot_product_attention(q, k, v, is_causal=False) + + # Merge heads + y = y.transpose(1, 2).contiguous().view(B, T, C) + return self.c_proj(y) +``` + +### Causal Attention (for comparison) +```python +# Only change: is_causal=True +y = F.scaled_dot_product_attention(q, k, v, is_causal=True) +``` + +### Manual Attention (for understanding) +```python +# Explicit implementation of what scaled_dot_product_attention does +scores = (q @ k.transpose(-2, -1)) / math.sqrt(head_dim) + +# For causal: mask upper triangle +if is_causal: + mask = torch.triu(torch.ones(T, T), diagonal=1).bool() + scores = scores.masked_fill(mask, float('-inf')) + +attn_weights = F.softmax(scores, dim=-1) +output = attn_weights @ v +``` + +--- + +## Debugging Tips + +### Check Shapes +```python +print(f"Input: {x.shape}") # (B, T, C) +print(f"Q: {q.shape}") # (B, n_head, T, head_dim) +print(f"Scores: {scores.shape}") # (B, n_head, T, T) +print(f"Output: {output.shape}") # (B, T, C) +``` + +### Visualize Attention +```python +# Extract attention weights +with torch.no_grad(): + scores = (q @ k.transpose(-2, -1)) / math.sqrt(head_dim) + attn = F.softmax(scores, dim=-1) + +# Plot +import matplotlib.pyplot as plt +plt.imshow(attn[0, 0].cpu(), cmap='viridis') +plt.title('Attention Pattern - Head 0') +plt.xlabel('Key Position') +plt.ylabel('Query Position') +plt.colorbar() +plt.show() +``` + +### Common Issues +```python +# Issue: Shape mismatch after transpose +# Solution: Use .contiguous() before .view() +x = x.transpose(1, 2).contiguous().view(B, T, -1) + +# Issue: Attention is all NaN +# Solution: Check for -inf masking or numerical instability +# - Add small epsilon: scores / (math.sqrt(d) + 1e-8) +# - Use QK normalization + +# Issue: Out of memory +# Solution: Use Flash Attention or reduce batch size +# - Flash Attention: Automatic in PyTorch 2.0+ when available +# - Or implement custom memory-efficient attention +``` + +--- + +## Performance Tips + +### Use Flash Attention +```python +# PyTorch 2.0+ automatically uses Flash Attention when possible +# No code changes needed! +y = F.scaled_dot_product_attention(q, k, v, is_causal=False) +``` + +### Mixed Precision Training +```python +from torch.cuda.amp import autocast + +with autocast(): + output = attention(x) +``` + +### Gradient Checkpointing +```python +from torch.utils.checkpoint import checkpoint + +# Trade memory for computation +output = checkpoint(attention_layer, x) +``` + +--- + +## Testing Checklist + +- [ ] Input shape: (batch, seq_len, n_embd) +- [ ] Output shape matches input shape +- [ ] Attention weights sum to 1.0 per query +- [ ] Causal mask applied correctly (if causal) +- [ ] No NaN or Inf values in output +- [ ] Memory usage reasonable +- [ ] Forward/backward pass completes + +--- + +## Resources + +**Files in this repo:** +- `model.py` - Annotated implementation +- `attention_comparison.py` - Causal vs bidirectional demo +- `visualize_model_attention.py` - Visualize trained model +- `docs/bidirectional_attention_tutorial.md` - Deep dive +- `LEARNING_GUIDE.md` - Structured learning path + +**Papers:** +- Attention Is All You Need: https://arxiv.org/abs/1706.03762 +- BERT: https://arxiv.org/abs/1810.04805 +- RoPE: https://arxiv.org/abs/2104.09864 +- Flash Attention: https://arxiv.org/abs/2205.14135 + +**Blogs:** +- The Illustrated Transformer: https://jalammar.github.io/illustrated-transformer/ +- RoPE Explained: https://blog.eleuther.ai/rotary-embeddings/ diff --git a/model.py b/model.py index 35cd7ed..e375be1 100644 --- a/model.py +++ b/model.py @@ -37,57 +37,306 @@ class DiffusionConfig: def norm(x): + """ + RMSNorm (Root Mean Square Normalization) - Modern normalization technique + + Used in: LLaMA, Mistral, and many recent LLMs + Formula: RMSNorm(x) = x / sqrt(mean(xยฒ) + ฮต) + + Advantages over LayerNorm: + - Simpler: No learnable parameters (scale/bias) + - Faster: Fewer operations + - Effective: Works just as well in practice + + LayerNorm vs RMSNorm: + - LayerNorm: y = (x - mean(x)) / sqrt(var(x)) * gamma + beta + - RMSNorm: y = x / sqrt(mean(xยฒ)) + + RMSNorm focuses only on the scale of the input, not the mean, + which is sufficient for attention mechanisms in transformers. + + Args: + x: Input tensor of any shape + Returns: + Normalized tensor with same shape + """ # Purely functional rmsnorm with no learnable params return F.rms_norm(x, (x.size(-1),)) def apply_rotary_emb(x, cos, sin): - assert x.ndim == 4 # multihead attention - d = x.shape[3] // 2 - x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves - y1 = x1 * cos + x2 * sin # rotate pairs of dims - y2 = x1 * (-sin) + x2 * cos - out = torch.cat([y1, y2], 3) # re-assemble - out = out.to(x.dtype) # ensure input/output dtypes match + """ + Apply Rotary Position Embeddings (RoPE) + + RoPE is a modern technique for encoding position information in transformers. + Used in: LLaMA, PaLM, GPT-NeoX, Code LLaMA, and many recent models. + + Key insight: Instead of adding position embeddings, RoPE rotates the + embedding vectors by an angle proportional to their position. + + Mathematical intuition: + - Treat pairs of dimensions as 2D vectors + - Rotate each pair by position-dependent angle ฮธ_pos + - Tokens at position m and n have relative rotation ฮธ_m - ฮธ_n + - Dot product naturally encodes relative position! + + Benefits: + 1. Relative position encoding: Distance between tokens matters, not absolute position + 2. Better extrapolation: Can handle longer sequences than seen during training + 3. No extra parameters: Position info encoded through geometry + 4. Efficient: Just element-wise multiplications + + 2D Rotation formula (what we're doing to each pair of dimensions): + [x1'] [cos ฮธ -sin ฮธ] [x1] + [x2'] = [sin ฮธ cos ฮธ] [x2] + + Which expands to: + x1' = x1 * cos ฮธ + x2 * sin ฮธ + x2' = x1 * (-sin ฮธ) + x2 * cos ฮธ + + Args: + x: Input tensor of shape (batch, n_head, seq_len, head_dim) + cos: Cosine values for rotation, shape (1, seq_len, 1, head_dim//2) + sin: Sine values for rotation, shape (1, seq_len, 1, head_dim//2) + + Returns: + Rotated tensor of same shape as input + + Example: + If head_dim = 64, we treat it as 32 pairs of 2D vectors + Each pair is rotated by a different frequency + Low frequencies encode coarse position info + High frequencies encode fine-grained position info + """ + assert x.ndim == 4, "Expected 4D tensor for multi-head attention" # (B, n_head, T, head_dim) + + # Split embedding dimension into two halves + # Each consecutive pair (x1, x2) will be rotated together + d = x.shape[3] // 2 # Half of head_dim + x1, x2 = x[..., :d], x[..., d:] # Split along last dimension + + # Apply 2D rotation to each pair + # This is the core of RoPE: geometric rotation in embedding space + y1 = x1 * cos + x2 * sin # First component after rotation + y2 = x1 * (-sin) + x2 * cos # Second component after rotation + + # Concatenate rotated pairs back together + out = torch.cat([y1, y2], dim=3) # Re-assemble along last dimension + + # Ensure output dtype matches input (important for mixed precision training) + out = out.to(x.dtype) + return out class BidirectionalAttention(nn.Module): + """ + Bidirectional Multi-Head Self-Attention + + This is the core component that makes this diffusion model different from GPT. + Key insight: BIDIRECTIONAL means each token can attend to ALL other tokens, + not just previous ones. This enables parallel decoding during diffusion. + + Architecture components: + - Multi-head attention: Multiple attention patterns learned in parallel + - RoPE (Rotary Position Embeddings): Encodes position information + - QK Normalization: Stabilizes training by normalizing queries and keys + + Comparison with GPT: + - GPT: Causal mask (lower triangular) - each token sees only past + - This: No mask (full matrix) - each token sees everything + + Mathematical formula: + Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V + + Where: + - Q (Query): "What am I looking for?" + - K (Key): "What information do I have?" + - V (Value): "What information should I pass forward?" + """ + def __init__(self, config): super().__init__() - self.n_head = config.n_head - self.n_embd = config.n_embd - self.head_dim = self.n_embd // self.n_head - assert self.n_embd % self.n_head == 0 + self.n_head = config.n_head # Number of attention heads (6 in default config) + self.n_embd = config.n_embd # Embedding dimension (384 in default config) + self.head_dim = self.n_embd // self.n_head # Dimension per head (384/6 = 64) + + # Ensure embedding dimension is evenly divisible by number of heads + assert self.n_embd % self.n_head == 0, "n_embd must be divisible by n_head" + + # Three separate learned projections for Q, K, V + # Why three? They learn different roles: + # - c_q: Projects input to "queries" (what to look for) + # - c_k: Projects input to "keys" (what to advertise) + # - c_v: Projects input to "values" (what to communicate) + # bias=False follows modern practice (e.g., LLaMA, GPT-2) self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False) + + # Output projection - combines information from all heads self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) def forward(self, x, cos_sin): - B, T, C = x.size() + """ + Forward pass through bidirectional attention + + Args: + x: Input tensor of shape (batch, seq_len, n_embd) + - batch: Number of sequences being processed in parallel + - seq_len: Number of tokens in each sequence + - n_embd: Embedding dimension for each token + cos_sin: Tuple of (cos, sin) for rotary embeddings + Used to encode positional information - # Project the input to get queries, keys, and values + Returns: + y: Output tensor of shape (batch, seq_len, n_embd) + Contextually enriched representations after attention + + Step-by-step process: + 1. Project input to Q, K, V + 2. Split into multiple heads + 3. Apply RoPE (positional encoding) + 4. Normalize Q and K + 5. Compute attention: softmax(QK^T / sqrt(d)) V + 6. Merge heads back together + 7. Final output projection + """ + B, T, C = x.size() # Batch size, sequence length (Time), embedding dimension (Channels) + + # ============================================================================ + # STEP 1: Project input to Queries, Keys, and Values + # ============================================================================ + # Each position gets three different representations: + # q[i]: "What patterns should position i look for?" + # k[i]: "What patterns does position i contain?" + # v[i]: "What information should position i contribute?" + # + # Shape transitions: + # x: (B, T, C) โ†’ linear projection โ†’ (B, T, C) โ†’ view โ†’ (B, T, n_head, head_dim) + # + # Example with default config (B=32, T=256, C=384, n_head=6, head_dim=64): + # (32, 256, 384) โ†’ c_q โ†’ (32, 256, 384) โ†’ view โ†’ (32, 256, 6, 64) q = self.c_q(x).view(B, T, self.n_head, self.head_dim) k = self.c_k(x).view(B, T, self.n_head, self.head_dim) v = self.c_v(x).view(B, T, self.n_head, self.head_dim) - # Apply Rotary Embeddings to queries and keys + # ============================================================================ + # STEP 2: Apply Rotary Position Embeddings (RoPE) + # ============================================================================ + # RoPE is a modern technique for encoding positional information + # Used in: LLaMA, PaLM, GPT-NeoX, and many recent LLMs + # + # Why RoPE? + # - Encodes relative positions through rotation in embedding space + # - Better extrapolation to longer sequences than learned embeddings + # - Doesn't require explicit position embeddings + # + # How it works: + # - Treats pairs of dimensions as 2D coordinates + # - Rotates them by an angle proportional to position + # - Tokens close together have similar rotation angles + # + # Applied to Q and K (not V) because attention uses QยทK^T for similarity cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + + # ============================================================================ + # STEP 3: QK Normalization + # ============================================================================ + # Normalize queries and keys before computing attention + # This is a modern technique that improves training stability + # + # Benefits: + # - Prevents attention logits from exploding + # - More stable gradients during training + # - Allows for higher learning rates + # + # norm() is RMSNorm (Root Mean Square Normalization): + # RMSNorm(x) = x / sqrt(mean(xยฒ) + ฮต) + # Simpler and more efficient than LayerNorm, used in LLaMA q, k = norm(q), norm(k) # QK norm + + # ============================================================================ + # STEP 4: Reshape for multi-head attention computation + # ============================================================================ + # Transpose to put head dimension before sequence dimension + # This allows us to compute attention for all heads in parallel + # + # Shape transition: (B, T, n_head, head_dim) โ†’ (B, n_head, T, head_dim) + # + # Why? Matrix multiplication will operate on the last two dimensions, + # so we want: (B, n_head, T_q, head_dim) @ (B, n_head, head_dim, T_k) + # = (B, n_head, T_q, T_k) โ† attention scores q, k, v = ( - q.transpose(1, 2), + q.transpose(1, 2), # (B, T, H, D) -> (B, H, T, D) k.transpose(1, 2), v.transpose(1, 2), - ) # (B, T, H, D) -> (B, H, T, D) + ) - # Bidirectional attention - no causal masking + # ============================================================================ + # STEP 5: Compute Bidirectional Attention + # ============================================================================ + # This is THE KEY LINE that makes this bidirectional! + # + # F.scaled_dot_product_attention implements: + # 1. scores = (Q @ K^T) / sqrt(head_dim) # Attention scores + # 2. if is_causal: scores = mask_upper_triangle(scores) # NOT APPLIED HERE! + # 3. attn_weights = softmax(scores, dim=-1) # Normalize to probabilities + # 4. output = attn_weights @ V # Weighted combination of values + # + # is_causal=False means NO MASKING: + # - Each position can attend to ALL positions (past, present, future) + # - Attention matrix is FULL, not lower triangular + # - Example attention pattern for "The cat sat": + # The cat sat + # The โœ“ โœ“ โœ“ (can see all tokens) + # cat โœ“ โœ“ โœ“ (can see all tokens) + # sat โœ“ โœ“ โœ“ (can see all tokens) + # + # Compare with causal (is_causal=True): + # The cat sat + # The โœ“ โœ— โœ— (can only see "The") + # cat โœ“ โœ“ โœ— (can only see "The" and "cat") + # sat โœ“ โœ“ โœ“ (can see all tokens) + # + # Why bidirectional for diffusion? + # - We iteratively refine tokens in parallel + # - Each token needs context from neighbors to stay coherent + # - Example: [M][M][M] โ†’ [M]"cat"[M] โ†’ "The""cat"[M] โ†’ "The""cat""sat" + # "The" needs to see "cat" to know it should be capitalized + # "cat" needs to see "The" to know it should be lowercase + # + # Performance note: PyTorch's scaled_dot_product_attention is highly optimized + # Uses Flash Attention when available (https://arxiv.org/abs/2205.14135) y = F.scaled_dot_product_attention(q, k, v, is_causal=False) - # Re-assemble the heads and project back + # Output shape: (B, n_head, T, head_dim) + # Each head has computed its own attention pattern + + # ============================================================================ + # STEP 6: Merge heads back together + # ============================================================================ + # We now have outputs from all heads, need to combine them + # + # Shape transitions: + # (B, n_head, T, head_dim) โ†’ transpose โ†’ (B, T, n_head, head_dim) + # โ†’ view โ†’ (B, T, n_embd) + # + # Example: (32, 6, 256, 64) โ†’ transpose โ†’ (32, 256, 6, 64) + # โ†’ view โ†’ (32, 256, 384) + # + # .contiguous() ensures memory layout is correct for view() y = y.transpose(1, 2).contiguous().view(B, T, -1) + + # ============================================================================ + # STEP 7: Final output projection + # ============================================================================ + # Mix information from all heads with a learned projection + # This allows heads to be combined in useful ways + # Shape: (B, T, n_embd) โ†’ (B, T, n_embd) y = self.c_proj(y) + return y diff --git a/visualize_model_attention.py b/visualize_model_attention.py new file mode 100644 index 0000000..cc07f25 --- /dev/null +++ b/visualize_model_attention.py @@ -0,0 +1,347 @@ +""" +Visualize Attention Patterns from the Trained Diffusion Model + +This script loads the trained model and visualizes its actual attention patterns +during text generation, showing how bidirectional attention enables the diffusion process. +""" + +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt +import numpy as np +from model import DiffusionTransformer, DiffusionConfig, encode_text, decode_tokens + + +class AttentionExtractor: + """ + Helper class to extract attention weights from the model during forward pass + """ + def __init__(self): + self.attention_weights = [] + + def hook_fn(self, module, input, output): + """Hook function to capture attention weights""" + # This will be called during forward pass + # We'll store the attention scores before they're used + pass + + def clear(self): + """Clear stored attention weights""" + self.attention_weights = [] + + +def visualize_attention_during_diffusion( + model, + context_text="KING HENRY:", + max_steps=10, + seq_len=64, + save_path="diffusion_attention_evolution.png" +): + """ + Visualize how attention patterns evolve during the diffusion denoising process + + Args: + model: Trained DiffusionTransformer + context_text: Initial context for generation + max_steps: Number of diffusion steps to visualize + seq_len: Sequence length to generate + save_path: Path to save visualization + """ + device = model.get_device() + model.eval() + + # Encode context + context_tokens = encode_text(context_text) + context_len = len(context_tokens) + batch_size = 1 + + # Start from masked sequence + x = torch.full( + (batch_size, seq_len), + model.config.mask_token_id, + dtype=torch.long, + device=device, + ) + x[:, :context_len] = context_tokens[:context_len].to(device) + + # Track evolution + snapshots = [] + attention_snapshots = [] + + # Track which positions are masked + masked_positions = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device) + masked_positions[:, :context_len] = False + + # Simplified diffusion process with visualization + steps_to_visualize = np.linspace(0, min(max_steps, 50), min(max_steps, 10), dtype=int) + + for step in range(min(max_steps, 50)): + if not masked_positions.any(): + break + + # Forward pass + t_batch = torch.full((batch_size,), step, device=device, dtype=torch.long) + t_batch = torch.clamp(t_batch, 0, model.config.diffusion_steps - 1) + + with torch.no_grad(): + # Get logits + logits = model.forward(x, t_batch) + + # Decode with confidence threshold + probs = F.softmax(logits / 1.0, dim=-1) + confidences, predicted_tokens = torch.max(probs, dim=-1) + + # Select positions above threshold + above_threshold = (confidences >= 0.95) & masked_positions + + # Ensure at least one token per batch + if masked_positions[0].any() and not above_threshold[0].any(): + masked_confidences = confidences[0].clone() + masked_confidences[~masked_positions[0]] = -float("inf") + best_idx = torch.argmax(masked_confidences) + above_threshold[0, best_idx] = True + + # Update + x = torch.where(above_threshold, predicted_tokens, x) + masked_positions = masked_positions & ~above_threshold + + # Store snapshots at key steps + if step in steps_to_visualize: + snapshots.append({ + 'step': step, + 'text': decode_tokens(x[0].cpu()), + 'masked': masked_positions[0].cpu().clone(), + 'confidences': confidences[0].cpu().clone() + }) + + # Create visualization + n_snapshots = len(snapshots) + fig, axes = plt.subplots(n_snapshots, 1, figsize=(16, 3*n_snapshots)) + if n_snapshots == 1: + axes = [axes] + + fig.suptitle(f'Diffusion Process: Iterative Denoising with Bidirectional Attention\nContext: "{context_text}"', + fontsize=14, fontweight='bold') + + for idx, snapshot in enumerate(snapshots): + ax = axes[idx] + + # Get data + text = snapshot['text'] + masked = snapshot['masked'].numpy() + confidences = snapshot['confidences'].numpy() + + # Create color map: masked=red, unmasked=green + colors = np.where(masked, 0.3, 1.0) # Darker for masked + + # Plot + tokens = [text[i:i+1] for i in range(len(text))] + x_pos = np.arange(len(tokens)) + + bars = ax.bar(x_pos, confidences, color=plt.cm.RdYlGn(colors), alpha=0.8, edgecolor='black', linewidth=0.5) + + # Add token labels + for i, (token, conf, is_masked) in enumerate(zip(tokens, confidences, masked)): + # Clean up special characters for display + display_char = repr(token)[1:-1] if ord(token) < 32 else token + label = f"{display_char}\n{'[M]' if is_masked else f'{conf:.2f}'}" + ax.text(i, -0.1, label, ha='center', va='top', fontsize=8, + fontweight='bold' if not is_masked else 'normal') + + ax.set_xlim(-0.5, len(tokens)-0.5) + ax.set_ylim(-0.2, 1.1) + ax.set_ylabel('Confidence', fontweight='bold') + ax.set_title(f'Step {snapshot["step"]} | Masked: {masked.sum()}/{len(masked)} tokens', + fontweight='bold', loc='left') + ax.axhline(y=0.95, color='r', linestyle='--', linewidth=2, alpha=0.5, label='Threshold (0.95)') + ax.legend(loc='upper right') + ax.set_xticks([]) + ax.grid(axis='y', alpha=0.3) + + plt.tight_layout() + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"\nโœ… Visualization saved to: {save_path}") + plt.close() + + # Print final generated text + print(f"\n{'='*70}") + print("FINAL GENERATED TEXT") + print('='*70) + print(decode_tokens(x[0].cpu())) + print('='*70) + + +def compute_attention_pattern(model, input_tokens, timestep=0): + """ + Compute and visualize the actual attention pattern from the model + + This extracts the Q, K matrices from the first attention layer + and computes the attention scores manually for visualization + """ + device = model.get_device() + model.eval() + + with torch.no_grad(): + # Prepare input + x_t = input_tokens.unsqueeze(0).to(device) # Add batch dim + t = torch.tensor([timestep], device=device) + B, T = x_t.size() + + # Get embeddings (replicate model's forward pass up to first attention layer) + x = model.token_emb(x_t) + t_emb = model.time_emb(t) + x = x + t_emb.unsqueeze(1) + x = F.rms_norm(x, (x.size(-1),)) + + # Get first attention layer + first_attn = model.blocks[0].attn + + # Compute Q, K + q = first_attn.c_q(x).view(B, T, first_attn.n_head, first_attn.head_dim) + k = first_attn.c_k(x).view(B, T, first_attn.n_head, first_attn.head_dim) + + # Apply RoPE (simplified - without RoPE for clarity) + q = q.transpose(1, 2) # (B, n_head, T, head_dim) + k = k.transpose(1, 2) + + # Compute attention scores + scores = torch.matmul(q, k.transpose(-2, -1)) / (first_attn.head_dim ** 0.5) + attn_weights = F.softmax(scores, dim=-1) + + return attn_weights.cpu() + + +def visualize_actual_attention_matrix( + model, + text="KING HENRY: The king", + save_path="actual_attention_pattern.png" +): + """ + Visualize the actual attention matrix from the trained model + """ + # Encode text + tokens = encode_text(text) + + # Get attention weights + attn_weights = compute_attention_pattern(model, tokens, timestep=0) + attn_weights = attn_weights[0].numpy() # Remove batch dim + + # Create visualization + n_heads = attn_weights.shape[0] + fig, axes = plt.subplots(1, n_heads, figsize=(5*n_heads, 4)) + if n_heads == 1: + axes = [axes] + + fig.suptitle(f'Bidirectional Attention Patterns from Trained Model\nInput: "{text}"', + fontsize=14, fontweight='bold') + + token_labels = [repr(text[i])[1:-1] if ord(text[i]) < 32 else text[i] for i in range(len(text))] + + for head_idx in range(n_heads): + ax = axes[head_idx] + + im = ax.imshow(attn_weights[head_idx], cmap='viridis', aspect='auto', vmin=0, vmax=1) + ax.set_title(f'Head {head_idx}', fontweight='bold') + ax.set_xlabel('Key Position (what we look at)', fontweight='bold') + ax.set_ylabel('Query Position (who is looking)', fontweight='bold') + + # Add token labels + ax.set_xticks(range(len(tokens))) + ax.set_yticks(range(len(tokens))) + ax.set_xticklabels(token_labels, rotation=45, ha='right') + ax.set_yticklabels(token_labels) + + # Add colorbar + plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + + # Add grid + ax.grid(True, alpha=0.3, color='white', linewidth=0.5) + + plt.tight_layout() + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"\nโœ… Attention matrix visualization saved to: {save_path}") + plt.close() + + # Print analysis + print(f"\n{'='*70}") + print("ATTENTION PATTERN ANALYSIS") + print('='*70) + print(f"\nInput text: '{text}'") + print(f"Sequence length: {len(tokens)}") + print(f"Number of heads: {n_heads}") + + print("\n" + "-"*70) + print("Key Observations:") + print("-"*70) + print("1. Each position can attend to ALL positions (full matrix)") + print("2. No causal mask - darker squares appear everywhere, not just below diagonal") + print("3. Different heads learn different patterns:") + for head_idx in range(min(n_heads, 3)): + avg_attn = attn_weights[head_idx].mean() + std_attn = attn_weights[head_idx].std() + print(f" Head {head_idx}: mean={avg_attn:.3f}, std={std_attn:.3f}") + print("\n4. This bidirectional attention enables:") + print(" - Each token to gather context from neighbors") + print(" - Parallel refinement during diffusion") + print(" - Coherent generation without left-to-right constraint") + print('='*70) + + +def main(): + """ + Main visualization routine + """ + print("\n" + "="*70) + print("TRAINED MODEL ATTENTION VISUALIZATION") + print("="*70) + + # Load model + print("\n[1/3] Loading trained model...") + config = DiffusionConfig() + model = DiffusionTransformer(config) + + try: + model.load_state_dict(torch.load('weights/diffusion_model.pt', map_location='cpu')) + print("โœ… Model loaded successfully from weights/diffusion_model.pt") + except FileNotFoundError: + print("โš ๏ธ No trained weights found. Using random initialization.") + print(" (Patterns will still show bidirectional structure, just not trained)") + model.init_weights() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + model.eval() + + # Visualize attention matrix + print("\n[2/3] Visualizing attention patterns from trained model...") + visualize_actual_attention_matrix( + model, + text="KING HENRY: The", + save_path="actual_attention_pattern.png" + ) + + # Visualize diffusion process + print("\n[3/3] Visualizing diffusion denoising process...") + visualize_attention_during_diffusion( + model, + context_text="KING HENRY:", + max_steps=20, + seq_len=40, + save_path="diffusion_attention_evolution.png" + ) + + print("\n" + "="*70) + print("โœ… ALL VISUALIZATIONS COMPLETE") + print("="*70) + print("\nGenerated files:") + print("1. actual_attention_pattern.png - Attention matrix from trained model") + print("2. diffusion_attention_evolution.png - How text evolves during generation") + print("\nNext steps:") + print("- Open the PNG files to see the visualizations") + print("- Try different input texts to see different patterns") + print("- Compare with attention_comparison.py to see causal vs bidirectional") + print("- Read docs/bidirectional_attention_tutorial.md for deep understanding") + + +if __name__ == "__main__": + main()