Skip to content

Commit dbb9694

Browse files
committed
fix model save for moe_sharding
1 parent 10014d8 commit dbb9694

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,6 @@ def __post_init__(self):
12661266

12671267
# use_hybrid_parallel
12681268
if self.use_hybrid_parallel:
1269-
12701269
if ShardingOption.OFFLOAD in self.sharding:
12711270
warnings.warn("`offload` is not supported NOW!")
12721271

@@ -2327,6 +2326,19 @@ def sharding_parallel_rank(self):
23272326
else:
23282327
return 0
23292328

2329+
@property
2330+
def moe_sharding_parallel_rank(self):
2331+
if self.use_hybrid_parallel:
2332+
hcg = fleet.get_hybrid_communicate_group()
2333+
if hasattr(hcg, "get_moe_sharding_parallel_world_size") and hcg.get_moe_sharding_parallel_world_size() > 0:
2334+
# hybrid expert parallel
2335+
moe_sharding_group = hcg.get_moe_sharding_parallel_group()
2336+
return max(moe_sharding_group.rank, 0)
2337+
else:
2338+
return 0
2339+
else:
2340+
return 0
2341+
23302342
@property
23312343
def tensor_parallel_rank(self):
23322344
if self.use_hybrid_parallel:
@@ -2405,6 +2417,8 @@ def weight_name_suffix(self):
24052417
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
24062418
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
24072419
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
2420+
if self.use_expert_parallel and self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1:
2421+
name.append(self._format_name("moe_sharding", self.expert_parallel_rank, self.expert_parallel_degree))
24082422
return "_".join(name)
24092423

24102424
else:
@@ -2534,7 +2548,16 @@ def should_save_model_state(self):
25342548
return False
25352549
elif self.use_hybrid_parallel:
25362550
# save on dataset rank 0
2537-
return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
2551+
if self.use_expert_parallel:
2552+
if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1:
2553+
# for sharding parallel + moe, we save checkpoint on moe_sharding rank 0
2554+
return self.moe_sharding_parallel_rank == 0
2555+
else:
2556+
# for data parallel + moe, we save checkpoint on sharding rank 0
2557+
return self.sharding_parallel_rank == 0
2558+
else:
2559+
# for no moe, we save checkpoint on sharding and data rank 0
2560+
return self.sharding_parallel_rank == 0 and self.data_parallel_rank == 0
25382561
else:
25392562
return self.process_index == 0 or self.use_expert_parallel
25402563

0 commit comments

Comments
 (0)