88from dvclive .env import DVCLIVE_OPEN
99from dvclive .plots import Image as LiveImage
1010from dvclive .plots import Metric
11- from dvclive .plots .sklearn import ConfusionMatrix , SKLearnPlot
11+ from dvclive .plots .sklearn import ConfusionMatrix , Roc , SKLearnPlot
1212from dvclive .report import (
1313 get_image_renderers ,
1414 get_metrics_renderers ,
@@ -31,6 +31,12 @@ def test_get_renderers(tmp_dir, mocker):
3131 live .next_step ()
3232
3333 live .log_sklearn_plot ("confusion_matrix" , [0 , 0 , 1 , 1 ], [1 , 0 , 0 , 1 ])
34+ live .log_sklearn_plot (
35+ "confusion_matrix" , [0 , 0 , 1 , 1 ], [1 , 0 , 0 , 1 ], name = "train/cm"
36+ )
37+ live .log_sklearn_plot (
38+ "roc" , [0 , 0 , 1 , 1 ], [1 , 0.1 , 0 , 1 ], name = "roc_curve"
39+ )
3440
3541 image_renderers = get_image_renderers (
3642 tmp_dir / live .plots_dir / LiveImage .subfolder
@@ -63,16 +69,27 @@ def test_get_renderers(tmp_dir, mocker):
6369 assert scalar_renderers [0 ].name == "static/foo/bar"
6470
6571 plot_renderers = get_plot_renderers (
66- tmp_dir / live .plots_dir / SKLearnPlot .subfolder
72+ tmp_dir / live .plots_dir / SKLearnPlot .subfolder , live
6773 )
68- assert len (plot_renderers ) == 1
69- assert plot_renderers [0 ].datapoints == [
70- {"actual" : "0" , "rev" : "workspace" , "predicted" : "1" },
71- {"actual" : "0" , "rev" : "workspace" , "predicted" : "0" },
72- {"actual" : "1" , "rev" : "workspace" , "predicted" : "0" },
73- {"actual" : "1" , "rev" : "workspace" , "predicted" : "1" },
74- ]
75- assert plot_renderers [0 ].properties == ConfusionMatrix .get_properties ()
74+ assert len (plot_renderers ) == 3
75+ for plot_renderer in plot_renderers :
76+ if plot_renderer .name == "roc_curve" :
77+ assert plot_renderer .datapoints == [
78+ {"fpr" : 0.0 , "rev" : "workspace" , "threshold" : 2.0 , "tpr" : 0.0 },
79+ {"fpr" : 0.5 , "rev" : "workspace" , "threshold" : 1.0 , "tpr" : 0.5 },
80+ {"fpr" : 1.0 , "rev" : "workspace" , "threshold" : 0.1 , "tpr" : 0.5 },
81+ {"fpr" : 1.0 , "rev" : "workspace" , "threshold" : 0.0 , "tpr" : 1.0 },
82+ ]
83+ assert plot_renderer .properties == Roc .get_properties ()
84+
85+ else :
86+ assert plot_renderer .datapoints == [
87+ {"actual" : "0" , "rev" : "workspace" , "predicted" : "1" },
88+ {"actual" : "0" , "rev" : "workspace" , "predicted" : "0" },
89+ {"actual" : "1" , "rev" : "workspace" , "predicted" : "0" },
90+ {"actual" : "1" , "rev" : "workspace" , "predicted" : "1" },
91+ ]
92+ assert plot_renderer .properties == ConfusionMatrix .get_properties ()
7693
7794 metrics_renderer = get_metrics_renderers (live .metrics_file )[0 ]
7895 assert metrics_renderer .datapoints == [{"step" : 1 , "foo" : {"bar" : 1 }}]
0 commit comments