Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source-fabric/api/accelerators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ Accelerators
CUDAAccelerator
MPSAccelerator
XLAAccelerator
MUSAAccelerator
77 changes: 77 additions & 0 deletions docs/source-pytorch/accelerators/musa.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
:orphan:

MUSA training (Advanced)
========================
**Audience:** Users looking to train models on MooreThreads device using MUSA accelerator.

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.

----

MUSAAccelerator Overview
--------------------
torch_musa is an extended Python package based on PyTorch that enables full utilization of MooreThreads graphics cards'
super computing power. Combined with PyTorch, users can take advantage of the strong power of MooreThreads graphics cards
through torch_musa.

PyTorch Lightning automatically finds these weights and ties them after the modules are moved to the
MUSA device under the hood. It will ensure that the weights among the modules are shared but not copied
independently.


Example:

.. code-block:: python
import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F
import pytorch_lightning as L

# Step 1: Define a LightningModule
class LitAutoEncoder(L.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding

def training_step(self, batch, batch_idx):
# training_step defines the train loop. It is independent of forward
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer

def main():
# -------------------
# Step 2: Define data
# -------------------
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

# -------------------
# Step 3: Train
# -------------------
autoencoder = LitAutoEncoder()
# we also support accelerator="auto" or accelerator="musa"
trainer = L.Trainer(accelerator="gpu")
trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val))

if __name__ == '__main__':

main()
----

MUSA
----
MUSA is the library that interfaces PyTorch with the MooreThreads graphics cards.
For more information check out `MUSA <https://github.com/MooreThreads/torch_musa>`_.
1 change: 1 addition & 0 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ accelerators
CPUAccelerator
CUDAAccelerator
XLAAccelerator
MUSAAccelerator

callbacks
---------
Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/extensions/accelerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,4 @@ Accelerator API
CUDAAccelerator
MPSAccelerator
XLAAccelerator
MUSAAccelerator
1 change: 1 addition & 0 deletions src/lightning/fabric/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from lightning.fabric.accelerators.cpu import CPUAccelerator # noqa: F401
from lightning.fabric.accelerators.cuda import CUDAAccelerator, find_usable_cuda_devices # noqa: F401
from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401
from lightning.fabric.accelerators.musa import MUSAAccelerator # noqa: F401
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401
from lightning.fabric.utilities.registry import _register_classes
Expand Down
186 changes: 186 additions & 0 deletions src/lightning/fabric/accelerators/musa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import lru_cache
from typing import Optional, Union

import torch
from typing_extensions import override

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.rank_zero import rank_zero_info


class MUSAAccelerator(Accelerator):
"""Accelerator for MUSA devices."""

@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
ValueError:
If the selected device is not of type MUSA.
"""
if device.type != "musa":
raise ValueError(f"Device should be MUSA, got {device} instead.")
_check_musa_matmul_precision(device)
torch.musa.set_device(device)

@override
def teardown(self) -> None:
_clear_musa_memory()

@staticmethod
@override
def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]:
"""Accelerator device parsing logic."""
from lightning.fabric.utilities.device_parser import _parse_gpu_ids

return _parse_gpu_ids(devices, include_musa=True)

@staticmethod
@override
def get_parallel_devices(devices: list[int]) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
return [torch.device("musa", i) for i in devices]

@staticmethod
@override
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return num_musa_devices()

@staticmethod
@override
def is_available() -> bool:
return num_musa_devices() > 0

@staticmethod
@override
def name() -> str:
return "musa"

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
cls.name(),
cls,
description=cls.__name__,
)


def find_usable_musa_devices(num_devices: int = -1) -> list[int]:
"""Returns a list of all available and usable MUSA GPU devices.

A GPU is considered usable if we can successfully move a tensor to the device, and this is what this function
tests for each GPU on the system until the target number of usable devices is found.

A subset of GPUs on the system might be used by other processes, and if the GPU is configured to operate in
'exclusive' mode (configurable by the admin), then only one process is allowed to occupy it.

Args:
num_devices: The number of devices you want to request. By default, this function will return as many as there
are usable MUSA GPU devices available.

Warning:
If multiple processes call this function at the same time, there can be race conditions in the case where
both processes determine that the device is unoccupied, leading into one of them crashing later on.

