Skip to content

Commit ee7968b

Browse files
authored
log_image: Support matplotlib.figure.Figure as input. (#658)
Closes #224
1 parent e8d008e commit ee7968b

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

src/dvclive/plots/image.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from pathlib import Path, PurePath
22

3+
from dvclive.utils import isinstance_without_import
4+
35
from .base import Data
46

57

@@ -17,20 +19,25 @@ def output_path(self) -> Path:
1719
def could_log(val: object) -> bool:
1820
acceptable = {
1921
("numpy", "ndarray"),
22+
("matplotlib.figure", "Figure"),
2023
("PIL.Image", "Image"),
2124
}
2225
for cls in type(val).mro():
23-
if (cls.__module__, cls.__name__) in acceptable:
26+
if any(isinstance_without_import(val, *cls) for cls in acceptable):
2427
return True
2528
if isinstance(val, (PurePath, str)):
2629
return True
2730
return False
2831

2932
def dump(self, val, **kwargs) -> None: # noqa: ARG002
30-
if val.__class__.__module__ == "numpy":
33+
if isinstance_without_import(val, "numpy", "ndarray"):
3134
from PIL import Image as ImagePIL
3235

33-
pil_image = ImagePIL.fromarray(val)
34-
else:
35-
pil_image = val
36-
pil_image.save(self.output_path)
36+
ImagePIL.fromarray(val).save(self.output_path)
37+
elif isinstance_without_import(val, "matplotlib.figure", "Figure"):
38+
import matplotlib.pyplot as plt
39+
40+
plt.savefig(self.output_path)
41+
plt.close(val)
42+
elif isinstance_without_import(val, "PIL.Image", "Image"):
43+
val.save(self.output_path)

src/dvclive/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,10 @@ def clean_and_copy_into(src: StrPath, dst: StrPath) -> str:
148148
shutil.copy2(src, dst_path)
149149

150150
return str(dst_path)
151+
152+
153+
def isinstance_without_import(val, module, name):
154+
for cls in type(val).mro():
155+
if (cls.__module__, cls.__name__) == (module, name):
156+
return True
157+
return False

tests/plots/test_image.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import matplotlib.pyplot as plt
12
import numpy as np
23
import pytest
34
from PIL import Image
@@ -100,3 +101,17 @@ def test_custom_class(tmp_dir):
100101
live.log_image("image.png", extended_img)
101102

102103
assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists()
104+
105+
106+
def test_matplotlib(tmp_dir):
107+
live = Live()
108+
fig, ax = plt.subplots()
109+
ax.plot([1, 2, 3, 4])
110+
111+
assert plt.fignum_exists(fig.number)
112+
113+
live.log_image("image.png", fig)
114+
115+
assert not plt.fignum_exists(fig.number)
116+
117+
assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists()

0 commit comments

Comments
 (0)