Skip to content

Commit 9088996

Browse files
sispdaavoopre-commit-ci[bot]
authored
xgb: infer metric data names from evals and deprecate metric_data (#587)
* xgb: add support for multiple metrics data sets * xgb: infer metric data names from `evals` and deprecate `metric_data` * xgb: improve code aesthetics Co-authored-by: David de la Iglesia Castro <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: David de la Iglesia Castro <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ce68fe1 commit 9088996

File tree

2 files changed

+54
-11
lines changed

2 files changed

+54
-11
lines changed

src/dvclive/xgb.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# ruff: noqa: ARG002
22
from typing import Optional
3+
from warnings import warn
34

45
from xgboost.callback import TrainingCallback
56

@@ -8,18 +9,29 @@
89

910
class DVCLiveCallback(TrainingCallback):
1011
def __init__(
11-
self, metric_data, model_file=None, live: Optional[Live] = None, **kwargs
12+
self,
13+
metric_data: Optional[str] = None,
14+
model_file=None,
15+
live: Optional[Live] = None,
16+
**kwargs,
1217
):
1318
super().__init__()
19+
if metric_data is not None:
20+
warn(
21+
"`metric_data` is deprecated and will be removed",
22+
category=DeprecationWarning,
23+
stacklevel=2,
24+
)
1425
self._metric_data = metric_data
1526
self.model_file = model_file
1627
self.live = live if live is not None else Live(**kwargs)
1728

1829
def after_iteration(self, model, epoch, evals_log):
19-
for key, values in evals_log[self._metric_data].items():
20-
if values:
21-
latest_metric = values[-1]
22-
self.live.log_metric(key, latest_metric)
30+
if self._metric_data:
31+
evals_log = {"": evals_log[self._metric_data]}
32+
for subdir, data in evals_log.items():
33+
for key, values in data.items():
34+
self.live.log_metric(f"{subdir}/{key}" if subdir else key, values[-1])
2335
if self.model_file:
2436
model.save_model(self.model_file)
2537
self.live.next_step()

tests/test_frameworks/test_xgboost.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import os
2+
from contextlib import nullcontext
23

34
import pytest
45

56
from dvclive import Live
7+
from dvclive.plots.metric import Metric
68
from dvclive.utils import parse_metrics
79

810
try:
911
import numpy as np
1012
import pandas as pd
1113
import xgboost as xgb
1214
from sklearn import datasets
15+
from sklearn.model_selection import train_test_split
1316

1417
from dvclive.xgb import DVCLiveCallback
1518
except ImportError:
@@ -29,24 +32,52 @@ def iris_data():
2932
return xgb.DMatrix(x, y)
3033

3134

32-
def test_xgb_integration(tmp_dir, train_params, iris_data, mocker):
33-
callback = DVCLiveCallback("eval_data")
35+
@pytest.fixture()
36+
def iris_train_eval_data():
37+
iris = datasets.load_iris()
38+
x_train, x_eval, y_train, y_eval = train_test_split(
39+
iris.data, iris.target, random_state=0
40+
)
41+
return (xgb.DMatrix(x_train, y_train), xgb.DMatrix(x_eval, y_eval))
42+
43+
44+
@pytest.mark.parametrize(
45+
("metric_data", "subdirs", "context"),
46+
[
47+
(
48+
"eval",
49+
("",),
50+
pytest.warns(DeprecationWarning, match="`metric_data`.+deprecated"),
51+
),
52+
(None, ("train", "eval"), nullcontext()),
53+
],
54+
)
55+
def test_xgb_integration(
56+
tmp_dir, train_params, iris_train_eval_data, metric_data, subdirs, context, mocker
57+
):
58+
with context:
59+
callback = DVCLiveCallback(metric_data)
3460
live = callback.live
3561
spy = mocker.spy(live, "end")
62+
data_train, data_eval = iris_train_eval_data
3663
xgb.train(
3764
train_params,
38-
iris_data,
65+
data_train,
3966
callbacks=[callback],
4067
num_boost_round=5,
41-
evals=[(iris_data, "eval_data")],
68+
evals=[(data_train, "train"), (data_eval, "eval")],
4269
)
4370
spy.assert_called_once()
4471

4572
assert os.path.exists("dvclive")
4673

4774
logs, _ = parse_metrics(callback.live)
48-
assert len(logs) == 1
49-
assert len(list(logs.values())[0]) == 5
75+
assert len(logs) == len(subdirs)
76+
assert list(map(len, logs.values())) == [5] * len(logs)
77+
scalars = os.path.join(callback.live.plots_dir, Metric.subfolder)
78+
assert all(
79+
os.path.join(scalars, subdir, "mlogloss.tsv") in logs for subdir in subdirs
80+
)
5081

5182

5283
def test_xgb_model_file(tmp_dir, train_params, iris_data):

0 commit comments

Comments
 (0)