diff --git a/torchbenchmark/util/env_check.py b/torchbenchmark/util/env_check.py index 9473f7bec..1476d353d 100644 --- a/torchbenchmark/util/env_check.py +++ b/torchbenchmark/util/env_check.py @@ -3,14 +3,17 @@ This file may be loaded without torch packages installed, e.g., in OnDemand CI. """ -import argparse import copy import logging import os import shutil from collections.abc import Mapping from contextlib import contextmanager, ExitStack -from typing import Any, Dict, List, Optional +from importlib.metadata import version +from typing import Dict, List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + import torchbenchmark MAIN_RANDOM_SEED = 1337 @@ -191,14 +194,9 @@ def deterministic_torch_manual_seed(*args, **kwargs): def get_pkg_versions(packages: List[str]) -> Dict[str, str]: - import subprocess - import sys - versions = {} for module in packages: - cmd = [sys.executable, "-c", f"import {module}; print({module}.__version__)"] - version = subprocess.check_output(cmd).decode().strip() - versions[module] = version + versions[module] = version(module) return versions @@ -257,7 +255,7 @@ def save_deterministic_dict(name: str): torch.backends.cuda.matmul.allow_tf32 ) - if not name in UNSUPPORTED_USE_DETERMINISTIC_ALGORITHMS: + if name not in UNSUPPORTED_USE_DETERMINISTIC_ALGORITHMS: torch.use_deterministic_algorithms(True) torch.backends.cudnn.deterministic = True torch.backends.cudnn.allow_tf32 = False @@ -449,7 +447,7 @@ def reduce_to_scalar_loss(out): return sum([reduce_to_scalar_loss(value) for value in out.values()]) / len( out.keys() ) - elif out == None: + elif out is None: return 0.0 raise NotImplementedError("Don't know how to reduce", type(out)) @@ -696,7 +694,7 @@ def maybe_cast(tbmodel, model, example_inputs): equal_nan=equal_nan, ): is_same = False - except Exception as e: + except Exception: # Sometimes torch.allclose may throw RuntimeError is_same = False @@ -747,7 +745,7 @@ def maybe_cast(tbmodel, model, example_inputs): tol=tolerance, ): is_same = False - except Exception as e: + except Exception: # Sometimes torch.allclose may throw RuntimeError is_same = False diff --git a/utils/__init__.py b/utils/__init__.py index 9bc06335d..a0ed0493b 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,5 +1,5 @@ -import subprocess import sys +from importlib.metadata import version from pathlib import Path from typing import Dict, List @@ -24,9 +24,7 @@ def __exit__(self, exc_type, exc_value, traceback): def get_pkg_versions(packages: List[str]) -> Dict[str, str]: versions = {} for module in packages: - cmd = [sys.executable, "-c", f"import {module}; print({module}.__version__)"] - version = subprocess.check_output(cmd).decode().strip() - versions[module] = version + versions[module] = version(module) return versions