Skip to content

Commit 36382c5

Browse files
pacifikusdaavoo
andauthored
Added LightGBM (#134)
* LGBM integration added * Tests for lgbm integration * test_lgbm: added tmp_dir fixture * Update dvclive/lgbm.py Save model at the end of each iteration Co-authored-by: David de la Iglesia Castro <[email protected]> Co-authored-by: David de la Iglesia Castro <[email protected]>
1 parent bff563d commit 36382c5

File tree

3 files changed

+90
-1
lines changed

3 files changed

+90
-1
lines changed

dvclive/lgbm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import dvclive
2+
3+
4+
class DvcLiveCallback:
5+
def __init__(self, model_file=None):
6+
super().__init__()
7+
self.model_file = model_file
8+
9+
def __call__(self, env):
10+
for eval_result in env.evaluation_result_list:
11+
metric = eval_result[1]
12+
value = eval_result[2]
13+
dvclive.log(metric, value)
14+
if self.model_file:
15+
env.model.save_model(self.model_file)
16+
dvclive.next_step()

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ def run(self):
3939
mmcv = ["mmcv", "torch", "torchvision"]
4040
tf = ["tensorflow"]
4141
xgb = ["xgboost"]
42+
lgbm = ["lightgbm"]
4243

43-
all_libs = mmcv + tf + xgb
44+
all_libs = mmcv + tf + xgb + lgbm
4445

4546
tests_requires = [
4647
"pylint==2.5.3",
@@ -71,6 +72,7 @@ def run(self):
7172
"all": all_libs,
7273
"tf": tf,
7374
"xgb": xgb,
75+
"lgbm": lgbm,
7476
},
7577
keywords="data-science metrics machine-learning developer-tools ai",
7678
python_requires=">=3.6",

tests/test_lgbm.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
3+
import lightgbm as lgbm
4+
import numpy as np
5+
import pandas as pd
6+
import pytest
7+
from funcy import first
8+
from sklearn import datasets
9+
from sklearn.model_selection import train_test_split
10+
11+
import dvclive
12+
from dvclive.lgbm import DvcLiveCallback
13+
from tests.test_main import read_logs
14+
15+
# pylint: disable=redefined-outer-name, unused-argument
16+
17+
18+
@pytest.fixture
19+
def model_params():
20+
return {"objective": "multiclass", "n_estimators": 5, "seed": 0}
21+
22+
23+
@pytest.fixture
24+
def iris_data():
25+
iris = datasets.load_iris()
26+
x = pd.DataFrame(iris["data"], columns=iris["feature_names"])
27+
y = iris["target"]
28+
x_train, x_test, y_train, y_test = train_test_split(
29+
x, y, test_size=0.33, random_state=42
30+
)
31+
return (x_train, y_train), (x_test, y_test)
32+
33+
34+
def test_lgbm_integration(tmp_dir, model_params, iris_data):
35+
dvclive.init("logs")
36+
model = lgbm.LGBMClassifier()
37+
model.set_params(**model_params)
38+
39+
model.fit(
40+
iris_data[0][0],
41+
iris_data[0][1],
42+
eval_set=(iris_data[1][0], iris_data[1][1]),
43+
eval_metric=["multi_logloss"],
44+
callbacks=[DvcLiveCallback()],
45+
)
46+
47+
assert os.path.exists("logs")
48+
49+
logs, _ = read_logs("logs")
50+
assert len(logs) == 1
51+
assert len(first(logs.values())) == 5
52+
53+
54+
def test_lgbm_model_file(tmp_dir, model_params, iris_data):
55+
dvclive.init("logs")
56+
model = lgbm.LGBMClassifier()
57+
model.set_params(**model_params)
58+
59+
model.fit(
60+
iris_data[0][0],
61+
iris_data[0][1],
62+
eval_set=(iris_data[1][0], iris_data[1][1]),
63+
eval_metric=["multi_logloss"],
64+
callbacks=[DvcLiveCallback("lgbm_model")],
65+
)
66+
67+
preds = model.predict(iris_data[1][0])
68+
model2 = lgbm.Booster(model_file="lgbm_model")
69+
preds2 = model2.predict(iris_data[1][0])
70+
preds2 = np.argmax(preds2, axis=1)
71+
assert np.sum(np.abs(preds2 - preds)) == 0

0 commit comments

Comments
 (0)