Skip to content

Commit 786c83a

Browse files
dberenbaumdberenbaum
authored andcommitted
lightning: add tests for log_model
1 parent c9452c7 commit 786c83a

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

src/dvclive/lightning.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,21 +141,21 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
141141
self.experiment.next_step()
142142

143143
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
144-
self._checkpoint_callback = checkpoint_callback
144+
if self._log_model in [True, "all"]:
145+
self._checkpoint_callback = checkpoint_callback
146+
self._scan_checkpoints(checkpoint_callback)
145147
if self._log_model == "all" or (
146148
self._log_model is True and checkpoint_callback.save_top_k == -1
147149
):
148150
self._save_checkpoints(checkpoint_callback)
149151

150152
@rank_zero_only
151153
def finalize(self, status: str) -> None:
152-
checkpoint_callback = self._checkpoint_callback
153-
# Save model checkpoints.
154-
if self._log_model is True:
155-
self._save_checkpoints(checkpoint_callback)
156154
# Log best model.
157-
if self._log_model in (True, "all"):
158-
best_model_path = checkpoint_callback.best_model_path
155+
if self._checkpoint_callback:
156+
self._scan_checkpoints(self._checkpoint_callback)
157+
self._save_checkpoints(self._checkpoint_callback)
158+
best_model_path = self._checkpoint_callback.best_model_path
159159
self.experiment.log_artifact(
160160
best_model_path, name="best", type="model", cache=False
161161
)

tests/test_frameworks/test_lightning.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88

99
try:
1010
import torch
11-
from pytorch_lightning import LightningModule
12-
from pytorch_lightning.trainer import Trainer
11+
from lightning import LightningModule
12+
from lightning.pytorch import Trainer
13+
from lightning.pytorch.callbacks import ModelCheckpoint
1314
from torch import nn
1415
from torch.nn import functional as F # noqa: N812
1516
from torch.optim import SGD, Adam
@@ -18,7 +19,7 @@
1819
from dvclive import Live
1920
from dvclive.lightning import DVCLiveLogger
2021
except ImportError:
21-
pytest.skip("skipping pytorch_lightning tests", allow_module_level=True)
22+
pytest.skip("skipping lightning tests", allow_module_level=True)
2223

2324

2425
class XORDataset(Dataset):
@@ -161,6 +162,38 @@ def test_lightning_kwargs(tmp_dir):
161162
assert dvclive_logger.experiment._cache_images is True
162163

163164

165+
@pytest.mark.parametrize("log_model", [False, True, "all"])
166+
@pytest.mark.parametrize("save_top_k", [1, -1])
167+
def test_lightning_log_model(tmp_dir, mocker, log_model, save_top_k):
168+
model = LitXOR()
169+
dvclive_logger = DVCLiveLogger(dir="dir", log_model=log_model)
170+
checkpoint = ModelCheckpoint(dirpath="model", save_top_k=save_top_k)
171+
trainer = Trainer(
172+
logger=dvclive_logger,
173+
max_epochs=2,
174+
log_every_n_steps=1,
175+
callbacks=[checkpoint],
176+
)
177+
log_artifact = mocker.patch.object(dvclive_logger.experiment, "log_artifact")
178+
trainer.fit(model)
179+
180+
# Check that log_artifact is called.
181+
if log_model is False:
182+
log_artifact.assert_not_called()
183+
elif (log_model is True) and (save_top_k != -1):
184+
# called once to cache, then again to log best artifact
185+
assert log_artifact.call_count == 2
186+
else:
187+
# once per epoch plus two calls at the end (see above)
188+
assert log_artifact.call_count == 4
189+
190+
# Check that checkpoint files does not grow with each run.
191+
num_checkpoints = len(os.listdir(tmp_dir / "model"))
192+
if log_model in [True, "all"]:
193+
trainer.fit(model)
194+
assert len(os.listdir(tmp_dir / "model")) == num_checkpoints
195+
196+
164197
def test_lightning_steps(tmp_dir, mocker):
165198
model = LitXOR()
166199
# Handle kwargs passed to Live.

0 commit comments

Comments
 (0)