Skip to content

Commit fdd84af

Browse files
daavoodberenbaum
andauthored
utils: Add catch_and_warn. (#673)
dvc: Extract `ensure_dir_is_tracked` to `dvc`. Co-authored-by: dberenbaum <[email protected]>
1 parent 8a80256 commit fdd84af

File tree

5 files changed

+73
-60
lines changed

5 files changed

+73
-60
lines changed

src/dvclive/dvc.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,19 @@ def find_overlapping_stage(dvc_repo: "Repo", path: StrPath) -> Optional["Stage"]
150150
if str(out.fs_path) in abs_path:
151151
return stage
152152
return None
153+
154+
155+
def ensure_dir_is_tracked(directory: str, dvc_repo: "Repo") -> None:
156+
from pathspec import PathSpec
157+
158+
dir_spec = PathSpec.from_lines("gitwildmatch", [directory])
159+
outs_spec = PathSpec.from_lines(
160+
"gitwildmatch", [str(o) for o in dvc_repo.index.outs]
161+
)
162+
paths_to_track = [
163+
f
164+
for f in dvc_repo.scm.untracked_files()
165+
if (dir_spec.match_file(f) and not outs_spec.match_file(f))
166+
]
167+
if paths_to_track:
168+
dvc_repo.scm.add(paths_to_track)

src/dvclive/live.py

Lines changed: 37 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from pathlib import Path
66
from typing import Any, Dict, List, Optional, Set, Union
77

8+
from dvc.exceptions import DvcException
89
from dvc_studio_client.post_live_metrics import post_live_metrics
910
from funcy import set_in
10-
from pathspec import PathSpec
1111
from ruamel.yaml.representer import RepresenterError
1212

1313
from . import env
1414
from .dvc import (
15+
ensure_dir_is_tracked,
1516
find_overlapping_stage,
1617
get_dvc_repo,
1718
get_random_exp_name,
@@ -31,6 +32,7 @@
3132
from .studio import get_dvc_studio_config, get_studio_updates
3233
from .utils import (
3334
StrPath,
35+
catch_and_warn,
3436
clean_and_copy_into,
3537
env2bool,
3638
inside_notebook,
@@ -124,6 +126,7 @@ def _init_cleanup(self):
124126
if self.dvc_file and os.path.exists(self.dvc_file):
125127
os.remove(self.dvc_file)
126128

129+
@catch_and_warn(DvcException, logger)
127130
def _init_dvc(self):
128131
from dvc.scm import NoSCM
129132

@@ -453,35 +456,31 @@ def log_artifact(
453456
name,
454457
)
455458

459+
@catch_and_warn(DvcException, logger)
456460
def cache(self, path):
457-
try:
458-
if self._inside_dvc_exp:
459-
existing_stage = find_overlapping_stage(self._dvc_repo, path)
460-
461-
if existing_stage:
462-
if existing_stage.cmd:
463-
logger.info(
464-
f"Skipping `dvc add {path}` because it is already being"
465-
" tracked automatically as an output of `dvc exp run`."
466-
)
467-
return # skip caching
468-
logger.warning(
469-
f"To track '{path}' automatically during `dvc exp run`:"
470-
f"\n1. Run `dvc remove {existing_stage.addressing}` "
471-
"to stop tracking it outside the pipeline."
472-
"\n2. Add it as an output of the pipeline stage."
473-
)
474-
else:
475-
logger.warning(
476-
f"To track '{path}' automatically during `dvc exp run`, "
477-
"add it as an output of the pipeline stage."
478-
)
461+
if self._inside_dvc_exp:
462+
existing_stage = find_overlapping_stage(self._dvc_repo, path)
479463

480-
stage = self._dvc_repo.add(str(path))
464+
if existing_stage:
465+
if existing_stage.cmd:
466+
logger.info(
467+
f"Skipping `dvc add {path}` because it is already being"
468+
" tracked automatically as an output of `dvc exp run`."
469+
)
470+
return # skip caching
471+
logger.warning(
472+
f"To track '{path}' automatically during `dvc exp run`:"
473+
f"\n1. Run `dvc remove {existing_stage.addressing}` "
474+
"to stop tracking it outside the pipeline."
475+
"\n2. Add it as an output of the pipeline stage."
476+
)
477+
else:
478+
logger.warning(
479+
f"To track '{path}' automatically during `dvc exp run`, "
480+
"add it as an output of the pipeline stage."
481+
)
481482

482-
except Exception as e: # noqa: BLE001
483-
logger.warning(f"Failed to dvc add {path}: {e}")
484-
return
483+
stage = self._dvc_repo.add(str(path))
485484

486485
dvc_file = stage[0].addressing
487486

@@ -539,7 +538,10 @@ def end(self):
539538
if self._dvcyaml:
540539
self.make_dvcyaml()
541540

542-
self._ensure_paths_are_tracked_in_dvc_exp()
541+
if self._inside_dvc_exp and self._dvc_repo:
542+
catch_and_warn(DvcException, logger)(ensure_dir_is_tracked)(
543+
self.dir, self._dvc_repo
544+
)
543545

544546
self.save_dvc_exp()
545547

@@ -582,35 +584,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
582584
self._inside_with = False
583585
self.end()
584586

585-
def _ensure_paths_are_tracked_in_dvc_exp(self):
586-
if self._inside_dvc_exp and self._dvc_repo:
587-
dir_spec = PathSpec.from_lines("gitwildmatch", [self.dir])
588-
outs_spec = PathSpec.from_lines(
589-
"gitwildmatch", [str(o) for o in self._dvc_repo.index.outs]
590-
)
591-
try:
592-
paths_to_track = [
593-
f
594-
for f in self._dvc_repo.scm.untracked_files()
595-
if (dir_spec.match_file(f) and not outs_spec.match_file(f))
596-
]
597-
if paths_to_track:
598-
self._dvc_repo.scm.add(paths_to_track)
599-
except Exception as e: # noqa: BLE001
600-
logger.warning(f"Failed to git add paths:\n{e}")
601-
587+
@catch_and_warn(DvcException, logger, mark_dvclive_only_ended)
602588
def save_dvc_exp(self):
603589
if self._save_dvc_exp:
604-
from dvc.exceptions import DvcException
605-
606-
try:
607-
self._experiment_rev = self._dvc_repo.experiments.save(
608-
name=self._exp_name,
609-
include_untracked=self._include_untracked,
610-
force=True,
611-
message=self._exp_message,
612-
)
613-
except DvcException as e:
614-
logger.warning(f"Failed to save experiment:\n{e}")
615-
finally:
616-
mark_dvclive_only_ended()
590+
self._experiment_rev = self._dvc_repo.experiments.save(
591+
name=self._exp_name,
592+
include_untracked=self._include_untracked,
593+
force=True,
594+
message=self._exp_message,
595+
)

src/dvclive/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,19 @@ def isinstance_without_import(val, module, name):
155155
if (cls.__module__, cls.__name__) == (module, name):
156156
return True
157157
return False
158+
159+
160+
def catch_and_warn(exception, logger, on_finally=None):
161+
def decorator(func):
162+
def wrapper(*args, **kwargs):
163+
try:
164+
return func(*args, **kwargs)
165+
except exception as e:
166+
logger.warning(f"Error in {func.__name__}: {e}")
167+
finally:
168+
if on_finally is not None:
169+
on_finally()
170+
171+
return wrapper
172+
173+
return decorator

tests/test_dvc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
import pytest
4+
from dvc.exceptions import DvcException
45
from dvc.repo import Repo
56
from dvc.scm import NoSCM
67
from PIL import Image
@@ -279,7 +280,7 @@ def test_errors_on_git_add_are_catched(tmp_dir, mocked_dvc_repo, monkeypatch):
279280
monkeypatch.setenv(DVC_EXP_BASELINE_REV, "foo")
280281
monkeypatch.setenv(DVC_EXP_NAME, "bar")
281282
mocked_dvc_repo.scm.untracked_files.return_value = ["dvclive/metrics.json"]
282-
mocked_dvc_repo.scm.add.side_effect = Exception("foo")
283+
mocked_dvc_repo.scm.add.side_effect = DvcException("foo")
283284

284285
with Live(report=None) as live:
285286
live.summary["foo"] = 1

tests/test_log_artifact.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33

44
import pytest
5+
from dvc.exceptions import DvcException
56

67
from dvclive import Live
78
from dvclive.serialize import load_yaml
@@ -216,7 +217,7 @@ def test_log_artifact_attrs(tmp_dir, mocked_dvc_repo):
216217

217218
def test_log_artifact_type_model_when_dvc_add_fails(tmp_dir, mocker, mocked_dvc_repo):
218219
(tmp_dir / "model.pth").touch()
219-
mocked_dvc_repo.add.side_effect = Exception
220+
mocked_dvc_repo.add.side_effect = DvcException("foo")
220221
with Live(save_dvc_exp=True) as live:
221222
live.log_artifact("model.pth", type="model")
222223

0 commit comments

Comments
 (0)