@@ -1266,7 +1266,6 @@ def __post_init__(self):
12661266
12671267 # use_hybrid_parallel
12681268 if self .use_hybrid_parallel :
1269-
12701269 if ShardingOption .OFFLOAD in self .sharding :
12711270 warnings .warn ("`offload` is not supported NOW!" )
12721271
@@ -2327,6 +2326,17 @@ def sharding_parallel_rank(self):
23272326 else :
23282327 return 0
23292328
2329+ @property
2330+ def moe_sharding_parallel_rank (self ):
2331+ if self .use_hybrid_parallel :
2332+ hcg = fleet .get_hybrid_communicate_group ()
2333+ if hasattr (hcg , "get_moe_sharding_parallel_group" ):
2334+ return max (hcg .get_moe_sharding_parallel_group ().rank , 0 )
2335+ else :
2336+ return 0
2337+ else :
2338+ return 0
2339+
23302340 @property
23312341 def tensor_parallel_rank (self ):
23322342 if self .use_hybrid_parallel :
@@ -2405,6 +2415,8 @@ def weight_name_suffix(self):
24052415 name .append (self ._format_name ("pp" , self .pipeline_parallel_rank , self .pipeline_parallel_degree ))
24062416 if self .use_expert_parallel and self .expert_parallel_degree <= 1 :
24072417 name .append (self ._format_name ("moe" , self .data_parallel_rank , self .data_parallel_degree ))
2418+ if self .use_expert_parallel and self .expert_parallel_degree > 1 :
2419+ name .append (self ._format_name ("moe_sharding" , self .expert_parallel_rank , self .expert_parallel_degree ))
24082420 return "_" .join (name )
24092421
24102422 else :
@@ -2534,7 +2546,9 @@ def should_save_model_state(self):
25342546 return False
25352547 elif self .use_hybrid_parallel :
25362548 # save on dataset rank 0
2537- return self .sharding_parallel_rank == 0 and (self .data_parallel_rank == 0 or self .use_expert_parallel )
2549+ return (
2550+ self .sharding_parallel_rank == 0 and (self .data_parallel_rank == 0 or self .use_expert_parallel )
2551+ ) or (self .expert_parallel_degree > 1 and self .moe_sharding_parallel_rank == 0 )
25382552 else :
25392553 return self .process_index == 0 or self .use_expert_parallel
25402554
0 commit comments