Skip to content

Commit 3414163

Browse files
Fixes #812 (#813)
* Fixes #812 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * named arguments in log_sklearn_plot * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * replaced sklearn_kwargs with kwargs --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 228e9a8 commit 3414163

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

src/dvclive/live.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,10 @@ def log_sklearn_plot(
624624
labels: Union[List, np.ndarray],
625625
predictions: Union[List, Tuple, np.ndarray],
626626
name: Optional[str] = None,
627+
title: Optional[str] = None,
628+
x_label: Optional[str] = None,
629+
y_label: Optional[str] = None,
630+
normalized: Optional[bool] = None,
627631
**kwargs,
628632
):
629633
"""
@@ -638,14 +642,17 @@ def log_sklearn_plot(
638642
"roc"): a supported plot type.
639643
labels (List | np.ndarray): array of ground truth labels.
640644
predictions (List | np.ndarray): array of predicted labels (for
641-
`"confusion_matrix"`) or predicted probabilities (for other plots).
645+
`"confusion_matrix"`) or predicted probabilities (for other plots).
642646
name (str): optional name of the output file. If not provided, `kind` will
643-
be used as name.
647+
be used as name.
648+
title (str): optional title to be displayed.
649+
x_label (str): optional label for the x axis.
650+
y_label (str): optional label for the y axis.
651+
normalized (bool): optional, `confusion_matrix` with values normalized to
652+
`<0, 1>` range.
644653
kwargs: additional arguments to tune the result. Arguments are passed to the
645654
scikit-learn function (e.g. `drop_intermediate=True` for the `"roc"`
646-
type). Plus extra arguments supported by the type of a plot are:
647-
- `normalized`: default to `False`. `confusion_matrix` with values
648-
normalized to `<0, 1>` range.
655+
type).
649656
Raises:
650657
InvalidPlotTypeError: thrown if the provided `kind` does not correspond to
651658
any of the supported plots.
@@ -654,9 +661,15 @@ def log_sklearn_plot(
654661

655662
plot_config = {
656663
k: v
657-
for k, v in kwargs.items()
658-
if k in ("title", "x_label", "y_label", "normalized")
664+
for k, v in {
665+
"title": title,
666+
"x_label": x_label,
667+
"y_label": y_label,
668+
"normalized": normalized,
669+
}.items()
670+
if v is not None
659671
}
672+
660673
name = name or kind
661674
if name in self._plots:
662675
plot = self._plots[name]
@@ -666,11 +679,8 @@ def log_sklearn_plot(
666679
else:
667680
raise InvalidPlotTypeError(name)
668681

669-
sklearn_kwargs = {
670-
k: v for k, v in kwargs.items() if k not in plot_config or k != "normalized"
671-
}
672682
plot.step = self.step
673-
plot.dump(val, **sklearn_kwargs)
683+
plot.dump(val, **kwargs)
674684
logger.debug(f"Logged {name}")
675685

676686
def _read_params(self):

tests/plots/test_sklearn.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_custom_title(tmp_dir, y_true_y_pred_y_score):
162162
live = Live()
163163
out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder
164164

165-
y_true, y_pred, _ = y_true_y_pred_y_score
165+
y_true, y_pred, y_score = y_true_y_pred_y_score
166166

167167
live.log_sklearn_plot(
168168
"confusion_matrix",
@@ -174,8 +174,38 @@ def test_custom_title(tmp_dir, y_true_y_pred_y_score):
174174
live.log_sklearn_plot(
175175
"confusion_matrix", y_true, y_pred, name="val/cm", title="Val Confusion Matrix"
176176
)
177+
live.log_sklearn_plot(
178+
"precision_recall",
179+
y_true,
180+
y_score,
181+
name="val/prc",
182+
title="Val Precision Recall",
183+
)
177184
assert (out / "train" / "cm.json").exists()
178185
assert (out / "val" / "cm.json").exists()
186+
assert (out / "val" / "prc.json").exists()
179187

180188
assert live._plots["train/cm"].plot_config["title"] == "Train Confusion Matrix"
181189
assert live._plots["val/cm"].plot_config["title"] == "Val Confusion Matrix"
190+
assert live._plots["val/prc"].plot_config["title"] == "Val Precision Recall"
191+
192+
193+
def test_custom_labels(tmp_dir, y_true_y_pred_y_score):
194+
"""https://github.com/iterative/dvclive/issues/453"""
195+
live = Live()
196+
out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder
197+
198+
y_true, _, y_score = y_true_y_pred_y_score
199+
200+
live.log_sklearn_plot(
201+
"precision_recall",
202+
y_true,
203+
y_score,
204+
name="val/prc",
205+
x_label="x_test",
206+
y_label="y_test",
207+
)
208+
assert (out / "val" / "prc.json").exists()
209+
210+
assert live._plots["val/prc"].plot_config["x_label"] == "x_test"
211+
assert live._plots["val/prc"].plot_config["y_label"] == "y_test"

0 commit comments

Comments
 (0)