Skip to content

Commit 371998e

Browse files
committed
fix model save for moe_sharding
1 parent 10014d8 commit 371998e

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,6 @@ def __post_init__(self):
12661266

12671267
# use_hybrid_parallel
12681268
if self.use_hybrid_parallel:
1269-
12701269
if ShardingOption.OFFLOAD in self.sharding:
12711270
warnings.warn("`offload` is not supported NOW!")
12721271

@@ -2327,6 +2326,17 @@ def sharding_parallel_rank(self):
23272326
else:
23282327
return 0
23292328

2329+
@property
2330+
def moe_sharding_parallel_rank(self):
2331+
if self.use_hybrid_parallel:
2332+
hcg = fleet.get_hybrid_communicate_group()
2333+
if hasattr(hcg, "get_moe_sharding_parallel_group"):
2334+
return max(hcg.get_moe_sharding_parallel_group().rank, 0)
2335+
else:
2336+
return 0
2337+
else:
2338+
return 0
2339+
23302340
@property
23312341
def tensor_parallel_rank(self):
23322342
if self.use_hybrid_parallel:
@@ -2405,6 +2415,8 @@ def weight_name_suffix(self):
24052415
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
24062416
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
24072417
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
2418+
if self.use_expert_parallel and self.expert_parallel_degree > 1:
2419+
name.append(self._format_name("moe_sharding", self.expert_parallel_rank, self.expert_parallel_degree))
24082420
return "_".join(name)
24092421

24102422
else:
@@ -2534,7 +2546,9 @@ def should_save_model_state(self):
25342546
return False
25352547
elif self.use_hybrid_parallel:
25362548
# save on dataset rank 0
2537-
return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
2549+
return (
2550+
self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
2551+
) or (self.expert_parallel_degree > 1 and self.moe_sharding_parallel_rank == 0)
25382552
else:
25392553
return self.process_index == 0 or self.use_expert_parallel
25402554

0 commit comments

Comments
 (0)