|
| 1 | +# Megatron + Ray Fault Tolerant Training |
| 2 | + |
| 3 | +This example implements PPO-style distributed training using Megatron and Ray with comprehensive fault tolerance capabilities. The system can automatically recover from actor failures during training by utilizing backup actors and re-initializing process groups. |
| 4 | + |
| 5 | +## Key Features |
| 6 | + |
| 7 | +### Fault Tolerance Mechanisms |
| 8 | + |
| 9 | +1. **Actor Health Monitoring**: Continuously monitors the health of distributed training actors |
| 10 | +2. **Backup Actor Pool**: Pre-allocated backup actors ready to replace failed workers |
| 11 | +3. **Automatic Recovery**: Seamlessly recovers from failures by: |
| 12 | + - Detecting dead actors |
| 13 | + - Destroying old process groups |
| 14 | + - Replacing failed actors with backup actors |
| 15 | + - Re-initializing process groups with new world size |
| 16 | + - Reloading model and optimizer state from checkpoints |
| 17 | + |
| 18 | +4. **Distributed Checkpointing**: Implements efficient sharded checkpoint saving/loading using Megatron's distributed checkpointing |
| 19 | +5. **Process Group Management**: Handles NCCL process group initialization, destruction, and re-initialization |
| 20 | + |
| 21 | +### Parallelism Support |
| 22 | + |
| 23 | +- **Data Parallelism (DP)**: Distributes training data across multiple GPUs |
| 24 | +- **Tensor Parallelism (TP)**: Splits model tensors across GPUs |
| 25 | +- **Pipeline Parallelism (PP)**: Distributes model layers across GPUs |
| 26 | +- **Context Parallelism (CP)**: Enables sequence parallelism for long contexts |
| 27 | + |
| 28 | +### Advanced Training Features |
| 29 | + |
| 30 | +- **PPO Training**: Implements Proximal Policy Optimization with micro-batch accumulation |
| 31 | +- **Mixed Precision**: Supports BF16 training for improved performance |
| 32 | +- **Gradient Accumulation**: Handles micro-batches with automatic gradient accumulation |
| 33 | +- **Distributed Optimizer**: Uses Megatron's distributed optimizer for memory efficiency |
| 34 | + |
| 35 | +## Architecture |
| 36 | + |
| 37 | +### Core Components |
| 38 | + |
| 39 | +1. **MegatronActor** (`megatron_actor.py`): |
| 40 | + - Individual training actor wrapping Megatron models |
| 41 | + - Handles model initialization, forward/backward passes, and checkpointing |
| 42 | + - Supports dynamic process group re-initialization |
| 43 | + |
| 44 | +2. **MegatronActorGroup** (`megatron_actor.py`): |
| 45 | + - Manages a group of distributed actors |
| 46 | + - Implements fault recovery logic |
| 47 | + - Coordinates distributed training operations |
| 48 | + |
| 49 | +3. **Dispatch System** (`dispatch.py`): |
| 50 | + - **MeshDispatch**: Distributes data across the device mesh (DP, SP, TP, PP) |
| 51 | + - **PassThroughDispatch**: Broadcasts same data/commands to all actors |
| 52 | + - Handles data sharding and result collection |
| 53 | + |
| 54 | +4. **Training Batch** (`training_batch.py`): |
| 55 | + - Defines input/output batch structures for PPO training |
| 56 | + - Supports chunking and concatenation for distributed operations |
| 57 | + |
| 58 | +5. **Checkpoint I/O** (`file_io.py`): |
| 59 | + - Cloud-aware file I/O supporting S3, GCS, and local storage |
| 60 | + - Efficient checkpoint upload/download with parallel transfers |
| 61 | + |
| 62 | +## Getting Started |
| 63 | + |
| 64 | +### Quick Start |
| 65 | + |
| 66 | +```bash |
| 67 | +uv run --isolated main.py |
| 68 | +``` |
| 69 | + |
| 70 | +This will: |
| 71 | +1. Create a placement group with workers and backup GPUs |
| 72 | +2. Initialize the actor group and model |
| 73 | +3. Run a training step |
| 74 | +4. Save a checkpoint |
| 75 | +5. Simulate a failure by killing actors |
| 76 | +6. Recover from the failure using backup actors |
| 77 | +7. Resume training after recovery |
| 78 | + |
| 79 | +### Configuration |
| 80 | + |
| 81 | +Edit the `Config` class in `main.py` to customize: |
| 82 | + |
| 83 | +```python |
| 84 | +@dataclass |
| 85 | +class Config: |
| 86 | + model: str = "Qwen/Qwen3-0.6B" # HuggingFace model name |
| 87 | + num_nodes: int = 1 |
| 88 | + num_gpus_per_node: int = 4 |
| 89 | + num_spare_gpus: int = 4 # Backup actors for fault tolerance |
| 90 | + mini_batch_size: int = 16 |
| 91 | + micro_train_batch_size_per_gpu: int = 2 |
| 92 | + |
| 93 | + # Megatron parallelism settings |
| 94 | + megatron_config: MegatronConfig = field(default_factory=MegatronConfig) |
| 95 | +``` |
| 96 | + |
| 97 | +### Megatron Parallelism Configuration |
| 98 | + |
| 99 | +```python |
| 100 | +@dataclass |
| 101 | +class MegatronConfig: |
| 102 | + tensor_model_parallel_size: int = 1 # TP degree |
| 103 | + pipeline_model_parallel_size: int = 1 # PP degree |
| 104 | + context_parallel_size: int = 1 # CP degree |
| 105 | + expert_model_parallel_size: int = 1 # For MoE models |
| 106 | +``` |
| 107 | + |
| 108 | +## Fault Recovery Workflow |
| 109 | + |
| 110 | +1. **Training Phase**: |
| 111 | + - Actors perform distributed training using Megatron |
| 112 | + - Periodic checkpoints saved to cloud storage |
| 113 | + |
| 114 | +2. **Failure Detection**: |
| 115 | + - System detects actor failures via health checks |
| 116 | + - Identifies affected data parallel groups |
| 117 | + |
| 118 | +3. **Recovery Process**: |
| 119 | + - Destroy old process groups on healthy actors |
| 120 | + - Pop backup actors from the backup pool |
| 121 | + - Insert backup actors at failed ranks |
| 122 | + - Update world size and reassign ranks |
| 123 | + - Re-initialize process groups with new configuration |
| 124 | + - Reload model/optimizer state from checkpoint |
| 125 | + |
| 126 | +4. **Resume Training**: |
| 127 | + - Continue training with recovered actor group |
| 128 | + - No loss of training progress (from last checkpoint) |
| 129 | + |
| 130 | +## Advanced Usage |
| 131 | + |
| 132 | +### Custom Dispatch Types |
| 133 | + |
| 134 | +Register custom dispatch strategies: |
| 135 | + |
| 136 | +```python |
| 137 | +from dispatch import register_dispatch_type, Dispatch |
| 138 | + |
| 139 | +class CustomDispatch(Dispatch): |
| 140 | + # Implement dispatch, collect, and validate methods |
| 141 | + pass |
| 142 | + |
| 143 | +register_dispatch_type("custom", CustomDispatch) |
| 144 | +``` |
| 145 | + |
| 146 | +### CPU Offloading (Experimental) |
| 147 | + |
| 148 | +For faster recovery, offload model/optimizer state to CPU memory: |
| 149 | + |
| 150 | +```python |
| 151 | +# Before failure |
| 152 | +ray.get(actor_group.async_run_ray_method("pass_through", "offload_to_cpu")) |
| 153 | + |
| 154 | +# After recovery, on healthy actors |
| 155 | +ray.get(actor_group.async_run_ray_method("pass_through", "backload_to_gpu")) |
| 156 | +``` |
| 157 | + |
| 158 | +## Dependencies |
| 159 | + |
| 160 | +See `pyproject.toml` for full dependency list. Key dependencies: |
| 161 | +- Ray for distributed orchestration |
| 162 | +- Megatron-Core for model parallelism |
| 163 | +- PyTorch with CUDA support |
| 164 | +- Transformers for model loading |
| 165 | +- vLLM and related libraries |
| 166 | + |
| 167 | +## Running on Anyscale |
| 168 | + |
| 169 | +Submit the job using: |
| 170 | + |
| 171 | +```bash |
| 172 | +anyscale job submit -f job.yaml |
| 173 | +``` |
| 174 | + |
| 175 | +The job configuration in `job.yaml` specifies: |
| 176 | +- Container image with dependencies |
| 177 | +- GPU instance types (g6e.12xlarge with 4xL4) |
| 178 | +- Resource limits and scaling |
| 179 | +- Environment variables for NCCL configuration |
| 180 | + |
| 181 | +## Limitations and Future Work |
| 182 | + |
| 183 | +- Virtual pipeline parallelism not yet supported |
| 184 | +- CPU offloading optimization in progress |
| 185 | +- Async checkpoint saving planned for future releases |
| 186 | + |
| 187 | +## References |
| 188 | + |
| 189 | +- [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) |
| 190 | +- [Ray Documentation](https://docs.ray.io/) |
| 191 | +- [Anyscale Platform](https://docs.anyscale.com/) |
0 commit comments