Skip to content

Commit e9ed5a0

Browse files
daavooDave Berenbaum
andcommitted
report: Handle custom names used in sklearn_plots.
Closes #370 Co-authored-by: Dave Berenbaum <[email protected]>
1 parent 1c26ee1 commit e9ed5a0

File tree

2 files changed

+64
-26
lines changed

2 files changed

+64
-26
lines changed

src/dvclive/report.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from dvclive import Live
1818

1919

20+
# noqa pylint: disable=protected-access
21+
22+
2023
def get_scalar_renderers(metrics_path):
2124
renderers = []
2225
for suffix in Metric.suffixes:
@@ -55,17 +58,33 @@ def get_image_renderers(images_folder):
5558
return renderers
5659

5760

58-
def get_plot_renderers(plots_folder):
61+
def get_plot_renderers(plots_folder, live):
5962
renderers = []
6063
for suffix in SKLearnPlot.suffixes:
6164
for file in Path(plots_folder).rglob(f"*{suffix}"):
62-
name = file.stem
65+
name = file.relative_to(plots_folder).with_suffix("").as_posix()
66+
properties = {}
67+
68+
if name in SKLEARN_PLOTS:
69+
properties = SKLEARN_PLOTS[name].get_properties()
70+
data_field = name
71+
else:
72+
# Plot with custom name
73+
logged_plot = live._plots[name]
74+
for default_name, plot_class in SKLEARN_PLOTS.items():
75+
if isinstance(logged_plot, plot_class):
76+
properties = plot_class.get_properties()
77+
data_field = default_name
78+
break
79+
6380
data = json.loads(file.read_text())
64-
if name in data:
65-
data = data[name]
81+
82+
if data_field in data:
83+
data = data[data_field]
84+
6685
for row in data:
6786
row["rev"] = "workspace"
68-
properties = SKLEARN_PLOTS[name].get_properties()
87+
6988
renderers.append(VegaRenderer(data, name, **properties))
7089
return renderers
7190

@@ -94,19 +113,21 @@ def get_params_renderers(dvclive_params):
94113
return []
95114

96115

97-
def make_report(dvclive: "Live"):
98-
plots_path = Path(dvclive.plots_dir)
116+
def make_report(live: "Live"):
117+
plots_path = Path(live.plots_dir)
99118

100119
renderers = []
101-
renderers.extend(get_params_renderers(dvclive.params_file))
102-
renderers.extend(get_metrics_renderers(dvclive.metrics_file))
120+
renderers.extend(get_params_renderers(live.params_file))
121+
renderers.extend(get_metrics_renderers(live.metrics_file))
103122
renderers.extend(get_scalar_renderers(plots_path / Metric.subfolder))
104123
renderers.extend(get_image_renderers(plots_path / Image.subfolder))
105-
renderers.extend(get_plot_renderers(plots_path / SKLearnPlot.subfolder))
106-
107-
if dvclive.report_mode == "html":
108-
render_html(renderers, dvclive.report_file, refresh_seconds=5)
109-
elif dvclive.report_mode == "md":
110-
render_markdown(renderers, dvclive.report_file)
124+
renderers.extend(
125+
get_plot_renderers(plots_path / SKLearnPlot.subfolder, live)
126+
)
127+
128+
if live.report_mode == "html":
129+
render_html(renderers, live.report_file, refresh_seconds=5)
130+
elif live.report_mode == "md":
131+
render_markdown(renderers, live.report_file)
111132
else:
112-
raise ValueError(f"Invalid `mode` {dvclive.report_mode}.")
133+
raise ValueError(f"Invalid `mode` {live.report_mode}.")

tests/test_report.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dvclive.env import DVCLIVE_OPEN
99
from dvclive.plots import Image as LiveImage
1010
from dvclive.plots import Metric
11-
from dvclive.plots.sklearn import ConfusionMatrix, SKLearnPlot
11+
from dvclive.plots.sklearn import ConfusionMatrix, Roc, SKLearnPlot
1212
from dvclive.report import (
1313
get_image_renderers,
1414
get_metrics_renderers,
@@ -31,6 +31,12 @@ def test_get_renderers(tmp_dir, mocker):
3131
live.next_step()
3232

3333
live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1])
34+
live.log_sklearn_plot(
35+
"confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1], name="train/cm"
36+
)
37+
live.log_sklearn_plot(
38+
"roc", [0, 0, 1, 1], [1, 0.1, 0, 1], name="roc_curve"
39+
)
3440

3541
image_renderers = get_image_renderers(
3642
tmp_dir / live.plots_dir / LiveImage.subfolder
@@ -63,16 +69,27 @@ def test_get_renderers(tmp_dir, mocker):
6369
assert scalar_renderers[0].name == "static/foo/bar"
6470

6571
plot_renderers = get_plot_renderers(
66-
tmp_dir / live.plots_dir / SKLearnPlot.subfolder
72+
tmp_dir / live.plots_dir / SKLearnPlot.subfolder, live
6773
)
68-
assert len(plot_renderers) == 1
69-
assert plot_renderers[0].datapoints == [
70-
{"actual": "0", "rev": "workspace", "predicted": "1"},
71-
{"actual": "0", "rev": "workspace", "predicted": "0"},
72-
{"actual": "1", "rev": "workspace", "predicted": "0"},
73-
{"actual": "1", "rev": "workspace", "predicted": "1"},
74-
]
75-
assert plot_renderers[0].properties == ConfusionMatrix.get_properties()
74+
assert len(plot_renderers) == 3
75+
for plot_renderer in plot_renderers:
76+
if plot_renderer.name == "roc_curve":
77+
assert plot_renderer.datapoints == [
78+
{"fpr": 0.0, "rev": "workspace", "threshold": 2.0, "tpr": 0.0},
79+
{"fpr": 0.5, "rev": "workspace", "threshold": 1.0, "tpr": 0.5},
80+
{"fpr": 1.0, "rev": "workspace", "threshold": 0.1, "tpr": 0.5},
81+
{"fpr": 1.0, "rev": "workspace", "threshold": 0.0, "tpr": 1.0},
82+
]
83+
assert plot_renderer.properties == Roc.get_properties()
84+
85+
else:
86+
assert plot_renderer.datapoints == [
87+
{"actual": "0", "rev": "workspace", "predicted": "1"},
88+
{"actual": "0", "rev": "workspace", "predicted": "0"},
89+
{"actual": "1", "rev": "workspace", "predicted": "0"},
90+
{"actual": "1", "rev": "workspace", "predicted": "1"},
91+
]
92+
assert plot_renderer.properties == ConfusionMatrix.get_properties()
7693

7794
metrics_renderer = get_metrics_renderers(live.metrics_file)[0]
7895
assert metrics_renderer.datapoints == [{"step": 1, "foo": {"bar": 1}}]

0 commit comments

Comments
 (0)