-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Description
When using PEFT LoRA with modules_to_save (e.g. ['wte','ff_out']) under DeepSpeed ZeRO-3, the training process either fails with shape mismatches or leads to large GPU memory usage.
It seems that modules_to_save parameters are not sharded like the rest of the model under ZeRO-3, which may cause them to remain in full precision on every GPU.
Reproduction
1. Load a model with LoRA using PEFT
2. Enable modules_to_save=['wte','ff_out']
3. Launch training with DeepSpeed ZeRO-3 (zero3_init_flag=False)
4. Observe that:
• wte.original_module.weight may have shape [0, hidden_size] or mismatch at load time
• Memory usage increases drastically (batch size must drop to 1)
Expected behavior
The modules_to_save parameters should be properly partitioned under ZeRO-3, so that LoRA training can proceed normally without excessive memory usage.