Skip to content

Commit 354a2f5

Browse files
committed
[bug] lightning default dir integration fix
1 parent 7fe6cdf commit 354a2f5

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/dvclive/lightning.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@ def __init__(
2020

2121
super().__init__()
2222
self._prefix = prefix
23-
self._live_init = {
24-
"dir": dir,
25-
"resume": resume,
26-
}
23+
self._live_init: Dict[str, Any] = {"resume": resume}
24+
if dir is not None:
25+
self._live_init["dir"] = dir
2726
self._experiment = experiment
2827
self._version = run_name
2928

tests/test_lightning.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,18 @@ def test_lightning_integration(tmp_dir):
109109
assert os.path.join(scalars, "train", "epoch", "loss.tsv") in logs
110110
assert os.path.join(scalars, "train", "step", "loss.tsv") in logs
111111
assert os.path.join(scalars, "epoch.tsv") in logs
112+
113+
114+
def test_lightning_default_dir(tmp_dir):
115+
model = LitXOR()
116+
# If `dir` is not provided handle it properly, use default value
117+
dvclive_logger = DVCLiveLogger("test_run")
118+
trainer = Trainer(
119+
logger=dvclive_logger,
120+
max_epochs=2,
121+
enable_checkpointing=False,
122+
log_every_n_steps=1,
123+
)
124+
trainer.fit(model)
125+
126+
assert os.path.exists("dvclive")

0 commit comments

Comments
 (0)