Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions torchbenchmark/util/env_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
This file may be loaded without torch packages installed, e.g., in OnDemand CI.
"""

import argparse
import copy
import logging
import os
import shutil
from collections.abc import Mapping
from contextlib import contextmanager, ExitStack
from typing import Any, Dict, List, Optional
from importlib.metadata import version
from typing import Dict, List, Optional, TYPE_CHECKING

if TYPE_CHECKING:
import torchbenchmark


MAIN_RANDOM_SEED = 1337
Expand Down Expand Up @@ -191,14 +194,9 @@ def deterministic_torch_manual_seed(*args, **kwargs):


def get_pkg_versions(packages: List[str]) -> Dict[str, str]:
import subprocess
import sys

versions = {}
for module in packages:
cmd = [sys.executable, "-c", f"import {module}; print({module}.__version__)"]
version = subprocess.check_output(cmd).decode().strip()
versions[module] = version
versions[module] = version(module)
return versions


Expand Down Expand Up @@ -257,7 +255,7 @@ def save_deterministic_dict(name: str):
torch.backends.cuda.matmul.allow_tf32
)

if not name in UNSUPPORTED_USE_DETERMINISTIC_ALGORITHMS:
if name not in UNSUPPORTED_USE_DETERMINISTIC_ALGORITHMS:
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
Expand Down Expand Up @@ -449,7 +447,7 @@ def reduce_to_scalar_loss(out):
return sum([reduce_to_scalar_loss(value) for value in out.values()]) / len(
out.keys()
)
elif out == None:
elif out is None:
return 0.0
raise NotImplementedError("Don't know how to reduce", type(out))

Expand Down Expand Up @@ -696,7 +694,7 @@ def maybe_cast(tbmodel, model, example_inputs):
equal_nan=equal_nan,
):
is_same = False
except Exception as e:
except Exception:
# Sometimes torch.allclose may throw RuntimeError
is_same = False

Expand Down Expand Up @@ -747,7 +745,7 @@ def maybe_cast(tbmodel, model, example_inputs):
tol=tolerance,
):
is_same = False
except Exception as e:
except Exception:
# Sometimes torch.allclose may throw RuntimeError
is_same = False

Expand Down
6 changes: 2 additions & 4 deletions utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import subprocess
import sys
from importlib.metadata import version
from pathlib import Path
from typing import Dict, List

Expand All @@ -24,9 +24,7 @@ def __exit__(self, exc_type, exc_value, traceback):
def get_pkg_versions(packages: List[str]) -> Dict[str, str]:
versions = {}
for module in packages:
cmd = [sys.executable, "-c", f"import {module}; print({module}.__version__)"]
version = subprocess.check_output(cmd).decode().strip()
versions[module] = version
versions[module] = version(module)
return versions


Expand Down
Loading