|
8 | 8 |
|
9 | 9 | try: |
10 | 10 | import torch |
11 | | - from pytorch_lightning import LightningModule |
12 | | - from pytorch_lightning.trainer import Trainer |
| 11 | + from lightning import LightningModule |
| 12 | + from lightning.pytorch import Trainer |
| 13 | + from lightning.pytorch.callbacks import ModelCheckpoint |
13 | 14 | from torch import nn |
14 | 15 | from torch.nn import functional as F # noqa: N812 |
15 | 16 | from torch.optim import SGD, Adam |
|
18 | 19 | from dvclive import Live |
19 | 20 | from dvclive.lightning import DVCLiveLogger |
20 | 21 | except ImportError: |
21 | | - pytest.skip("skipping pytorch_lightning tests", allow_module_level=True) |
| 22 | + pytest.skip("skipping lightning tests", allow_module_level=True) |
22 | 23 |
|
23 | 24 |
|
24 | 25 | class XORDataset(Dataset): |
@@ -161,6 +162,38 @@ def test_lightning_kwargs(tmp_dir): |
161 | 162 | assert dvclive_logger.experiment._cache_images is True |
162 | 163 |
|
163 | 164 |
|
| 165 | +@pytest.mark.parametrize("log_model", [False, True, "all"]) |
| 166 | +@pytest.mark.parametrize("save_top_k", [1, -1]) |
| 167 | +def test_lightning_log_model(tmp_dir, mocker, log_model, save_top_k): |
| 168 | + model = LitXOR() |
| 169 | + dvclive_logger = DVCLiveLogger(dir="dir", log_model=log_model) |
| 170 | + checkpoint = ModelCheckpoint(dirpath="model", save_top_k=save_top_k) |
| 171 | + trainer = Trainer( |
| 172 | + logger=dvclive_logger, |
| 173 | + max_epochs=2, |
| 174 | + log_every_n_steps=1, |
| 175 | + callbacks=[checkpoint], |
| 176 | + ) |
| 177 | + log_artifact = mocker.patch.object(dvclive_logger.experiment, "log_artifact") |
| 178 | + trainer.fit(model) |
| 179 | + |
| 180 | + # Check that log_artifact is called. |
| 181 | + if log_model is False: |
| 182 | + log_artifact.assert_not_called() |
| 183 | + elif (log_model is True) and (save_top_k != -1): |
| 184 | + # called once to cache, then again to log best artifact |
| 185 | + assert log_artifact.call_count == 2 |
| 186 | + else: |
| 187 | + # once per epoch plus two calls at the end (see above) |
| 188 | + assert log_artifact.call_count == 4 |
| 189 | + |
| 190 | + # Check that checkpoint files does not grow with each run. |
| 191 | + num_checkpoints = len(os.listdir(tmp_dir / "model")) |
| 192 | + if log_model in [True, "all"]: |
| 193 | + trainer.fit(model) |
| 194 | + assert len(os.listdir(tmp_dir / "model")) == num_checkpoints |
| 195 | + |
| 196 | + |
164 | 197 | def test_lightning_steps(tmp_dir, mocker): |
165 | 198 | model = LitXOR() |
166 | 199 | # Handle kwargs passed to Live. |
|
0 commit comments