Skip to content

Commit bf38034

Browse files
mattseddondaavoo
andauthored
Add step completed signal file for VS Code (#688)
* Add step completed signal file for VS Code * Use new DVC environment variable * Write file if no env variable present * Bump `dvc>=3.17.0`. --------- Co-authored-by: daavoo <[email protected]>
1 parent 914a7e6 commit bf38034

File tree

5 files changed

+103
-1
lines changed

5 files changed

+103
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ classifiers = [
3131
]
3232
dynamic = ["version"]
3333
dependencies = [
34-
"dvc>=2.58.0",
34+
"dvc>=3.17.0",
3535
"dvc-render>=0.5.0,<1.0",
3636
"dvc-studio-client>=0.10.0,<1",
3737
"funcy",

src/dvclive/dvc.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from dvclive.serialize import dump_yaml
1010
from dvclive.utils import StrPath
1111

12+
from . import env
13+
1214
if TYPE_CHECKING:
1315
from dvc.repo import Repo
1416
from dvc.stage import Stage
@@ -27,6 +29,11 @@ def _dvclive_only_signal_file(root_dir: StrPath) -> str:
2729
return os.path.join(dvc_exps_run_dir, "DVCLIVE_ONLY")
2830

2931

32+
def _dvclive_step_completed_signal_file(root_dir: StrPath) -> str:
33+
dvc_exps_run_dir = _dvc_exps_run_dir(root_dir)
34+
return os.path.join(dvc_exps_run_dir, "DVCLIVE_STEP_COMPLETED")
35+
36+
3037
def _find_dvc_root(root: Optional[StrPath] = None) -> Optional[str]:
3138
if not root:
3239
root = os.getcwd()
@@ -46,6 +53,10 @@ def _find_dvc_root(root: Optional[StrPath] = None) -> Optional[str]:
4653
return None
4754

4855

56+
def _find_non_queue_root() -> Optional[str]:
57+
return os.getenv(env.DVC_ROOT) or _find_dvc_root()
58+
59+
4960
def _write_file(file: str, contents: Dict[str, Union[str, int]]):
5061
import builtins
5162

@@ -106,6 +117,39 @@ def make_dvcyaml(live) -> None:
106117
dump_yaml(dvcyaml, live.dvc_file)
107118

108119

120+
def mark_dvclive_step_completed(step: int) -> None:
121+
"""
122+
https://github.com/iterative/vscode-dvc/issues/4528
123+
Signal DVC VS Code extension that
124+
a step has been completed for an experiment running in the queue
125+
"""
126+
non_queue_root_dir = _find_non_queue_root()
127+
128+
if not non_queue_root_dir:
129+
return
130+
131+
exp_run_dir = _dvc_exps_run_dir(non_queue_root_dir)
132+
os.makedirs(exp_run_dir, exist_ok=True)
133+
134+
signal_file = _dvclive_step_completed_signal_file(non_queue_root_dir)
135+
136+
_write_file(signal_file, {"pid": os.getpid(), "step": step})
137+
138+
139+
def cleanup_dvclive_step_completed() -> None:
140+
non_queue_root_dir = _find_non_queue_root()
141+
142+
if not non_queue_root_dir:
143+
return
144+
145+
signal_file = _dvclive_step_completed_signal_file(non_queue_root_dir)
146+
147+
if not os.path.exists(signal_file):
148+
return
149+
150+
os.remove(signal_file)
151+
152+
109153
def mark_dvclive_only_started(exp_name: str) -> None:
110154
"""
111155
Signal DVC VS Code extension that

src/dvclive/env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
DVC_CHECKPOINT = "DVC_CHECKPOINT"
55
DVC_EXP_BASELINE_REV = "DVC_EXP_BASELINE_REV"
66
DVC_EXP_NAME = "DVC_EXP_NAME"
7+
DVC_ROOT = "DVC_ROOT"

src/dvclive/live.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313

1414
from . import env
1515
from .dvc import (
16+
cleanup_dvclive_step_completed,
1617
ensure_dir_is_tracked,
1718
find_overlapping_stage,
1819
get_dvc_repo,
1920
get_random_exp_name,
2021
make_dvcyaml,
2122
mark_dvclive_only_ended,
2223
mark_dvclive_only_started,
24+
mark_dvclive_step_completed,
2325
)
2426
from .error import (
2527
InvalidDataTypeError,
@@ -295,6 +297,7 @@ def next_step(self):
295297
self.make_dvcyaml()
296298

297299
self.make_report()
300+
mark_dvclive_step_completed(self.step)
298301
self.step += 1
299302

300303
def log_metric(
@@ -570,6 +573,8 @@ def end(self):
570573
else:
571574
self.make_report()
572575

576+
cleanup_dvclive_step_completed()
577+
573578
def read_step(self):
574579
if Path(self.metrics_file).exists():
575580
latest = self.read_latest()

tests/test_main.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,58 @@ def test_context_manager_skips_end_calls(tmp_dir):
400400
assert (tmp_dir / live.metrics_file).exists()
401401

402402

403+
@pytest.mark.vscode()
404+
@pytest.mark.parametrize("dvc_root", [True, False])
405+
def test_vscode_dvclive_step_completed_signal_file(
406+
tmp_dir, dvc_root, mocker, monkeypatch
407+
):
408+
signal_file = os.path.join(
409+
tmp_dir, ".dvc", "tmp", "exps", "run", "DVCLIVE_STEP_COMPLETED"
410+
)
411+
cwd = tmp_dir
412+
test_pid = 12345
413+
414+
if dvc_root:
415+
cwd = tmp_dir / ".dvc" / "tmp" / "exps" / "asdasasf"
416+
monkeypatch.setenv(env.DVC_ROOT, tmp_dir)
417+
(cwd / ".dvc").mkdir(parents=True)
418+
419+
assert not os.path.exists(signal_file)
420+
421+
dvc_repo = mocker.MagicMock()
422+
dvc_repo.index.stages = []
423+
dvc_repo.config = {}
424+
dvc_repo.scm.get_rev.return_value = "current_rev"
425+
dvc_repo.scm.get_ref.return_value = None
426+
dvc_repo.scm.no_commits = False
427+
with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo), mocker.patch(
428+
"dvclive.live.os.getpid", return_value=test_pid
429+
):
430+
dvclive = Live(save_dvc_exp=True)
431+
assert not os.path.exists(signal_file)
432+
dvclive.next_step()
433+
assert dvclive.step == 1
434+
435+
if dvc_root:
436+
assert os.path.exists(signal_file)
437+
with open(signal_file, encoding="utf-8") as f:
438+
assert json.load(f) == {"pid": test_pid, "step": 0}
439+
440+
else:
441+
assert not os.path.exists(signal_file)
442+
443+
dvclive.next_step()
444+
assert dvclive.step == 2
445+
446+
if dvc_root:
447+
with open(signal_file, encoding="utf-8") as f:
448+
assert json.load(f) == {"pid": test_pid, "step": 1}
449+
450+
dvclive.end()
451+
452+
assert not os.path.exists(signal_file)
453+
454+
403455
@pytest.mark.vscode()
404456
@pytest.mark.parametrize("dvc_root", [True, False])
405457
def test_vscode_dvclive_only_signal_file(tmp_dir, dvc_root, mocker):

0 commit comments

Comments
 (0)