-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Description
Problem Statement
Today we have multiple sets of ProcessGroup management in the codebase for different parallel scenarios, namely:
PipelineModuleusesPipelineParallelGridfor pipeline parallelism.- AutoTP uses one set of APIs (whose names start with letters) from
utils.groups. MoEuses another set of APIs (whose names start with_) fromutils.groups.- Ulysses SP uses a DeviceMesh created on
deepspeed.initialize.
All of those implementations deal with the process group creation (either using DeviceMesh or customized logic) under complex combinations. It significantly increases maintenance burden as well as difficulty in introducing new parallelism techniques.
Besides, that divergence also causes user confusions. deepspeed.initialize() accepts two parameters, i.e., mpu and mesh_param. mpu is for DP/TP (without PP/EP/SP), mesh_param for SP, and for PP/EP parallelism topology should be configured on model initialization. That does not provide a good experience for users who want to try out different configurations for best efficiency.
Proposal
We would like to refactor those ProcessGroup management facilities so that:
- A unified
DeviceMeshbased module serves all parallelism strategies and their combinations. - Let
DeviceMeshinstances create and manage ProcessGroups when possible. - Simplify the mesh topology configuration interface of
deepseed.initialize. - (Advanced) Allow different models in the same world to use different parallelism strategies. This is mainly for RL post-training, but need more careful feasibility investigation due to the global map design of ProcessGroup. See [RFC] Megatron-LM and MCore maintaining issues for veRL volcengine/verl#897 MCore pain point 1 for more information.
- (Advanced) Make it easy to extend ZeRO to support HSDP-style replicate + shared parallelism for asymmetric GPU clusters. The extension itself will be tracked in a separate issue.
While supporting multiple dimensions, DeviceMesh does not fit all parallelism techniques and thus need extension. Essentially, DeviceMesh, for each dimension, creates a ProcessGroup among ranks that share the same coordinate except that dimension. It fits the need of DP, TP and SP, but PP and EP have additional requirements: P2P groups among adjacent stages for PP and global data parallel groups for EP. We would like to subclass DeviceMesh in DeepSpeed (e.g. DeepSpeedDeviceMesh?) to collect parallelism configurations from both the model and configurations and create those additional groups. Each model has its own DeepSpeedDeviceMesh instance at DeepSpeedEngine.mesh_device and fetches ProcessGroups from there.
This is an early-stage idea yet. Any comment or suggestion will be welcomed.