11from pathlib import Path , PurePath
22
3+ from dvclive .utils import isinstance_without_import
4+
35from .base import Data
46
57
@@ -17,20 +19,25 @@ def output_path(self) -> Path:
1719 def could_log (val : object ) -> bool :
1820 acceptable = {
1921 ("numpy" , "ndarray" ),
22+ ("matplotlib.figure" , "Figure" ),
2023 ("PIL.Image" , "Image" ),
2124 }
2225 for cls in type (val ).mro ():
23- if ( cls . __module__ , cls . __name__ ) in acceptable :
26+ if any ( isinstance_without_import ( val , * cls ) for cls in acceptable ) :
2427 return True
2528 if isinstance (val , (PurePath , str )):
2629 return True
2730 return False
2831
2932 def dump (self , val , ** kwargs ) -> None : # noqa: ARG002
30- if val . __class__ . __module__ == "numpy" :
33+ if isinstance_without_import ( val , "numpy" , "ndarray" ) :
3134 from PIL import Image as ImagePIL
3235
33- pil_image = ImagePIL .fromarray (val )
34- else :
35- pil_image = val
36- pil_image .save (self .output_path )
36+ ImagePIL .fromarray (val ).save (self .output_path )
37+ elif isinstance_without_import (val , "matplotlib.figure" , "Figure" ):
38+ import matplotlib .pyplot as plt
39+
40+ plt .savefig (self .output_path )
41+ plt .close (val )
42+ elif isinstance_without_import (val , "PIL.Image" , "Image" ):
43+ val .save (self .output_path )
0 commit comments