Skip to content

Commit 7d5c088

Browse files
author
Dave Berenbaum
authored
post data to studio on end (#738)
* post data to studio on end * fix test for dvc studio config
1 parent d2f861e commit 7d5c088

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

src/dvclive/live.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,9 @@ def end(self):
590590

591591
self.save_dvc_exp()
592592

593+
# Post any data that hasn't been sent
594+
self.post_to_studio("data")
595+
# Mark experiment as done
593596
self.post_to_studio("done")
594597

595598
cleanup_dvclive_step_completed()

tests/test_post_to_studio.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,24 @@ def test_post_to_studio_failed_start_request(
132132
assert mocked_post.call_count == 1
133133

134134

135-
def test_post_to_studio_end_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post):
135+
def test_post_to_studio_done_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post):
136136
mocked_post, _ = mocked_studio_post
137137
with Live() as live:
138138
live.log_metric("foo", 1)
139139
live.next_step()
140140

141-
assert mocked_post.call_count == 4
141+
expected_done_calls = [
142+
call
143+
for call in mocked_post.call_args_list
144+
if call.kwargs["json"]["type"] == "done"
145+
]
142146
live.end()
143-
assert mocked_post.call_count == 4
147+
actual_done_calls = [
148+
call
149+
for call in mocked_post.call_args_list
150+
if call.kwargs["json"]["type"] == "done"
151+
]
152+
assert expected_done_calls == actual_done_calls
144153

145154

146155
@pytest.mark.studio()
@@ -157,7 +166,9 @@ def test_post_to_studio_skip_start_and_done_on_env_var(
157166
live.log_metric("foo", 1)
158167
live.next_step()
159168

160-
assert mocked_post.call_count == 2
169+
call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
170+
assert "start" not in call_types
171+
assert "done" not in call_types
161172

162173

163174
@pytest.mark.studio()
@@ -169,14 +180,15 @@ def test_post_to_studio_dvc_studio_config(
169180
monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40)
170181
monkeypatch.setenv(DVC_EXP_NAME, "bar")
171182
monkeypatch.setenv(DVC_ROOT, tmp_dir)
183+
monkeypatch.delenv(DVC_STUDIO_TOKEN)
172184

173185
mocked_dvc_repo.config = {"studio": {"token": "token"}}
174186

175187
with Live() as live:
176188
live.log_metric("foo", 1)
177189
live.next_step()
178190

179-
assert mocked_post.call_count == 2
191+
assert mocked_post.call_args.kwargs["headers"]["Authorization"] == "token token"
180192

181193

182194
@pytest.mark.studio()
@@ -236,7 +248,9 @@ def test_post_to_studio_inside_dvc_exp(
236248
live.log_metric("foo", 1)
237249
live.next_step()
238250

239-
assert mocked_post.call_count == 2
251+
call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
252+
assert "start" not in call_types
253+
assert "done" not in call_types
240254

241255

242256
@pytest.mark.studio()
@@ -370,3 +384,15 @@ def test_post_to_studio_name(tmp_dir, mocked_dvc_repo, mocked_studio_post):
370384
"https://0.0.0.0/api/live",
371385
**get_studio_call("start", exp_name="custom-name"),
372386
)
387+
388+
389+
def test_post_to_studio_if_done_skipped(tmp_dir, mocked_dvc_repo, mocked_studio_post):
390+
live = Live()
391+
live._studio_events_to_skip.add("start")
392+
live._studio_events_to_skip.add("done")
393+
live.log_metric("foo", 1)
394+
live.end()
395+
396+
mocked_post, _ = mocked_studio_post
397+
call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
398+
assert "data" in call_types

0 commit comments

Comments
 (0)