1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import functools
15- import warnings
1615from typing import Any , Union
1716
1817import torch
2120
2221from lightning .fabric .accelerators .accelerator import Accelerator
2322from lightning .fabric .accelerators .registry import _AcceleratorRegistry
24- from lightning .fabric .utilities .imports import _raise_enterprise_not_available
25-
26- _XLA_AVAILABLE = RequirementCache ("torch_xla>=1.13" , "torch_xla" )
27- _XLA_GREATER_EQUAL_2_1 = RequirementCache ("torch_xla>=2.1" )
28- _XLA_GREATER_EQUAL_2_5 = RequirementCache ("torch_xla>=2.5" )
23+ from lightning .fabric .utilities .device_parser import _check_data_type
2924
3025
3126class XLAAccelerator (Accelerator ):
@@ -36,38 +31,38 @@ class XLAAccelerator(Accelerator):
3631 """
3732
3833 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
39- _raise_enterprise_not_available ()
34+ if not _XLA_AVAILABLE :
35+ raise ModuleNotFoundError (str (_XLA_AVAILABLE ))
36+ if not _using_pjrt ():
37+ raise RuntimeError ("The XLA XRT runtime is not supported anymore." )
4038 super ().__init__ (* args , ** kwargs )
4139
42- from pytorch_lightning_enterprise .accelerators .xla import XLAAccelerator as EnterpriseXLAAccelerator
43-
44- self .accelerator_impl = EnterpriseXLAAccelerator (* args , ** kwargs )
45-
4640 @override
4741 def setup_device (self , device : torch .device ) -> None :
48- return self . accelerator_impl . setup_device ( device )
42+ pass
4943
5044 @override
5145 def teardown (self ) -> None :
52- return self . accelerator_impl . teardown ()
46+ pass
5347
5448 @staticmethod
5549 @override
5650 def parse_devices (devices : Union [int , str , list [int ]]) -> Union [int , list [int ]]:
5751 """Accelerator device parsing logic."""
58- _raise_enterprise_not_available ()
59- from pytorch_lightning_enterprise .accelerators .xla import XLAAccelerator as EnterpriseXLAAccelerator
60-
61- return EnterpriseXLAAccelerator .parse_devices (devices )
52+ return _parse_tpu_devices (devices )
6253
6354 @staticmethod
6455 @override
6556 def get_parallel_devices (devices : Union [int , list [int ]]) -> list [torch .device ]:
6657 """Gets parallel devices for the Accelerator."""
67- _raise_enterprise_not_available ()
68- from pytorch_lightning_enterprise .accelerators .xla import XLAAccelerator as EnterpriseXLAAccelerator
69-
70- return EnterpriseXLAAccelerator .get_parallel_devices (devices )
58+ devices = _parse_tpu_devices (devices )
59+ if isinstance (devices , int ):
60+ return [torch .device ("xla" , i ) for i in range (devices )]
61+ # list of devices is not supported, just a specific index, fine to access [0]
62+ return [torch .device ("xla" , devices [0 ])]
63+ # we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
64+ # accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
65+ # it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
7166
7267 @staticmethod
7368 @override
@@ -76,10 +71,16 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
7671 @functools .lru_cache (maxsize = 1 )
7772 def auto_device_count () -> int :
7873 """Get the devices when set to auto."""
79- _raise_enterprise_not_available ()
80- from pytorch_lightning_enterprise .accelerators .xla import XLAAccelerator as EnterpriseXLAAccelerator
74+ if not _XLA_AVAILABLE :
75+ return 0
76+ if _XLA_GREATER_EQUAL_2_1 :
77+ from torch_xla ._internal import tpu
78+
79+ return tpu .num_available_devices ()
80+ from torch_xla .experimental import tpu
8181
82- return EnterpriseXLAAccelerator .auto_device_count ()
82+ device_count_on_version = {2 : 8 , 3 : 8 , 4 : 4 }
83+ return device_count_on_version .get (tpu .version (), 8 )
8384
8485 @staticmethod
8586 @override
@@ -91,9 +92,6 @@ def is_available() -> bool:
9192 # XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
9293 # when `torch_xla` is imported but not used
9394 return False
94- except ModuleNotFoundError as e :
95- warnings .warn (str (e ))
96- return False
9795
9896 @staticmethod
9997 @override
@@ -108,3 +106,74 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
108106 cls ,
109107 description = cls .__name__ ,
110108 )
109+
110+
111+ # PJRT support requires this minimum version
112+ _XLA_AVAILABLE = RequirementCache ("torch_xla>=1.13" , "torch_xla" )
113+ _XLA_GREATER_EQUAL_2_1 = RequirementCache ("torch_xla>=2.1" )
114+ _XLA_GREATER_EQUAL_2_5 = RequirementCache ("torch_xla>=2.5" )
115+
116+
117+ def _using_pjrt () -> bool :
118+ # `using_pjrt` is removed in torch_xla 2.5
119+ if _XLA_GREATER_EQUAL_2_5 :
120+ from torch_xla import runtime as xr
121+
122+ return xr .device_type () is not None
123+ # delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
124+ if _XLA_GREATER_EQUAL_2_1 :
125+ from torch_xla import runtime as xr
126+
127+ return xr .using_pjrt ()
128+
129+ from torch_xla .experimental import pjrt
130+
131+ return pjrt .using_pjrt ()
132+
133+
134+ def _parse_tpu_devices (devices : Union [int , str , list [int ]]) -> Union [int , list [int ]]:
135+ """Parses the TPU devices given in the format as accepted by the
136+ :class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.
137+
138+ Args:
139+ devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
140+ An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
141+ A single element list of int or string can be used to indicate the specific TPU core to use.
142+
143+ Returns:
144+ A list of tpu cores to be used.
145+
146+ """
147+ _check_data_type (devices )
148+ if isinstance (devices , str ):
149+ devices = _parse_tpu_devices_str (devices )
150+ _check_tpu_devices_valid (devices )
151+ return devices
152+
153+
154+ def _check_tpu_devices_valid (devices : object ) -> None :
155+ device_count = XLAAccelerator .auto_device_count ()
156+ if (
157+ # support number of devices
158+ isinstance (devices , int )
159+ and devices in {1 , device_count }
160+ # support picking a specific device
161+ or isinstance (devices , (list , tuple ))
162+ and len (devices ) == 1
163+ and 0 <= devices [0 ] <= device_count - 1
164+ ):
165+ return
166+ raise ValueError (
167+ f"`devices` can only be 'auto', 1, { device_count } or [<0-{ device_count - 1 } >] for TPUs. Got { devices !r} "
168+ )
169+
170+
171+ def _parse_tpu_devices_str (devices : str ) -> Union [int , list [int ]]:
172+ devices = devices .strip ()
173+ try :
174+ return int (devices )
175+ except ValueError :
176+ try :
177+ return [int (x .strip ()) for x in devices .split ("," ) if len (x ) > 0 ]
178+ except ValueError :
179+ raise ValueError (f"Could not parse the selected TPU devices: { devices !r} " )
0 commit comments