@@ -132,15 +132,24 @@ def test_post_to_studio_failed_start_request(
132132 assert mocked_post .call_count == 1
133133
134134
135- def test_post_to_studio_end_only_once (tmp_dir , mocked_dvc_repo , mocked_studio_post ):
135+ def test_post_to_studio_done_only_once (tmp_dir , mocked_dvc_repo , mocked_studio_post ):
136136 mocked_post , _ = mocked_studio_post
137137 with Live () as live :
138138 live .log_metric ("foo" , 1 )
139139 live .next_step ()
140140
141- assert mocked_post .call_count == 4
141+ expected_done_calls = [
142+ call
143+ for call in mocked_post .call_args_list
144+ if call .kwargs ["json" ]["type" ] == "done"
145+ ]
142146 live .end ()
143- assert mocked_post .call_count == 4
147+ actual_done_calls = [
148+ call
149+ for call in mocked_post .call_args_list
150+ if call .kwargs ["json" ]["type" ] == "done"
151+ ]
152+ assert expected_done_calls == actual_done_calls
144153
145154
146155@pytest .mark .studio ()
@@ -157,7 +166,9 @@ def test_post_to_studio_skip_start_and_done_on_env_var(
157166 live .log_metric ("foo" , 1 )
158167 live .next_step ()
159168
160- assert mocked_post .call_count == 2
169+ call_types = [call .kwargs ["json" ]["type" ] for call in mocked_post .call_args_list ]
170+ assert "start" not in call_types
171+ assert "done" not in call_types
161172
162173
163174@pytest .mark .studio ()
@@ -169,14 +180,15 @@ def test_post_to_studio_dvc_studio_config(
169180 monkeypatch .setenv (DVC_EXP_BASELINE_REV , "f" * 40 )
170181 monkeypatch .setenv (DVC_EXP_NAME , "bar" )
171182 monkeypatch .setenv (DVC_ROOT , tmp_dir )
183+ monkeypatch .delenv (DVC_STUDIO_TOKEN )
172184
173185 mocked_dvc_repo .config = {"studio" : {"token" : "token" }}
174186
175187 with Live () as live :
176188 live .log_metric ("foo" , 1 )
177189 live .next_step ()
178190
179- assert mocked_post .call_count == 2
191+ assert mocked_post .call_args . kwargs [ "headers" ][ "Authorization" ] == "token token"
180192
181193
182194@pytest .mark .studio ()
@@ -236,7 +248,9 @@ def test_post_to_studio_inside_dvc_exp(
236248 live .log_metric ("foo" , 1 )
237249 live .next_step ()
238250
239- assert mocked_post .call_count == 2
251+ call_types = [call .kwargs ["json" ]["type" ] for call in mocked_post .call_args_list ]
252+ assert "start" not in call_types
253+ assert "done" not in call_types
240254
241255
242256@pytest .mark .studio ()
@@ -370,3 +384,15 @@ def test_post_to_studio_name(tmp_dir, mocked_dvc_repo, mocked_studio_post):
370384 "https://0.0.0.0/api/live" ,
371385 ** get_studio_call ("start" , exp_name = "custom-name" ),
372386 )
387+
388+
389+ def test_post_to_studio_if_done_skipped (tmp_dir , mocked_dvc_repo , mocked_studio_post ):
390+ live = Live ()
391+ live ._studio_events_to_skip .add ("start" )
392+ live ._studio_events_to_skip .add ("done" )
393+ live .log_metric ("foo" , 1 )
394+ live .end ()
395+
396+ mocked_post , _ = mocked_studio_post
397+ call_types = [call .kwargs ["json" ]["type" ] for call in mocked_post .call_args_list ]
398+ assert "data" in call_types
0 commit comments