Skip to content

Commit 41a161a

Browse files
authored
add health check (#1275)
1 parent 7f4f794 commit 41a161a

File tree

3 files changed

+141
-0
lines changed

3 files changed

+141
-0
lines changed

tests/utils/test_check_health.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
import torch
3+
import torch.distributed as dist
4+
from torch.testing._internal.common_distributed import DistributedTestBase
5+
from unittest.mock import patch, Mock
6+
7+
import xtuner.v1.utils.check_health as check_health
8+
9+
from xtuner.v1.utils.device import get_device
10+
11+
12+
DEVICE = get_device()
13+
14+
15+
def fake_health_job(dtype, loop=10):
16+
if dist.get_rank() == 1:
17+
print(f"rank {dist.get_rank()} world size {dist.get_world_size()} return 0.0")
18+
return torch.tensor(0.0, dtype=dtype, device=DEVICE)
19+
else:
20+
print(f"rank {dist.get_rank()} world size {dist.get_world_size()} return 1.0")
21+
return torch.tensor(1.0, dtype=dtype, device=DEVICE)
22+
23+
24+
class TestCheckHealth(DistributedTestBase):
25+
def create_pg(self, device):
26+
ret = super().create_pg(device)
27+
os.environ["LOCAL_RANK"] = str(dist.get_rank())
28+
torch.accelerator.set_device_index(int(os.environ["LOCAL_RANK"]))
29+
return ret
30+
31+
def test_check_health_normal(self):
32+
self.create_pg(DEVICE)
33+
34+
self.assertTrue(check_health.check_health())
35+
36+
def test_check_health_failed(self):
37+
self.create_pg(DEVICE)
38+
39+
with patch("xtuner.v1.utils.check_health.health_job", fake_health_job):
40+
self.assertFalse(check_health.check_health())

xtuner/v1/train/trainer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
log_format,
4848
record_git_info,
4949
)
50+
from xtuner.v1.utils.check_health import check_health
5051
from xtuner.v1.utils.device import get_device, get_torch_device_module
5152

5253
from .toy_tokenizer import UTF8ByteTokenizer
@@ -172,6 +173,7 @@ class TrainerConfig(BaseModel):
172173
checkpoint_maxkeep: int | None = -1
173174
skip_checkpoint_validation: bool = False # Suggest enabled if fsdp_size is larger than 512
174175
snapshot_interval: int | None = None
176+
check_health_interval: int | None = None
175177
hf_interval: int | None = None
176178
hf_max_keep: int | None = None
177179
exp_tracker: Literal["tensorboard", "jsonl"] = "jsonl"
@@ -289,6 +291,7 @@ def __init__(
289291
checkpoint_maxkeep: int | None = -1,
290292
skip_checkpoint_validation: bool = False, # Suggest enabled if fsdp_size is larger than 512
291293
snapshot_interval: int | None = None,
294+
check_health_interval: int | None = None,
292295
hf_interval: int | None = None,
293296
hf_max_keep: int | None = None,
294297
exp_tracker: Literal["tensorboard", "jsonl"] = "jsonl",
@@ -337,6 +340,7 @@ def __init__(
337340
self._checkpoint_interval = checkpoint_interval
338341
self._checkpoint_maxkeep = checkpoint_maxkeep
339342
self._snapshot_interval = snapshot_interval
343+
self._check_health_interval = check_health_interval
340344
self._hf_max_keep = hf_max_keep
341345
self._hf_interval = hf_interval
342346

@@ -481,6 +485,7 @@ def from_config(cls, config: TrainerConfig) -> Self:
481485
checkpoint_maxkeep=config.checkpoint_maxkeep,
482486
skip_checkpoint_validation=config.skip_checkpoint_validation,
483487
snapshot_interval=config.snapshot_interval,
488+
check_health_interval=config.check_health_interval,
484489
hf_interval=config.hf_interval,
485490
hf_max_keep=config.hf_max_keep,
486491
exp_tracker=config.exp_tracker,
@@ -586,6 +591,7 @@ def fit(self):
586591
)
587592

588593
self._lr_scheduler.step()
594+
self._maybe_check_health()
589595
self._maybe_save_hf()
590596
ckpt_saved = self._maybe_save(is_snapshot=False)
591597
if not ckpt_saved:
@@ -806,6 +812,16 @@ def warmup_fn(x):
806812
)
807813
return lr_scheduler
808814

