Skip to content

Commit 8b77c21

Browse files
committed
[CI] Update local_rank in DeterministicDDPTestCase.create_pg
1 parent 5c7a66d commit 8b77c21

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

xtuner/_testing/testcase.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from torch.testing._internal.common_distributed import DistributedTestBase, MultiProcessTestCase, logger, TEST_SKIPS, c10d
22
import torch
3+
import torch.distributed as dist
34
import threading
45
import sys
56
import os
@@ -91,3 +92,7 @@ def _check_loss_curve(
9192
raise AssertionError(
9293
f"Failed to check relative error of loss, expected: {losses_ref}, got {losses}, Mean diff: {avg_relative_diff}")
9394

95+
def create_pg(self, device):
96+
ret = super().create_pg(device)
97+
os.environ["LOCAL_RANK"] = str(dist.get_rank() % torch.cuda.device_count())
98+
return ret

0 commit comments

Comments
 (0)