@@ -1266,7 +1266,6 @@ def __post_init__(self):
12661266
12671267 # use_hybrid_parallel
12681268 if self .use_hybrid_parallel :
1269-
12701269 if ShardingOption .OFFLOAD in self .sharding :
12711270 warnings .warn ("`offload` is not supported NOW!" )
12721271
@@ -2327,6 +2326,19 @@ def sharding_parallel_rank(self):
23272326 else :
23282327 return 0
23292328
2329+ @property
2330+ def moe_sharding_parallel_rank (self ):
2331+ if self .use_hybrid_parallel :
2332+ hcg = fleet .get_hybrid_communicate_group ()
2333+ if hasattr (hcg , "get_moe_sharding_parallel_world_size" ) and hcg .get_moe_sharding_parallel_world_size () > 0 :
2334+ # hybrid expert parallel
2335+ moe_sharding_group = hcg .get_moe_sharding_parallel_group ()
2336+ return max (moe_sharding_group .rank , 0 )
2337+ else :
2338+ return 0
2339+ else :
2340+ return 0
2341+
23302342 @property
23312343 def tensor_parallel_rank (self ):
23322344 if self .use_hybrid_parallel :
@@ -2405,6 +2417,8 @@ def weight_name_suffix(self):
24052417 name .append (self ._format_name ("pp" , self .pipeline_parallel_rank , self .pipeline_parallel_degree ))
24062418 if self .use_expert_parallel and self .expert_parallel_degree <= 1 :
24072419 name .append (self ._format_name ("moe" , self .data_parallel_rank , self .data_parallel_degree ))
2420+ if self .use_expert_parallel and self .moe_sharding_parallel_degree >= 1 and self .expert_parallel_degree > 1 :
2421+ name .append (self ._format_name ("moe_sharding" , self .expert_parallel_rank , self .expert_parallel_degree ))
24082422 return "_" .join (name )
24092423
24102424 else :
@@ -2534,7 +2548,16 @@ def should_save_model_state(self):
25342548 return False
25352549 elif self .use_hybrid_parallel :
25362550 # save on dataset rank 0
2537- return self .sharding_parallel_rank == 0 and (self .data_parallel_rank == 0 or self .use_expert_parallel )
2551+ if self .use_expert_parallel :
2552+ if self .moe_sharding_parallel_degree >= 1 and self .expert_parallel_degree > 1 :
2553+ # for sharding parallel + moe, we save checkpoint on moe_sharding rank 0
2554+ return self .moe_sharding_parallel_rank == 0
2555+ else :
2556+ # for data parallel + moe, we save checkpoint on sharding rank 0
2557+ return self .sharding_parallel_rank == 0
2558+ else :
2559+ # for no moe, we save checkpoint on sharding and data rank 0
2560+ return self .sharding_parallel_rank == 0 and self .data_parallel_rank == 0
25382561 else :
25392562 return self .process_index == 0 or self .use_expert_parallel
25402563
0 commit comments