815+
def _maybe_check_health(self):
816+
if (
817+
(self._check_health_interval is not None and self.cur_step % self._check_health_interval == 0)
818+
or (self._checkpoint_interval is not None and self.cur_step % self._checkpoint_interval == 0)
819+
or (self._snapshot_interval is not None and self.cur_step % self._snapshot_interval == 0)
820+
):
821+
if not check_health():
822+
raise RuntimeError("Health check failed, exit training")
823+
logger.info(f"Health check passed at step {self.cur_step}")
824+
809825
def _maybe_save(self, is_snapshot: bool = False) -> bool:
810826
ckp_interval = self._checkpoint_interval if not is_snapshot else self._snapshot_interval
811827
if ckp_interval is None:

xtuner/v1/utils/check_health.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from collections import defaultdict
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.nn.functional as F
6+
7+
from xtuner.v1.utils import get_logger
8+
from xtuner.v1.utils.device import get_device
9+
10+
11+
logger = get_logger()
12+
13+
DEVICE = get_device()
14+
15+
16+
def health_job(dtype, loop=10):
17+
# use independent generator to avoid affecting the global generator
18+
x = torch.rand(128, 128, generator=torch.Generator(device=DEVICE).manual_seed(12345), dtype=dtype, device=DEVICE)
19+
dist.broadcast(x, src=0)
20+
21+
y = x
22+
for _ in range(loop):
23+
y = F.normalize(y, dim=0)
24+
torch.matmul(x, y, out=y)
25+
y = y.mean()
26+
return y
27+
28+
29+
def check_health(loop=10):
30+
rank = dist.get_rank()
31+
world_size = dist.get_world_size()
32+
33+
dtype = torch.bfloat16
34+
rtol = 1.6e-2
35+
atol = 1e-5
36+
# from torch.testing.assert_close:
37+
# +---------------------------+------------+----------+
38+
# | ``dtype`` | ``rtol`` | ``atol`` |
39+
# +===========================+============+==========+
40+
# | :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` |
41+
# +---------------------------+------------+----------+
42+
# | :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` |
43+
# +---------------------------+------------+----------+
44+
# | :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` |
45+
# +---------------------------+------------+----------+
46+
47+
y = health_job(dtype, loop)
48+
49+
# gather check
50+
y_list = [torch.tensor(0.0, dtype=dtype, device=DEVICE) for _ in range(world_size)] if rank == 0 else None
51+
dist.gather(y, y_list)
52+
gather_check = torch.tensor(1, dtype=torch.int32, device=DEVICE)
53+
if rank == 0:
54+
for i in range(world_size):
55+
if not torch.allclose(y, y_list[i], rtol=rtol, atol=atol):
56+
gather_check = torch.tensor(0, dtype=torch.int32, device=DEVICE)
57+
break
58+
dist.all_reduce(gather_check, op=dist.ReduceOp.MIN)
59+
60+
# all reduce check
61+
z = y.clone()
62+
dist.all_reduce(z, op=dist.ReduceOp.AVG)
63+
all_reduce_check = (
64+
torch.tensor(1, dtype=torch.int32, device=DEVICE)
65+
if torch.allclose(y, z, rtol=rtol, atol=atol)
66+
else torch.tensor(0, dtype=torch.int32, device=DEVICE)
67+
)
68+
dist.all_reduce(all_reduce_check, op=dist.ReduceOp.MIN)
69+
70+
if gather_check.item() == 1 and all_reduce_check.item() == 1:
71+
return True
72+
73+
if rank == 0: # log
74+
logger.error(
75+
f"Health check failed: gather_check={gather_check.item()}, all_reduce_check={all_reduce_check.item()}. rtol={rtol}, atol={atol}."
76+
)
77+
logger.error(f"All reduce check info: y: {y.item()}, z: {z.item()}")
78+
79+
y2rank = defaultdict(list)
80+
for ranki, yi in enumerate(y_list):
81+
y2rank[yi.item()].append(ranki)
82+
for yi, ranks in y2rank.items():
83+
logger.error(f"Gather check info: rank {sorted(ranks)}: {yi}")
84+
85+
return False

0 commit comments

Comments
 (0)