Skip to content

Commit 5297151

Browse files
authored
live.log_metric: Cast nan and inf to string. (#677)
* live.log_metric: Cast `nan` and `inf` to string. - Support string metrics (forgot to support them when DVC added support) - Don't recast to float when sending to Studio. Closes iterative/studio-support#93 * Fix and move test
1 parent fdd84af commit 5297151

File tree

5 files changed

+47
-15
lines changed

5 files changed

+47
-15
lines changed

src/dvclive/live.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import logging
3+
import math
34
import os
45
import shutil
56
from pathlib import Path
@@ -299,13 +300,16 @@ def next_step(self):
299300
def log_metric(
300301
self,
301302
name: str,
302-
val: Union[int, float],
303+
val: Union[int, float, str],
303304
timestamp: bool = False,
304305
plot: bool = True,
305306
):
306307
if not Metric.could_log(val):
307308
raise InvalidDataTypeError(name, type(val))
308309

310+
if not isinstance(val, str) and (math.isnan(val) or math.isinf(val)):
311+
val = str(val)
312+
309313
if name in self._metrics:
310314
metric = self._metrics[name]
311315
else:

src/dvclive/plots/metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class Metric(Data):
1414

1515
@staticmethod
1616
def could_log(val: object) -> bool:
17-
if isinstance(val, (int, float)):
17+
if isinstance(val, (int, float, str)):
1818
return True
1919
if (
2020
val.__class__.__module__ == "numpy"

src/dvclive/studio.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# ruff: noqa: SLF001
22
import base64
3+
import math
34
import os
45
from pathlib import Path
56

@@ -21,7 +22,11 @@ def _cast_to_numbers(datapoints):
2122
elif k == "timestamp":
2223
continue
2324
else:
24-
datapoint[k] = float(v)
25+
float_v = float(v)
26+
if math.isnan(float_v) or math.isinf(float_v):
27+
datapoint[k] = str(v)
28+
else:
29+
datapoint[k] = float_v
2530
return datapoints
2631

2732

tests/test_log_metric.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import math
2+
3+
import numpy as np
4+
import pytest
5+
6+
from dvclive import Live
7+
from dvclive.error import InvalidDataTypeError
8+
9+
10+
@pytest.mark.parametrize("invalid_type", [{0: 1}, [0, 1], (0, 1)])
11+
def test_invalid_metric_type(tmp_dir, invalid_type):
12+
dvclive = Live()
13+
14+
with pytest.raises(
15+
InvalidDataTypeError,
16+
match=f"Data 'm' has not supported type {type(invalid_type)}",
17+
):
18+
dvclive.log_metric("m", invalid_type)
19+
20+
21+
@pytest.mark.parametrize(
22+
("val"),
23+
[math.inf, math.nan, np.nan, np.inf],
24+
)
25+
def test_log_metric_inf_nan(tmp_dir, val):
26+
with Live() as live:
27+
live.log_metric("metric", val)
28+
assert live.summary["metric"] == str(val)
29+
30+
31+
def test_log_metic_str(tmp_dir):
32+
with Live() as live:
33+
live.log_metric("metric", "foo")
34+
assert live.summary["metric"] == "foo"

tests/test_main.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from PIL import Image
77

88
from dvclive import Live, env
9-
from dvclive.error import InvalidDataTypeError, InvalidParameterTypeError
9+
from dvclive.error import InvalidParameterTypeError
1010
from dvclive.plots import Metric
1111
from dvclive.serialize import load_yaml
1212
from dvclive.utils import parse_metrics, parse_tsv
@@ -272,17 +272,6 @@ def test_log_reset_with_set_step(tmp_dir):
272272
assert read_latest(dvclive, "val_m") == (2, 1)
273273

274274

275-
@pytest.mark.parametrize("invalid_type", [{0: 1}, [0, 1], "foo", (0, 1)])
276-
def test_invalid_metric_type(tmp_dir, invalid_type):
277-
dvclive = Live()
278-
279-
with pytest.raises(
280-
InvalidDataTypeError,
281-
match=f"Data 'm' has not supported type {type(invalid_type)}",
282-
):
283-
dvclive.log_metric("m", invalid_type)
284-
285-
286275
def test_get_step_resume(tmp_dir):
287276
dvclive = Live()
288277

0 commit comments

Comments
 (0)