Skip to content

Commit dc1f45d

Browse files
Dave Berenbaumpre-commit-ci[bot]
andauthored
Revert dvcyaml (#398)
* log_sklearn_plot: fix .json replacement * dvcyaml: only write if saving exp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cd84219 commit dc1f45d

File tree

2 files changed

+42
-14
lines changed

2 files changed

+42
-14
lines changed

src/dvclive/live.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,14 @@ def next_step(self):
193193
self._step = 0
194194

195195
self.make_summary()
196-
make_dvcyaml(self)
196+
197+
if (
198+
self._dvc_repo is not None
199+
and not self._inside_dvc_exp
200+
and self._save_dvc_exp
201+
):
202+
make_dvcyaml(self)
203+
197204
self.make_report()
198205
self.make_checkpoint()
199206
self.step += 1
@@ -297,7 +304,6 @@ def make_report(self):
297304

298305
def end(self):
299306
self.make_summary(update_step=False)
300-
make_dvcyaml(self)
301307
if self._studio_url and self._studio_token:
302308
if "done" not in self._studio_events_to_skip:
303309
if not post_to_studio(self, "done", logger):
@@ -311,6 +317,7 @@ def end(self):
311317
and not self._inside_dvc_exp
312318
and self._save_dvc_exp
313319
):
320+
make_dvcyaml(self)
314321
self._dvc_repo.experiments.save(
315322
name=self._exp_name, include_untracked=self.dir
316323
)

tests/test_dvc.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,39 +20,44 @@ def test_get_dvc_repo(tmp_dir):
2020
assert isinstance(get_dvc_repo(), Repo)
2121

2222

23-
def test_make_dvcyaml(tmp_dir):
23+
def test_make_dvcyaml_empty(tmp_dir):
2424
live = Live()
2525
make_dvcyaml(live)
2626

2727
assert load_yaml(live.dvc_file) == {}
2828

29+
30+
def test_make_dvcyaml_param(tmp_dir):
2931
live = Live()
3032
live.log_param("foo", 1)
31-
live.next_step()
33+
make_dvcyaml(live)
3234

3335
assert load_yaml(live.dvc_file) == {
3436
"params": ["params.yaml"],
3537
}
3638

39+
40+
def test_make_dvcyaml_metrics(tmp_dir):
41+
live = Live()
3742
live.log_metric("bar", 2)
38-
live.end()
43+
make_dvcyaml(live)
3944

4045
assert load_yaml(live.dvc_file) == {
4146
"metrics": ["metrics.json"],
42-
"params": ["params.yaml"],
4347
"plots": [os.path.join("plots", "metrics")],
4448
}
4549

4650

4751
def test_make_dvcyaml_all_plots(tmp_dir):
48-
with Live() as live:
49-
live.log_param("foo", 1)
50-
live.log_metric("bar", 2)
51-
live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250)))
52-
live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0])
53-
live.log_sklearn_plot(
54-
"roc", [0, 0, 1, 1], [0.0, 0.5, 0.5, 0.0], "custom_name_roc"
55-
)
52+
live = Live()
53+
live.log_param("foo", 1)
54+
live.log_metric("bar", 2)
55+
live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250)))
56+
live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0])
57+
live.log_sklearn_plot(
58+
"roc", [0, 0, 1, 1], [0.0, 0.5, 0.5, 0.0], "custom_name_roc"
59+
)
60+
make_dvcyaml(live)
5661

5762
assert load_yaml(live.dvc_file) == {
5863
"metrics": ["metrics.json"],
@@ -138,6 +143,7 @@ def test_exp_save_on_end(tmp_dir, mocker, save):
138143
assert live._baseline_rev is None
139144
assert live._exp_name is None
140145
dvc_repo.experiments.save.assert_not_called()
146+
assert not (tmp_dir / live.dvc_file).exists()
141147

142148

143149
def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker):
@@ -152,6 +158,7 @@ def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker):
152158
assert live._baseline_rev == "foo"
153159
assert live._exp_name == "bar"
154160
assert live._inside_dvc_exp
161+
assert not (tmp_dir / live.dvc_file).exists()
155162

156163

157164
def test_exp_save_skip_on_dvc_repro(tmp_dir, mocker):
@@ -165,3 +172,17 @@ def test_exp_save_skip_on_dvc_repro(tmp_dir, mocker):
165172
live.end()
166173

167174
dvc_repo.experiments.save.assert_not_called()
175+
assert not (tmp_dir / live.dvc_file).exists()
176+
177+
178+
@pytest.mark.parametrize("save", [True, False])
179+
def test_dvcyaml_on_next_step(tmp_dir, mocker, save):
180+
dvc_repo = mocker.MagicMock()
181+
dvc_repo.index.stages = []
182+
with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo):
183+
live = Live(save_dvc_exp=save)
184+
live.next_step()
185+
if save:
186+
assert (tmp_dir / live.dvc_file).exists()
187+
else:
188+
assert not (tmp_dir / live.dvc_file).exists()

0 commit comments

Comments
 (0)