diff --git a/docs/source-fabric/api/accelerators.rst b/docs/source-fabric/api/accelerators.rst index 7e8444dd8e47e..218d80b4d6920 100644 --- a/docs/source-fabric/api/accelerators.rst +++ b/docs/source-fabric/api/accelerators.rst @@ -20,3 +20,4 @@ Accelerators CUDAAccelerator MPSAccelerator XLAAccelerator + MUSAAccelerator diff --git a/docs/source-pytorch/accelerators/musa.rst b/docs/source-pytorch/accelerators/musa.rst new file mode 100644 index 0000000000000..bc59a0355bf21 --- /dev/null +++ b/docs/source-pytorch/accelerators/musa.rst @@ -0,0 +1,77 @@ +:orphan: + +MUSA training (Advanced) +======================== +**Audience:** Users looking to train models on MooreThreads device using MUSA accelerator. + +.. warning:: This is an :ref:`experimental ` feature. + +---- + +MUSAAccelerator Overview +-------------------- +torch_musa is an extended Python package based on PyTorch that enables full utilization of MooreThreads graphics cards' +super computing power. Combined with PyTorch, users can take advantage of the strong power of MooreThreads graphics cards +through torch_musa. + +PyTorch Lightning automatically finds these weights and ties them after the modules are moved to the +MUSA device under the hood. It will ensure that the weights among the modules are shared but not copied +independently. + + +Example: + +.. code-block:: python + import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F + import pytorch_lightning as L + + # Step 1: Define a LightningModule + class LitAutoEncoder(L.LightningModule): + def __init__(self): + super().__init__() + self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) + self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + embedding = self.encoder(x) + return embedding + + def training_step(self, batch, batch_idx): + # training_step defines the train loop. It is independent of forward + x, _ = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + x_hat = self.decoder(z) + loss = F.mse_loss(x_hat, x) + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + + def main(): + # ------------------- + # Step 2: Define data + # ------------------- + dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) + train, val = data.random_split(dataset, [55000, 5000]) + + # ------------------- + # Step 3: Train + # ------------------- + autoencoder = LitAutoEncoder() + # we also support accelerator="auto" or accelerator="musa" + trainer = L.Trainer(accelerator="gpu") + trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val)) + + if __name__ == '__main__': + + main() +---- + +MUSA +---- +MUSA is the library that interfaces PyTorch with the MooreThreads graphics cards. +For more information check out `MUSA `_. diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 278cc98ef5547..34b9c5f7cbebb 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -14,6 +14,7 @@ accelerators CPUAccelerator CUDAAccelerator XLAAccelerator + MUSAAccelerator callbacks --------- diff --git a/docs/source-pytorch/extensions/accelerator.rst b/docs/source-pytorch/extensions/accelerator.rst index 174ba5ee7b42c..efb34e7aafdad 100644 --- a/docs/source-pytorch/extensions/accelerator.rst +++ b/docs/source-pytorch/extensions/accelerator.rst @@ -128,3 +128,4 @@ Accelerator API CUDAAccelerator MPSAccelerator XLAAccelerator + MUSAAccelerator diff --git a/src/lightning/fabric/accelerators/__init__.py b/src/lightning/fabric/accelerators/__init__.py index 3d4b43f75c762..acb996000d0e9 100644 --- a/src/lightning/fabric/accelerators/__init__.py +++ b/src/lightning/fabric/accelerators/__init__.py @@ -16,6 +16,7 @@ from lightning.fabric.accelerators.cpu import CPUAccelerator # noqa: F401 from lightning.fabric.accelerators.cuda import CUDAAccelerator, find_usable_cuda_devices # noqa: F401 from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401 +from lightning.fabric.accelerators.musa import MUSAAccelerator # noqa: F401 from lightning.fabric.accelerators.registry import _AcceleratorRegistry from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401 from lightning.fabric.utilities.registry import _register_classes diff --git a/src/lightning/fabric/accelerators/musa.py b/src/lightning/fabric/accelerators/musa.py new file mode 100644 index 0000000000000..98815b6a83775 --- /dev/null +++ b/src/lightning/fabric/accelerators/musa.py @@ -0,0 +1,186 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import lru_cache +from typing import Optional, Union + +import torch +from typing_extensions import override + +from lightning.fabric.accelerators.accelerator import Accelerator +from lightning.fabric.accelerators.registry import _AcceleratorRegistry +from lightning.fabric.utilities.rank_zero import rank_zero_info + + +class MUSAAccelerator(Accelerator): + """Accelerator for MUSA devices.""" + + @override + def setup_device(self, device: torch.device) -> None: + """ + Raises: + ValueError: + If the selected device is not of type MUSA. + """ + if device.type != "musa": + raise ValueError(f"Device should be MUSA, got {device} instead.") + _check_musa_matmul_precision(device) + torch.musa.set_device(device) + + @override + def teardown(self) -> None: + _clear_musa_memory() + + @staticmethod + @override + def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: + """Accelerator device parsing logic.""" + from lightning.fabric.utilities.device_parser import _parse_gpu_ids + + return _parse_gpu_ids(devices, include_musa=True) + + @staticmethod + @override + def get_parallel_devices(devices: list[int]) -> list[torch.device]: + """Gets parallel devices for the Accelerator.""" + return [torch.device("musa", i) for i in devices] + + @staticmethod + @override + def auto_device_count() -> int: + """Get the devices when set to auto.""" + return num_musa_devices() + + @staticmethod + @override + def is_available() -> bool: + return num_musa_devices() > 0 + + @staticmethod + @override + def name() -> str: + return "musa" + + @classmethod + @override + def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: + accelerator_registry.register( + cls.name(), + cls, + description=cls.__name__, + ) + + +def find_usable_musa_devices(num_devices: int = -1) -> list[int]: + """Returns a list of all available and usable MUSA GPU devices. + + A GPU is considered usable if we can successfully move a tensor to the device, and this is what this function + tests for each GPU on the system until the target number of usable devices is found. + + A subset of GPUs on the system might be used by other processes, and if the GPU is configured to operate in + 'exclusive' mode (configurable by the admin), then only one process is allowed to occupy it. + + Args: + num_devices: The number of devices you want to request. By default, this function will return as many as there + are usable MUSA GPU devices available. + + Warning: + If multiple processes call this function at the same time, there can be race conditions in the case where + both processes determine that the device is unoccupied, leading into one of them crashing later on. + + """ + if num_devices == 0: + return [] + visible_devices = _get_all_visible_musa_devices() + if not visible_devices: + raise ValueError( + f"You requested to find {num_devices} devices but there are no visible MUSA devices on this machine." + ) + if num_devices > len(visible_devices): + raise ValueError( + f"You requested to find {num_devices} devices but this machine only has {len(visible_devices)} GPUs." + ) + + available_devices = [] + unavailable_devices = [] + + for gpu_idx in visible_devices: + try: + torch.tensor(0, device=torch.device("musa", gpu_idx)) + except RuntimeError: + unavailable_devices.append(gpu_idx) + continue + + available_devices.append(gpu_idx) + if len(available_devices) == num_devices: + # exit early if we found the right number of GPUs + break + + if num_devices != -1 and len(available_devices) != num_devices: + raise RuntimeError( + f"You requested to find {num_devices} devices but only {len(available_devices)} are currently available." + f" The devices {unavailable_devices} are occupied by other processes and can't be used at the moment." + ) + return available_devices + + +def _get_all_visible_musa_devices() -> list[int]: + """Returns a list of all visible MUSA GPU devices. + + Devices masked by the environment variabale ``MUSA_VISIBLE_DEVICES`` won't be returned here. For example, assume you + have 8 physical GPUs. If ``MUSA_VISIBLE_DEVICES="1,3,6"``, then this function will return the list ``[0, 1, 2]`` + because these are the three visible GPUs after applying the mask ``MUSA_VISIBLE_DEVICES``. + + """ + return list(range(num_musa_devices())) + + +def num_musa_devices() -> int: + """Returns the number of available MUSA devices.""" + return torch.musa.device_count() + + +def is_musa_available() -> bool: + """Returns a bool indicating if MUSA is currently available.""" + # We set `PYTORCH_NVML_BASED_MUSA_CHECK=1` in lightning.fabric.__init__.py + return torch.musa.is_available() + + +def _is_ampere_or_later(device: Optional[torch.device] = None) -> bool: + major, _ = torch.musa.get_device_capability(device) + return major >= 8 # Ampere and later leverage tensor cores, where this setting becomes useful + + +@lru_cache(1) # show the warning only ever once +def _check_musa_matmul_precision(device: torch.device) -> None: + if not torch.musa.is_available() or not _is_ampere_or_later(device): + return + # check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and + # `set_float32_matmul_precision` + if torch.get_float32_matmul_precision() == "highest": # default + rank_zero_info( + f"You are using a MUSA device ({torch.musa.get_device_name(device)!r}) that has Tensor Cores. To properly" + " utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off" + " precision for performance. For more details, read https://pytorch.org/docs/stable/generated/" + "torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision" + ) + # note: no need change `torch.backends.cudnn.allow_tf32` as it's enabled by default: + # https://pytorch.org/docs/stable/notes/musa.html#tensorfloat-32-tf32-on-ampere-devices + + +def _clear_musa_memory() -> None: + # strangely, the attribute function be undefined when torch.compile is used + if hasattr(torch._C, "_musa_clearCublasWorkspaces"): + # https://github.com/pytorch/pytorch/issues/95668 + torch._C._musa_clearMublasWorkspaces() + torch.musa.empty_cache() diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 594bb46f4b362..69e1de1f11889 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -21,7 +21,7 @@ from lightning_utilities.core.imports import RequirementCache from typing_extensions import get_args -from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator +from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator, MUSAAccelerator from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS from lightning.fabric.strategies import STRATEGY_REGISTRY from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args @@ -196,9 +196,11 @@ def _get_num_processes(accelerator: str, devices: str) -> int: else: raise ValueError(f"Cannot default to '1' device for accelerator='{accelerator}'") if accelerator == "gpu": - parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) + parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True, include_musa=True) elif accelerator == "cuda": parsed_devices = CUDAAccelerator.parse_devices(devices) + elif accelerator == "musa": + parsed_devices = MUSAAccelerator.parse_devices(devices) elif accelerator == "mps": parsed_devices = MPSAccelerator.parse_devices(devices) elif accelerator == "tpu": diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index b3289debbd522..bc6b6ff5510e4 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -23,6 +23,7 @@ from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.accelerators.cuda import CUDAAccelerator from lightning.fabric.accelerators.mps import MPSAccelerator +from lightning.fabric.accelerators.musa import MUSAAccelerator from lightning.fabric.accelerators.xla import XLAAccelerator from lightning.fabric.plugins import ( BitsandbytesPrecision, @@ -322,6 +323,8 @@ def _choose_auto_accelerator() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if MUSAAccelerator.is_available(): + return "musa" return "cpu" @staticmethod @@ -330,6 +333,8 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if MUSAAccelerator.is_available(): + return "musa" raise RuntimeError("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: @@ -400,8 +405,8 @@ def _choose_strategy(self) -> Union[Strategy, str]: if self._num_nodes_flag > 1: return "ddp" if len(self._parallel_devices) <= 1: - if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( - isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") + if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator, MUSAAccelerator)) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps", "musa") ): device = _determine_root_gpu_device(self._parallel_devices) else: diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 8bdacc0f523f5..84305cf8f74a0 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -50,6 +50,7 @@ def _parse_gpu_ids( gpus: Optional[Union[int, str, list[int]]], include_cuda: bool = False, include_mps: bool = False, + include_musa: bool = False, ) -> Optional[list[int]]: """Parses the GPU IDs given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer`. @@ -61,6 +62,7 @@ def _parse_gpu_ids( Any int N > 0 indicates that GPUs [0..N) should be used. include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing. include_mps: A boolean value indicating whether to include MPS devices for GPU parsing. + include_musa: A boolean value indicating whether to include MUSA devices for GPU parsing. Returns: A list of GPUs to be used or ``None`` if no GPUs were requested @@ -70,7 +72,7 @@ def _parse_gpu_ids( If no GPUs are available but the value of gpus variable indicates request for GPUs .. note:: - ``include_cuda`` and ``include_mps`` default to ``False`` so that you only + ``include_cuda`` ``include_musa`` and ``include_mps`` default to ``False`` so that you only have to specify which device type to use and all other devices are not disabled. """ @@ -84,7 +86,9 @@ def _parse_gpu_ids( # We know the user requested GPUs therefore if some of the # requested GPUs are not available an exception is thrown. gpus = _normalize_parse_gpu_string_input(gpus) - gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) + gpus = _normalize_parse_gpu_input_to_list( + gpus, include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa + ) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") @@ -92,7 +96,8 @@ def _parse_gpu_ids( torch.distributed.is_available() and torch.distributed.is_torchelastic_launched() and len(gpus) != 1 - and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1 + and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa)) + == 1 ): # Omit sanity check on torchelastic because by default it shows one visible GPU per process return gpus @@ -100,7 +105,7 @@ def _parse_gpu_ids( # Check that GPUs are unique. Duplicate GPUs are not supported by the backend. _check_unique(gpus) - return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) + return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa) def _normalize_parse_gpu_string_input(s: Union[int, str, list[int]]) -> Union[int, list[int]]: @@ -113,7 +118,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, list[int]]) -> Union[in return int(s.strip()) -def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps: bool = False) -> list[int]: +def _sanitize_gpu_ids( + gpus: list[int], include_cuda: bool = False, include_mps: bool = False, include_musa: bool = False +) -> list[int]: """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the GPUs is not available. @@ -128,9 +135,11 @@ def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps: If machine has fewer available GPUs than requested. """ - if sum((include_cuda, include_mps)) == 0: + if sum((include_cuda, include_mps, include_musa)) == 0: raise ValueError("At least one gpu type should be specified!") - all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + all_available_gpus = _get_all_available_gpus( + include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa + ) for gpu in gpus: if gpu not in all_available_gpus: raise MisconfigurationException( @@ -140,7 +149,7 @@ def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps: def _normalize_parse_gpu_input_to_list( - gpus: Union[int, list[int], tuple[int, ...]], include_cuda: bool, include_mps: bool + gpus: Union[int, list[int], tuple[int, ...]], include_cuda: bool, include_mps: bool, include_musa: bool ) -> Optional[list[int]]: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): @@ -150,22 +159,26 @@ def _normalize_parse_gpu_input_to_list( if not gpus: # gpus==0 return None if gpus == -1: - return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa) return list(range(gpus)) -def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> list[int]: +def _get_all_available_gpus( + include_cuda: bool = False, include_mps: bool = False, include_musa: bool = False +) -> list[int]: """ Returns: A list of all available GPUs """ from lightning.fabric.accelerators.cuda import _get_all_visible_cuda_devices from lightning.fabric.accelerators.mps import _get_all_available_mps_gpus + from lightning.fabric.accelerators.musa import _get_all_visible_musa_devices cuda_gpus = _get_all_visible_cuda_devices() if include_cuda else [] mps_gpus = _get_all_available_mps_gpus() if include_mps else [] - return cuda_gpus + mps_gpus + musa_gpus = _get_all_visible_musa_devices() if include_musa else [] + return cuda_gpus + mps_gpus + musa_gpus + musa_gpus def _check_unique(device_ids: list[int]) -> None: @@ -210,6 +223,7 @@ def _select_auto_accelerator() -> str: """Choose the accelerator type (str) based on availability.""" from lightning.fabric.accelerators.cuda import CUDAAccelerator from lightning.fabric.accelerators.mps import MPSAccelerator + from lightning.fabric.accelerators.musa import MUSAAccelerator from lightning.fabric.accelerators.xla import XLAAccelerator if XLAAccelerator.is_available(): @@ -218,4 +232,6 @@ def _select_auto_accelerator() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if MUSAAccelerator.is_available(): + return "musa" return "cpu" diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index d085e4138d742..48e1b20cda7c4 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -23,6 +23,7 @@ from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.accelerators.cuda import num_cuda_devices from lightning.fabric.accelerators.mps import MPSAccelerator +from lightning.fabric.accelerators.musa import MUSAAccelerator from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 @@ -36,6 +37,7 @@ def _runif_reasons( bf16_cuda: bool = False, tpu: bool = False, mps: Optional[bool] = None, + musa: Optional[bool] = None, skip_windows: bool = False, standalone: bool = False, deepspeed: bool = False, @@ -53,6 +55,8 @@ def _runif_reasons( tpu: Require that TPU is available. mps: If True: Require that MPS (Apple Silicon) is available, if False: Explicitly Require that MPS is not available + musa: If True: Require that MUSA (Device) is available, + if False: Explicitly Require that MUSA is not available skip_windows: Skip for Windows platform. standalone: Mark the test as standalone, our CI will run it in a separate process. This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set. @@ -108,6 +112,12 @@ def _runif_reasons( elif not mps and MPSAccelerator.is_available(): reasons.append("not MPS") + if musa is not None: + if musa and not MUSAAccelerator.is_available(): + reasons.append("MUSA") + elif not musa and MUSAAccelerator.is_available(): + reasons.append("not MUSA") + if standalone: if os.getenv("PL_RUN_STANDALONE_TESTS", "0") != "1": reasons.append("Standalone execution") diff --git a/src/lightning/pytorch/accelerators/__init__.py b/src/lightning/pytorch/accelerators/__init__.py index d7c2197aa5ed4..be609456f9f68 100644 --- a/src/lightning/pytorch/accelerators/__init__.py +++ b/src/lightning/pytorch/accelerators/__init__.py @@ -17,6 +17,7 @@ "CUDAAccelerator", "MPSAccelerator", "XLAAccelerator", + "MUSAAccelerator", "find_usable_cuda_devices", ] @@ -29,6 +30,7 @@ from lightning.pytorch.accelerators.cpu import CPUAccelerator from lightning.pytorch.accelerators.cuda import CUDAAccelerator from lightning.pytorch.accelerators.mps import MPSAccelerator +from lightning.pytorch.accelerators.musa import MUSAAccelerator from lightning.pytorch.accelerators.xla import XLAAccelerator AcceleratorRegistry = _AcceleratorRegistry() diff --git a/src/lightning/pytorch/accelerators/musa.py b/src/lightning/pytorch/accelerators/musa.py new file mode 100644 index 0000000000000..573ad638a9a14 --- /dev/null +++ b/src/lightning/pytorch/accelerators/musa.py @@ -0,0 +1,117 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Any, Optional, Union + +import torch +from typing_extensions import override + +import lightning.pytorch as pl +from lightning.fabric.accelerators.musa import _check_musa_matmul_precision, _clear_musa_memory, num_musa_devices +from lightning.fabric.accelerators.registry import _AcceleratorRegistry +from lightning.fabric.utilities.device_parser import _parse_gpu_ids +from lightning.fabric.utilities.types import _DEVICE +from lightning.pytorch.accelerators.accelerator import Accelerator +from lightning.pytorch.utilities.exceptions import MisconfigurationException + +_log = logging.getLogger(__name__) + + +class MUSAAccelerator(Accelerator): + """Accelerator for MUSA devices.""" + + @override + def setup_device(self, device: torch.device) -> None: + """ + Raises: + MisconfigurationException: + If the selected device is not GPU. + """ + if device.type != "musa": + raise MisconfigurationException(f"Device should be GPU, got {device} instead") + _check_musa_matmul_precision(device) + torch.musa.set_device(device) + + @override + def setup(self, trainer: "pl.Trainer") -> None: + # TODO refactor input from trainer to local_rank @four4fish + self.set_musa_flags(trainer.local_rank) + _clear_musa_memory() + + @staticmethod + def set_musa_flags(local_rank: int) -> None: + # set the correct musa visible devices (using pci order) + os.environ["MUSA_DEVICE_ORDER"] = "PCI_BUS_ID" + all_gpu_ids = ",".join(str(x) for x in range(num_musa_devices())) + devices = os.getenv("MUSA_VISIBLE_DEVICES", all_gpu_ids) + _log.info(f"LOCAL_RANK: {local_rank} - MUSA_VISIBLE_DEVICES: [{devices}]") + + @override + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + """Gets stats for the given GPU device. + + Args: + device: GPU device for which to get stats + + Returns: + A dictionary mapping the metrics to their values. + + Raises: + FileNotFoundError: + If mthreds-gmi installation not found + + """ + return torch.musa.memory_stats(device) + + @override + def teardown(self) -> None: + _clear_musa_memory() + + @staticmethod + @override + def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: + """Accelerator device parsing logic.""" + return _parse_gpu_ids(devices, include_musa=True) + + @staticmethod + @override + def get_parallel_devices(devices: list[int]) -> list[torch.device]: + """Gets parallel devices for the Accelerator.""" + return [torch.device("musa", i) for i in devices] + + @staticmethod + @override + def auto_device_count() -> int: + """Get the devices when set to auto.""" + return num_musa_devices() + + @staticmethod + @override + def is_available() -> bool: + return num_musa_devices() > 0 + + @staticmethod + @override + def name() -> str: + return "musa" + + @classmethod + @override + def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: + accelerator_registry.register( + cls.name(), + cls, + description=cls.__name__, + ) diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 391e9dd5d0f25..078799322e075 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -34,7 +34,7 @@ from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.fabric.utilities.data import AttributeDict from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH -from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator +from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, MUSAAccelerator, XLAAccelerator from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from lightning.pytorch.utilities.migration import pl_legacy_patch from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint @@ -111,6 +111,8 @@ def _default_map_location(storage: "UntypedStorage", location: str) -> Optional[ and not CUDAAccelerator.is_available() or location.startswith("xla") and not XLAAccelerator.is_available() + or location.startswith("musa") + and not MUSAAccelerator.is_available() ): return storage.cpu() return None # default behavior by `torch.load()` diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index d51b11ea6fb12..d8bb7e376eb23 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -35,6 +35,7 @@ from lightning.pytorch.accelerators.accelerator import Accelerator from lightning.pytorch.accelerators.cuda import CUDAAccelerator from lightning.pytorch.accelerators.mps import MPSAccelerator +from lightning.pytorch.accelerators.musa import MUSAAccelerator from lightning.pytorch.accelerators.xla import XLAAccelerator from lightning.pytorch.plugins import ( _PLUGIN_INPUT, @@ -338,6 +339,8 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if MUSAAccelerator.is_available(): + return "musa" raise MisconfigurationException("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: @@ -409,8 +412,8 @@ def _choose_strategy(self) -> Union[Strategy, str]: if self._num_nodes_flag > 1: return "ddp" if len(self._parallel_devices) <= 1: - if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( - isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") + if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator, MUSAAccelerator)) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps", "musa") ): device = _determine_root_gpu_device(self._parallel_devices) else: diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 4f522c7c008bc..8275de261995c 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -18,7 +18,7 @@ import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning -from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator +from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, MUSAAccelerator, XLAAccelerator from lightning.pytorch.loggers.logger import DummyLogger from lightning.pytorch.profilers import ( AdvancedProfiler, @@ -156,11 +156,14 @@ def _log_device_info(trainer: "pl.Trainer") -> None: elif MPSAccelerator.is_available(): gpu_available = True gpu_type = " (mps)" + elif MUSAAccelerator.is_available(): + gpu_available = True + gpu_type = " (musa)" else: gpu_available = False gpu_type = "" - gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) + gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator, MUSAAccelerator)) rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0 @@ -171,6 +174,8 @@ def _log_device_info(trainer: "pl.Trainer") -> None: and not isinstance(trainer.accelerator, CUDAAccelerator) or MPSAccelerator.is_available() and not isinstance(trainer.accelerator, MPSAccelerator) + or MUSAAccelerator.is_available() + and not isinstance(trainer.accelerator, MUSAAccelerator) ): rank_zero_warn( "GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.", diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 4c5b3bb6b4712..710293131bd9a 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -32,6 +32,7 @@ def _runif_reasons( bf16_cuda: bool = False, tpu: bool = False, mps: Optional[bool] = None, + musa: Optional[bool] = None, skip_windows: bool = False, standalone: bool = False, deepspeed: bool = False, @@ -56,6 +57,8 @@ def _runif_reasons( tpu: Require that TPU is available. mps: If True: Require that MPS (Apple Silicon) is available, if False: Explicitly Require that MPS is not available + musa: If True: Require that MUSA (Device) is available, + if False: Explicitly Require that MUSA is not available skip_windows: Skip for Windows platform. standalone: Mark the test as standalone, our CI will run it in a separate process. This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set. @@ -79,6 +82,7 @@ def _runif_reasons( bf16_cuda=bf16_cuda, tpu=tpu, mps=mps, + musa=musa, skip_windows=skip_windows, standalone=standalone, deepspeed=deepspeed, diff --git a/tests/tests_fabric/accelerators/test_musa.py b/tests/tests_fabric/accelerators/test_musa.py new file mode 100644 index 0000000000000..c354ca622a29c --- /dev/null +++ b/tests/tests_fabric/accelerators/test_musa.py @@ -0,0 +1,59 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest +import torch + +from lightning.fabric.accelerators.musa import MUSAAccelerator +from tests_fabric.helpers.runif import RunIf + +_MAYBE_MUSA = "musa" if MUSAAccelerator.is_available() else "cpu" + + +@mock.patch("lightning.fabric.accelerators.musa.num_musa_devices", return_value=2) +@RunIf(musa=True) +def test_auto_device_count(_): + assert MUSAAccelerator.auto_device_count() == 2 + + +@RunIf(musa=True) +def test_musa_availability(): + assert MUSAAccelerator.is_available() + + +def test_init_device_with_wrong_device_type(): + with pytest.raises(ValueError, match="Device should be MUSA"): + MUSAAccelerator().setup_device(torch.device("cpu")) + + +@RunIf(musa=True) +@pytest.mark.parametrize( + ("devices", "expected"), + [ + ([], []), + ([1], [torch.device(_MAYBE_MUSA, 1)]), + ([3, 1], [torch.device(_MAYBE_MUSA, 3), torch.device(_MAYBE_MUSA, 1)]), + ], +) +def test_get_parallel_devices(devices, expected): + assert MUSAAccelerator.get_parallel_devices(devices) == expected + + +@mock.patch("torch.musa.set_device") +@mock.patch("torch.musa.get_device_capability", return_value=(7, 0)) +def test_set_cuda_device(_, set_device_mock): + device = torch.device(_MAYBE_MUSA, 1) + MUSAAccelerator().setup_device(device) + set_device_mock.assert_called_once_with(device) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 51c4b320d5525..91a00a20903ed 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -10,7 +10,7 @@ from lightning_utilities.core.imports import RequirementCache import lightning.fabric -from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator +from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator, MUSAAccelerator from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher @@ -40,7 +40,12 @@ def spawn_launch(fn, parallel_devices): """Copied from ``tests_pytorch.core.test_results.spawn_launch``""" # TODO: the accelerator and cluster_environment should be optional to just launch processes, but this requires lazy # initialization to be implemented - device_to_accelerator = {"cuda": CUDAAccelerator, "mps": MPSAccelerator, "cpu": CPUAccelerator} + device_to_accelerator = { + "cuda": CUDAAccelerator, + "mps": MPSAccelerator, + "cpu": CPUAccelerator, + "musa": MUSAAccelerator, + } accelerator_cls = device_to_accelerator[parallel_devices[0].type] strategy = DDPStrategy( accelerator=accelerator_cls(), diff --git a/tests/tests_pytorch/accelerators/test_musa.py b/tests/tests_pytorch/accelerators/test_musa.py new file mode 100644 index 0000000000000..7b11e60cd429e --- /dev/null +++ b/tests/tests_pytorch/accelerators/test_musa.py @@ -0,0 +1,58 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import pytest + +from lightning.pytorch import Trainer +from lightning.pytorch.accelerators import MUSAAccelerator +from lightning.pytorch.demos.boring_classes import BoringModel +from tests_pytorch.helpers.runif import RunIf + + +@RunIf(musa=True) +def test_musa_availability(): + assert MUSAAccelerator.is_available() + + +def test_warning_if_musa_not_used(musa_count_1): + with pytest.warns(UserWarning, match="GPU available but not used"): + Trainer(accelerator="cpu") + + +@RunIf(musa=True) +@pytest.mark.parametrize("accelerator_value", ["musa", MUSAAccelerator()]) +def test_trainer_musa_accelerator(accelerator_value): + trainer = Trainer(accelerator=accelerator_value, devices=1) + assert isinstance(trainer.accelerator, MUSAAccelerator) + assert trainer.num_devices == 1 + + +@RunIf(musa=True) +@mock.patch("torch.musa.set_device") +def test_set_musa_device(set_device_mock, tmp_path, monkeypatch): + monkeypatch.setenv("MUSA_DEVICE_ORDER", "PCI_BUS_ID") # 或其他需要的值 + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + fast_dev_run=True, + accelerator="gpu", + devices=1, + enable_checkpointing=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model) + set_device_mock.assert_called_once() diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 878298c6bfd94..5af287b87f42c 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -204,6 +204,31 @@ def cuda_count_4(monkeypatch): mock_cuda_count(monkeypatch, 4) +def mock_musa_count(monkeypatch, n: int) -> None: + monkeypatch.setattr(lightning.fabric.accelerators.musa, "num_musa_devices", lambda: n) + monkeypatch.setattr(lightning.pytorch.accelerators.musa, "num_musa_devices", lambda: n) + + +@pytest.fixture +def musa_count_0(monkeypatch): + mock_musa_count(monkeypatch, 0) + + +@pytest.fixture +def musa_count_1(monkeypatch): + mock_musa_count(monkeypatch, 1) + + +@pytest.fixture +def musa_count_2(monkeypatch): + mock_musa_count(monkeypatch, 2) + + +@pytest.fixture +def musa_count_4(monkeypatch): + mock_musa_count(monkeypatch, 4) + + def mock_mps_count(monkeypatch, n: int) -> None: monkeypatch.setattr(lightning.fabric.accelerators.mps, "_get_all_available_mps_gpus", lambda: [0] if n > 0 else []) monkeypatch.setattr(lightning.fabric.accelerators.mps.MPSAccelerator, "is_available", lambda *_: n > 0)