"""
if num_devices == 0:
return []
visible_devices = _get_all_visible_musa_devices()
if not visible_devices:
raise ValueError(
f"You requested to find {num_devices} devices but there are no visible MUSA devices on this machine."
)
if num_devices > len(visible_devices):
raise ValueError(
f"You requested to find {num_devices} devices but this machine only has {len(visible_devices)} GPUs."
)

available_devices = []
unavailable_devices = []

for gpu_idx in visible_devices:
try:
torch.tensor(0, device=torch.device("musa", gpu_idx))
except RuntimeError:
unavailable_devices.append(gpu_idx)
continue

available_devices.append(gpu_idx)
if len(available_devices) == num_devices:
# exit early if we found the right number of GPUs
break

if num_devices != -1 and len(available_devices) != num_devices:
raise RuntimeError(
f"You requested to find {num_devices} devices but only {len(available_devices)} are currently available."
f" The devices {unavailable_devices} are occupied by other processes and can't be used at the moment."
)
return available_devices


def _get_all_visible_musa_devices() -> list[int]:
"""Returns a list of all visible MUSA GPU devices.

Devices masked by the environment variabale ``MUSA_VISIBLE_DEVICES`` won't be returned here. For example, assume you
have 8 physical GPUs. If ``MUSA_VISIBLE_DEVICES="1,3,6"``, then this function will return the list ``[0, 1, 2]``
because these are the three visible GPUs after applying the mask ``MUSA_VISIBLE_DEVICES``.

"""
return list(range(num_musa_devices()))


def num_musa_devices() -> int:
"""Returns the number of available MUSA devices."""
return torch.musa.device_count()


def is_musa_available() -> bool:
"""Returns a bool indicating if MUSA is currently available."""
# We set `PYTORCH_NVML_BASED_MUSA_CHECK=1` in lightning.fabric.__init__.py
return torch.musa.is_available()


def _is_ampere_or_later(device: Optional[torch.device] = None) -> bool:
major, _ = torch.musa.get_device_capability(device)
return major >= 8 # Ampere and later leverage tensor cores, where this setting becomes useful


@lru_cache(1) # show the warning only ever once
def _check_musa_matmul_precision(device: torch.device) -> None:
if not torch.musa.is_available() or not _is_ampere_or_later(device):
return
# check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and
# `set_float32_matmul_precision`
if torch.get_float32_matmul_precision() == "highest": # default
rank_zero_info(
f"You are using a MUSA device ({torch.musa.get_device_name(device)!r}) that has Tensor Cores. To properly"
" utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off"
" precision for performance. For more details, read https://pytorch.org/docs/stable/generated/"
"torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision"
)
# note: no need change `torch.backends.cudnn.allow_tf32` as it's enabled by default:
# https://pytorch.org/docs/stable/notes/musa.html#tensorfloat-32-tf32-on-ampere-devices


def _clear_musa_memory() -> None:
# strangely, the attribute function be undefined when torch.compile is used
if hasattr(torch._C, "_musa_clearCublasWorkspaces"):
# https://github.com/pytorch/pytorch/issues/95668
torch._C._musa_clearMublasWorkspaces()
torch.musa.empty_cache()
6 changes: 4 additions & 2 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import get_args

from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator, MUSAAccelerator
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
from lightning.fabric.strategies import STRATEGY_REGISTRY
from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args
Expand Down Expand Up @@ -196,9 +196,11 @@ def _get_num_processes(accelerator: str, devices: str) -> int:
else:
raise ValueError(f"Cannot default to '1' device for accelerator='{accelerator}'")
if accelerator == "gpu":
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True)
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True, include_musa=True)
elif accelerator == "cuda":
parsed_devices = CUDAAccelerator.parse_devices(devices)
elif accelerator == "musa":
parsed_devices = MUSAAccelerator.parse_devices(devices)
elif accelerator == "mps":
parsed_devices = MPSAccelerator.parse_devices(devices)
elif accelerator == "tpu":
Expand Down
9 changes: 7 additions & 2 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.cuda import CUDAAccelerator
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.accelerators.musa import MUSAAccelerator
from lightning.fabric.accelerators.xla import XLAAccelerator
from lightning.fabric.plugins import (
BitsandbytesPrecision,
Expand Down Expand Up @@ -322,6 +323,8 @@ def _choose_auto_accelerator() -> str:
return "mps"
if CUDAAccelerator.is_available():
return "cuda"
if MUSAAccelerator.is_available():
return "musa"
return "cpu"

@staticmethod
Expand All @@ -330,6 +333,8 @@ def _choose_gpu_accelerator_backend() -> str:
return "mps"
if CUDAAccelerator.is_available():
return "cuda"
if MUSAAccelerator.is_available():
return "musa"
raise RuntimeError("No supported gpu backend found!")

def _set_parallel_devices_and_init_accelerator(self) -> None:
Expand Down Expand Up @@ -400,8 +405,8 @@ def _choose_strategy(self) -> Union[Strategy, str]:
if self._num_nodes_flag > 1:
return "ddp"
if len(self._parallel_devices) <= 1:
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator, MUSAAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps", "musa")
):
device = _determine_root_gpu_device(self._parallel_devices)
else:
Expand Down
Loading