Skip to content

Commit 9f9c31b

Browse files
committed
frameworks: Call live.end at the end of training.
Ensure `post_to_studio` inside `live.end` is only called once.
1 parent da345f5 commit 9f9c31b

File tree

14 files changed

+80
-10
lines changed

14 files changed

+80
-10
lines changed

src/dvclive/catalyst.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@ def on_epoch_end(self, runner) -> None:
2727
)
2828
utils.save_checkpoint(checkpoint, self.model_file)
2929
self.live.next_step()
30+
31+
def on_experiment_end(self, runner): # pylint: disable=unused-argument
32+
self.live.end()

src/dvclive/fastai.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,6 @@ def after_epoch(self):
2323
if self.model_file:
2424
self.learn.save(self.model_file)
2525
self.live.next_step()
26+
27+
def after_fit(self):
28+
self.live.end()

src/dvclive/huggingface.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,12 @@ def on_epoch_end(
4242
tokenizer = kwargs.get("tokenizer")
4343
if tokenizer:
4444
tokenizer.save_pretrained(self.model_file)
45+
46+
def on_train_end(
47+
self,
48+
args: TrainingArguments,
49+
state: TrainerState,
50+
control: TrainerControl,
51+
**kwargs
52+
):
53+
self.live.end()

src/dvclive/keras.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,8 @@ def on_epoch_end(
5454
else:
5555
self.model.save(self.model_file)
5656
self.live.next_step()
57+
58+
def on_train_end(
59+
self, logs: Optional[Dict] = None
60+
): # pylint: disable=unused-argument
61+
self.live.end()

src/dvclive/lightning.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,7 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
6969
metric_name = standardize_metric_name(metric_name, __name__)
7070
self.experiment.log_metric(name=metric_name, val=metric_val)
7171
self.experiment.next_step()
72+
73+
@rank_zero_only
74+
def finalize(self, status: str) -> None:
75+
self.experiment.end()

src/dvclive/live.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
):
4444
self._dir: str = dir
4545
self._resume: bool = resume or env2bool(env.DVCLIVE_RESUME)
46-
46+
self._ended: bool = False
4747
self.studio_url = os.getenv(env.STUDIO_REPO_URL, None)
4848
self.studio_token = os.getenv(env.STUDIO_TOKEN, None)
4949
self.rev = None
@@ -243,8 +243,10 @@ def make_report(self):
243243
def end(self):
244244
self.make_summary()
245245
if self.report_mode == "studio":
246-
if not post_to_studio(self, "done", logger):
247-
logger.warning("`post_to_studio` `done` event failed.")
246+
if not self._ended:
247+
if not post_to_studio(self, "done", logger):
248+
logger.warning("`post_to_studio` `done` event failed.")
249+
self._ended = True
248250
else:
249251
self.make_report()
250252

src/dvclive/xgb.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,7 @@ def after_iteration(self, model, epoch, evals_log):
2626
if self.model_file:
2727
model.save_model(self.model_file)
2828
self.live.next_step()
29+
30+
def after_training(self, model):
31+
self.live.end()
32+
return model

tests/test_catalyst.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,17 @@ def runner_params():
4949
}
5050

5151

52-
def test_catalyst_callback(tmp_dir, runner, runner_params):
52+
def test_catalyst_callback(tmp_dir, runner, runner_params, mocker):
53+
callback = DVCLiveCallback()
54+
live = callback.live
55+
spy = mocker.spy(live, "end")
56+
5357
runner.train(
5458
**runner_params,
5559
num_epochs=2,
5660
callbacks=[
5761
dl.AccuracyCallback(input_key="logits", target_key="targets"),
58-
DVCLiveCallback(),
62+
callback,
5963
],
6064
logdir="./logs",
6165
valid_loader="valid",
@@ -64,6 +68,7 @@ def test_catalyst_callback(tmp_dir, runner, runner_params):
6468
verbose=True,
6569
load_best_on_end=True,
6670
)
71+
spy.assert_called_once()
6772

6873
assert os.path.exists("dvclive")
6974

tests/test_fastai.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,16 @@ def data_loader():
3838
return xor_loader
3939

4040

41-
def test_fastai_callback(tmp_dir, data_loader):
41+
def test_fastai_callback(tmp_dir, data_loader, mocker):
4242
learn = tabular_learner(data_loader, metrics=accuracy)
4343
learn.remove_cb(ProgressCallback)
4444
learn.model_dir = os.path.abspath("./")
4545
callback = DVCLiveCallback("model")
4646
live = callback.live
47+
48+
spy = mocker.spy(live, "end")
4749
learn.fit_one_cycle(2, cbs=[callback])
50+
spy.assert_called_once()
4851

4952
assert os.path.exists(live.dir)
5053

tests/test_huggingface.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def args():
101101
)
102102

103103

104-
def test_huggingface_integration(tmp_dir, model, args, data):
104+
def test_huggingface_integration(tmp_dir, model, args, data, mocker):
105105
trainer = Trainer(
106106
model,
107107
args,
@@ -110,8 +110,11 @@ def test_huggingface_integration(tmp_dir, model, args, data):
110110
compute_metrics=compute_metrics,
111111
)
112112
callback = DVCLiveCallback()
113+
live = callback.live
114+
spy = mocker.spy(live, "end")
113115
trainer.add_callback(callback)
114116
trainer.train()
117+
spy.assert_called_once()
115118

116119
live = callback.live
117120
assert os.path.exists(live.dir)

0 commit comments

Comments
 (0)