Skip to content

Commit 138debe

Browse files
pacifikusdaavoo
andauthored
Catalyst integration (#139)
* LGBM integration added * Tests for lgbm integration * test_lgbm: added tmp_dir fixture * Update dvclive/lgbm.py Save model at the end of each iteration Co-authored-by: David de la Iglesia Castro <[email protected]> * Huggingface integration * Add on_log event * Update dvclive/huggingface.py Co-authored-by: David de la Iglesia Castro <[email protected]> * fix: huggingface test after callback changes * revert last commit * fix: huggingface test after calback changes * Updated test_huggingface * Updated test_huggingface * Catalyst integration * Callback rename * Rename callback in test * Update dvclive/catalyst.py Co-authored-by: David de la Iglesia Castro <[email protected]> * upd: catalyst tests * Fix: cross platform tests Co-authored-by: David de la Iglesia Castro <[email protected]>
1 parent b00ffda commit 138debe

File tree

3 files changed

+132
-1
lines changed

3 files changed

+132
-1
lines changed

dvclive/catalyst.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from catalyst.core.callback import Callback, CallbackOrder
2+
3+
import dvclive
4+
5+
6+
class DvcLiveCallback(Callback):
7+
def __init__(self, model_file=None):
8+
super().__init__(order=CallbackOrder.external)
9+
self.model_file = model_file
10+
11+
def on_epoch_end(self, runner) -> None:
12+
step = runner.stage_epoch_step
13+
14+
for loader_key, per_loader_metrics in runner.epoch_metrics.items():
15+
for key, value in per_loader_metrics.items():
16+
key = key.replace("/", "_")
17+
dvclive.log(f"{loader_key}/{key}", float(value), step)
18+
19+
if self.model_file:
20+
checkpoint = runner.engine.pack_checkpoint(
21+
model=runner.model,
22+
criterion=runner.criterion,
23+
optimizer=runner.optimizer,
24+
scheduler=runner.scheduler,
25+
)
26+
runner.engine.save_checkpoint(checkpoint, self.model_file)
27+
dvclive.next_step()

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ def run(self):
4141
xgb = ["xgboost"]
4242
lgbm = ["lightgbm"]
4343
hugginface = ["transformers", "datasets"]
44+
catalyst = ["catalyst"]
4445

45-
all_libs = mmcv + tf + xgb + lgbm + hugginface
46+
all_libs = mmcv + tf + xgb + lgbm + hugginface + catalyst
4647

4748
tests_requires = [
4849
"pylint==2.5.3",
@@ -75,6 +76,7 @@ def run(self):
7576
"xgb": xgb,
7677
"lgbm": lgbm,
7778
"huggingface": hugginface,
79+
"catalyst": catalyst,
7880
},
7981
keywords="data-science metrics machine-learning developer-tools ai",
8082
python_requires=">=3.6",

tests/test_catalyst.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import os
2+
3+
import pytest
4+
from catalyst import dl
5+
from catalyst.contrib.datasets import MNIST
6+
from catalyst.data import ToTensor
7+
from catalyst.utils.torch import get_available_engine
8+
from torch import nn, optim
9+
from torch.utils.data import DataLoader
10+
11+
import dvclive
12+
from dvclive.catalyst import DvcLiveCallback
13+
14+
# pylint: disable=redefined-outer-name, unused-argument
15+
16+
17+
@pytest.fixture
18+
def loaders():
19+
train_data = MNIST(
20+
os.getcwd(), train=True, download=True, transform=ToTensor()
21+
)
22+
valid_data = MNIST(
23+
os.getcwd(), train=False, download=True, transform=ToTensor()
24+
)
25+
return {
26+
"train": DataLoader(train_data, batch_size=32),
27+
"valid": DataLoader(valid_data, batch_size=32),
28+
}
29+
30+
31+
@pytest.fixture
32+
def runner():
33+
return dl.SupervisedRunner(
34+
engine=get_available_engine(),
35+
input_key="features",
36+
output_key="logits",
37+
target_key="targets",
38+
loss_key="loss",
39+
)
40+
41+
42+
def test_catalyst_callback(tmp_dir, runner, loaders):
43+
dvclive.init("dvc_logs")
44+
45+
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
46+
criterion = nn.CrossEntropyLoss()
47+
optimizer = optim.Adam(model.parameters(), lr=0.02)
48+
49+
runner.train(
50+
model=model,
51+
criterion=criterion,
52+
optimizer=optimizer,
53+
loaders=loaders,
54+
num_epochs=2,
55+
callbacks=[
56+
dl.AccuracyCallback(input_key="logits", target_key="targets"),
57+
DvcLiveCallback(),
58+
],
59+
logdir="./logs",
60+
valid_loader="valid",
61+
valid_metric="loss",
62+
minimize_valid_metric=True,
63+
verbose=True,
64+
load_best_on_end=True,
65+
)
66+
67+
assert os.path.exists("dvc_logs")
68+
69+
train_path = tmp_dir / "dvc_logs/train"
70+
valid_path = tmp_dir / "dvc_logs/valid"
71+
72+
assert train_path.is_dir()
73+
assert valid_path.is_dir()
74+
assert (train_path / "accuracy.tsv").exists()
75+
76+
77+
def test_catalyst_model_file(tmp_dir, runner, loaders):
78+
dvclive.init("dvc_logs")
79+
80+
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
81+
criterion = nn.CrossEntropyLoss()
82+
optimizer = optim.Adam(model.parameters(), lr=0.02)
83+
84+
runner.train(
85+
model=model,
86+
engine=runner.engine,
87+
criterion=criterion,
88+
optimizer=optimizer,
89+
loaders=loaders,
90+
num_epochs=2,
91+
callbacks=[
92+
dl.AccuracyCallback(input_key="logits", target_key="targets"),
93+
DvcLiveCallback("model.pth"),
94+
],
95+
logdir="./logs",
96+
valid_loader="valid",
97+
valid_metric="loss",
98+
minimize_valid_metric=True,
99+
verbose=True,
100+
load_best_on_end=True,
101+
)
102+
assert (tmp_dir / "model.pth").is_file()

0 commit comments

Comments
 (0)