Skip to content

Commit bf070dd

Browse files
committed
fast.ai: Handle resuming.
Found some bugs while working on https://github.com/iterative/dvc-get-started-cv/pull/59 . - Use the ocasion to also remove the redundant logging of `epoch` as a metric. - Expose `with_opt` option. - Don't increase step when resuming.
1 parent 63d1b20 commit bf070dd

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

src/dvclive/fastai.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,36 @@
77

88

99
class DVCLiveCallback(Callback):
10-
def __init__(self, model_file=None, live: Optional[Live] = None, **kwargs):
10+
def __init__(
11+
self,
12+
model_file: Optional[str] = None,
13+
with_opt: bool = False,
14+
live: Optional[Live] = None,
15+
**kwargs
16+
):
1117
super().__init__()
1218
self.model_file = model_file
19+
self.with_opt = with_opt
1320
self.live = live if live is not None else Live(**kwargs)
1421

1522
def after_epoch(self):
23+
logged_metrics = False
1624
for key, value in zip(
1725
self.learn.recorder.metric_names, self.learn.recorder.log
1826
):
27+
if key == "epoch":
28+
continue
1929
self.live.log_metric(
2030
standardize_metric_name(key, __name__), float(value)
2131
)
32+
logged_metrics = True
2233

23-
if self.model_file:
24-
self.learn.save(self.model_file)
25-
self.live.next_step()
34+
# When resuming (i.e. passing `start_epoch` to learner)
35+
# fast.ai calls after_epoch but we don't want to increase the step.
36+
if logged_metrics:
37+
if self.model_file:
38+
self.learn.save(self.model_file, with_opt=self.with_opt)
39+
self.live.next_step()
2640

2741
def after_fit(self):
2842
self.live.end()

tests/test_fastai.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ def data_loader():
4141
def test_fastai_callback(tmp_dir, data_loader, mocker):
4242
learn = tabular_learner(data_loader, metrics=accuracy)
4343
learn.remove_cb(ProgressCallback)
44-
learn.model_dir = os.path.abspath("./")
45-
callback = DVCLiveCallback("model")
44+
callback = DVCLiveCallback()
4645
live = callback.live
4746

4847
spy = mocker.spy(live, "end")
@@ -58,18 +57,38 @@ def test_fastai_callback(tmp_dir, data_loader, mocker):
5857
assert train_path.is_dir()
5958
assert valid_path.is_dir()
6059
assert (metrics_path / "accuracy.tsv").exists()
60+
assert not (metrics_path / "epoch.tsv").exists()
6161

6262

63-
def test_fastai_model_file(tmp_dir, data_loader):
63+
def test_fastai_model_file(tmp_dir, data_loader, mocker):
6464
learn = tabular_learner(data_loader, metrics=accuracy)
6565
learn.remove_cb(ProgressCallback)
6666
learn.model_dir = os.path.abspath("./")
67-
learn.fit_one_cycle(2, cbs=[DVCLiveCallback("model")])
67+
save = mocker.spy(learn, "save")
68+
learn.fit_one_cycle(2, cbs=[DVCLiveCallback("model", with_opt=True)])
6869
assert (tmp_dir / "model.pth").is_file()
70+
save.assert_called_with("model", with_opt=True)
6971

7072

7173
def test_fastai_pass_logger():
7274
logger = Live("train_logs")
7375

7476
assert DVCLiveCallback().live is not logger
7577
assert DVCLiveCallback(live=logger).live is logger
78+
79+
80+
def test_fast_ai_resume(tmp_dir, data_loader, mocker):
81+
learn = tabular_learner(data_loader, metrics=accuracy)
82+
learn.remove_cb(ProgressCallback)
83+
callback = DVCLiveCallback()
84+
live = callback.live
85+
86+
spy = mocker.spy(live, "next_step")
87+
learn.fit_one_cycle(2, cbs=[callback])
88+
assert spy.call_count == 2
89+
90+
callback = DVCLiveCallback(resume=True)
91+
live = callback.live
92+
spy = mocker.spy(live, "next_step")
93+
learn.fit_one_cycle(3, cbs=[callback], start_epoch=live.step - 1)
94+
assert spy.call_count == 1

0 commit comments

Comments
 (0)