55#include < cstddef>
66#include < cstdint>
77#include < exception>
8+ #include < iterator>
89#include < memory>
910#include < string>
1011#include < string_view>
11- #include < utility>
1212#include < vector>
1313
1414#include < torch/csrc/lazy/backend/backend_device.h>
1515
16- #include " absl/log/absl_check.h"
1716#include " absl/status/status.h"
1817#include " absl/status/statusor.h"
1918#include " absl/strings/str_cat.h"
@@ -30,14 +29,13 @@ namespace {
3029static bool spmd_config_is_locked = false ;
3130static bool use_virtual_device = false ;
3231
33- constexpr std::size_t kNativeXlaDeviceTypeNumber = 5 ;
34- constexpr std::array<std::pair<XlaDeviceType, std::string_view>,
35- kNativeXlaDeviceTypeNumber >
36- kNativeXlaDeviceTypeWithName = {{
37- #define XLA_DEVICE_NAME_PAIR (name, _ ) {XlaDeviceType::name, #name},
38- XLA_FOR_ALL_NATIVE_DEVICE_TYPES_ (XLA_DEVICE_NAME_PAIR)
39- #undef XLA_DEVICE_NAME_PAIR
40- }};
32+ constexpr int8_t kNativeXlaDeviceTypeNumber =
33+ static_cast <int8_t >(XlaDeviceType::PLUGIN);
34+
35+ // The elements in this array should match the order in the XlaDeviceType enum
36+ // declaration. So, if you modify one of them, make sure to keep them in sync.
37+ constexpr std::array<std::string_view, kNativeXlaDeviceTypeNumber >
38+ kNativeXlaDeviceTypeNames = {" CPU" , " CUDA" , " TPU" , " NEURON" , " SPMD" };
4139
4240absl::Status CheckIsNativeXlaDeviceType (int8_t value) {
4341 if (value < 0 || value >= kNativeXlaDeviceTypeNumber ) {
@@ -62,21 +60,20 @@ std::string_view NativeXlaDeviceTypeToString(XlaDeviceType type) {
6260 // 2. The XlaDeviceType::PLUGIN enum, since it's not considered a "native"
6361 // device type
6462 XLA_CHECK_OK (CheckIsNativeXlaDeviceType (value));
65- return kNativeXlaDeviceTypeWithName [value]. second ;
63+ return kNativeXlaDeviceTypeNames [value];
6664}
6765
6866XlaDeviceType StringToXlaDeviceType (std::string_view type_name) {
69- std::array<std::pair<XlaDeviceType, std::string_view>,
70- kNativeXlaDeviceTypeNumber >::const_iterator it =
71- std::find_if (kNativeXlaDeviceTypeWithName .begin (),
72- kNativeXlaDeviceTypeWithName .end (),
73- [=](const std::pair<XlaDeviceType, std::string_view>& pair) {
74- return pair.second == type_name;
75- });
76- if (it != kNativeXlaDeviceTypeWithName .end ()) {
77- return it->first ;
67+ std::array<std::string_view, kNativeXlaDeviceTypeNumber >::const_iterator it =
68+ std::find (kNativeXlaDeviceTypeNames .begin (),
69+ kNativeXlaDeviceTypeNames .end (), type_name);
70+
71+ if (it == kNativeXlaDeviceTypeNames .end ()) {
72+ return XlaDeviceType::PLUGIN;
7873 }
79- return XlaDeviceType::PLUGIN;
74+
75+ std::size_t index = std::distance (kNativeXlaDeviceTypeNames .begin (), it);
76+ return static_cast <XlaDeviceType>(index);
8077}
8178
8279} // namespace
0 commit comments