Skip to content

Commit aad6a01

Browse files
authored
lightning: Only force init if report="notebook". (#595)
Closes #594.
1 parent 9088996 commit aad6a01

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

src/dvclive/lightning.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def __init__(
5757
self._live_init["dir"] = dir
5858
self._experiment = experiment
5959
self._version = run_name
60-
# Force Live instantiation
61-
self.experiment # noqa: B018
60+
if report == "notebook":
61+
# Force Live instantiation
62+
self.experiment # noqa: B018
6263

6364
@property
6465
def name(self):

tests/test_frameworks/test_lightning.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.optim import SGD, Adam
1616
from torch.utils.data import DataLoader, Dataset
1717

18+
from dvclive import Live
1819
from dvclive.lightning import DVCLiveLogger
1920
except ImportError:
2021
pytest.skip("skipping pytorch_lightning tests", allow_module_level=True)
@@ -239,3 +240,14 @@ def test_lightning_val_udpates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio
239240
# Without `self.experiment._latest_studio_step -= 1`
240241
# This would be empty
241242
assert len(val_loss["data"]) == 1
243+
244+
245+
def test_lightning_force_init(tmp_dir, mocker):
246+
"""Regression test for https://github.com/iterative/dvclive/issues/594
247+
Only call Live.__init__ when report is notebook.
248+
"""
249+
init = mocker.spy(Live, "__init__")
250+
DVCLiveLogger()
251+
init.assert_not_called()
252+
DVCLiveLogger(report="notebook")
253+
init.assert_called_once()

0 commit comments

Comments
 (0)