Skip to content

Commit dd8e417

Browse files
committed
Address review.
1 parent 51d7c3b commit dd8e417

File tree

3 files changed

+33
-62
lines changed

3 files changed

+33
-62
lines changed

test/cpp/test_device.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#include <gtest/gtest.h>
2-
#include <torch/csrc/lazy/backend/backend_device.h>
32

43
#include <cstdint>
54
#include <string>
65
#include <string_view>
76

7+
#include <torch/csrc/lazy/backend/backend_device.h>
8+
9+
#include "absl/status/statusor.h"
810
#include "absl/strings/str_cat.h"
911

1012
#include "torch_xla/csrc/device.h"
@@ -47,11 +49,11 @@ static void CheckDeviceTypeConstructionWithString(
4749
}
4850

4951
TEST(DeviceTest, ConstructNativeDeviceTypeWithString) {
50-
#define XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING(type, _) \
51-
CheckDeviceTypeConstructionWithString(XlaDeviceType::type, #type);
52-
XLA_FOR_ALL_NATIVE_DEVICE_TYPES_(
53-
XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING)
54-
#undef XLA_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING
52+
CheckDeviceTypeConstructionWithString(XlaDeviceType::CPU, "CPU");
53+
CheckDeviceTypeConstructionWithString(XlaDeviceType::CUDA, "CUDA");
54+
CheckDeviceTypeConstructionWithString(XlaDeviceType::TPU, "TPU");
55+
CheckDeviceTypeConstructionWithString(XlaDeviceType::NEURON, "NEURON");
56+
CheckDeviceTypeConstructionWithString(XlaDeviceType::SPMD, "SPMD");
5557
}
5658

5759
TEST(DeviceTest, ConstructPluginDeviceTypeWithString) {
@@ -68,11 +70,11 @@ static void CheckDeviceTypeConstructionWithEnum(
6870
}
6971

7072
TEST(DeviceTest, ConstructNativeDeviceTypeWithEnum) {
71-
#define XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_ENUM(type, _) \
72-
CheckDeviceTypeConstructionWithEnum(XlaDeviceType::type, #type);
73-
XLA_FOR_ALL_NATIVE_DEVICE_TYPES_(
74-
XLA_NATIVE_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_ENUM)
75-
#undef XLA_DEVICE_TYPE_CHECK_CONSTRUCTION_WITH_STRING
73+
CheckDeviceTypeConstructionWithEnum(XlaDeviceType::CPU, "CPU");
74+
CheckDeviceTypeConstructionWithEnum(XlaDeviceType::CUDA, "CUDA");
75+
CheckDeviceTypeConstructionWithEnum(XlaDeviceType::TPU, "TPU");
76+
CheckDeviceTypeConstructionWithEnum(XlaDeviceType::NEURON, "NEURON");
77+
CheckDeviceTypeConstructionWithEnum(XlaDeviceType::SPMD, "SPMD");
7678
}
7779

7880
TEST(DeviceTest, ConstructPluginDeviceTypeWithEnumError) {

torch_xla/csrc/device.cpp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
#include <cstddef>
66
#include <cstdint>
77
#include <exception>
8+
#include <iterator>
89
#include <memory>
910
#include <string>
1011
#include <string_view>
11-
#include <utility>
1212
#include <vector>
1313

1414
#include <torch/csrc/lazy/backend/backend_device.h>
1515

16-
#include "absl/log/absl_check.h"
1716
#include "absl/status/status.h"
1817
#include "absl/status/statusor.h"
1918
#include "absl/strings/str_cat.h"
@@ -30,14 +29,13 @@ namespace {
3029
static bool spmd_config_is_locked = false;
3130
static bool use_virtual_device = false;
3231

33-
constexpr std::size_t kNativeXlaDeviceTypeNumber = 5;
34-
constexpr std::array<std::pair<XlaDeviceType, std::string_view>,
35-
kNativeXlaDeviceTypeNumber>
36-
kNativeXlaDeviceTypeWithName = {{
37-
#define XLA_DEVICE_NAME_PAIR(name, _) {XlaDeviceType::name, #name},
38-
XLA_FOR_ALL_NATIVE_DEVICE_TYPES_(XLA_DEVICE_NAME_PAIR)
39-
#undef XLA_DEVICE_NAME_PAIR
40-
}};
32+
constexpr int8_t kNativeXlaDeviceTypeNumber =
33+
static_cast<int8_t>(XlaDeviceType::PLUGIN);
34+
35+
// The elements in this array should match the order in the XlaDeviceType enum
36+
// declaration. So, if you modify one of them, make sure to keep them in sync.
37+
constexpr std::array<std::string_view, kNativeXlaDeviceTypeNumber>
38+
kNativeXlaDeviceTypeNames = {"CPU", "CUDA", "TPU", "NEURON", "SPMD"};
4139

4240
absl::Status CheckIsNativeXlaDeviceType(int8_t value) {
4341
if (value < 0 || value >= kNativeXlaDeviceTypeNumber) {
@@ -62,21 +60,20 @@ std::string_view NativeXlaDeviceTypeToString(XlaDeviceType type) {
6260
// 2. The XlaDeviceType::PLUGIN enum, since it's not considered a "native"
6361
// device type
6462
XLA_CHECK_OK(CheckIsNativeXlaDeviceType(value));
65-
return kNativeXlaDeviceTypeWithName[value].second;
63+
return kNativeXlaDeviceTypeNames[value];
6664
}
6765

6866
XlaDeviceType StringToXlaDeviceType(std::string_view type_name) {
69-
std::array<std::pair<XlaDeviceType, std::string_view>,
70-
kNativeXlaDeviceTypeNumber>::const_iterator it =
71-
std::find_if(kNativeXlaDeviceTypeWithName.begin(),
72-
kNativeXlaDeviceTypeWithName.end(),
73-
[=](const std::pair<XlaDeviceType, std::string_view>& pair) {
74-
return pair.second == type_name;
75-
});
76-
if (it != kNativeXlaDeviceTypeWithName.end()) {
77-
return it->first;
67+
std::array<std::string_view, kNativeXlaDeviceTypeNumber>::const_iterator it =
68+
std::find(kNativeXlaDeviceTypeNames.begin(),
69+
kNativeXlaDeviceTypeNames.end(), type_name);
70+
71+
if (it == kNativeXlaDeviceTypeNames.end()) {
72+
return XlaDeviceType::PLUGIN;
7873
}
79-
return XlaDeviceType::PLUGIN;
74+
75+
std::size_t index = std::distance(kNativeXlaDeviceTypeNames.begin(), it);
76+
return static_cast<XlaDeviceType>(index);
8077
}
8178

8279
} // namespace

torch_xla/csrc/device.h

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,20 @@
11
#ifndef XLA_TORCH_XLA_CSRC_DEVICE_H_
22
#define XLA_TORCH_XLA_CSRC_DEVICE_H_
33

4+
#include <cstdint>
45
#include <string>
56
#include <string_view>
67

78
#include <torch/csrc/lazy/backend/backend_device.h>
8-
#include <torch/csrc/lazy/core/hash.h>
9-
#include <torch/csrc/lazy/core/util.h>
109

1110
#include "absl/status/statusor.h"
1211

1312
namespace torch_xla {
1413

15-
// Convenient macro for applying another macro to all native device types.
16-
//
17-
// Add new device type
18-
// ===================
19-
//
20-
// Add a new line to the macro below:
21-
//
22-
// _(<DEVICE>, <INDEX>)
23-
//
24-
// Where <DEVICE> is the enum of the given device, and <INDEX> is the
25-
// previous number plus 1.
26-
//
27-
#define XLA_FOR_ALL_NATIVE_DEVICE_TYPES_(_) \
28-
_(CPU, 0) \
29-
_(CUDA, 1) \
30-
_(TPU, 2) \
31-
_(NEURON, 3) \
32-
_(SPMD, 4)
33-
3414
// TODO(yeounoh) `SPMD` is a virtual device that defers data `TransferToDevice`
3515
// until after the paritioning pass. This avoids transfering the full input
3616
// tensor to the device.
37-
enum class XlaDeviceType : int8_t {
38-
#define XLA_DECLARE_ENUM(name, value) name = value,
39-
XLA_FOR_ALL_NATIVE_DEVICE_TYPES_(XLA_DECLARE_ENUM)
40-
#undef XLA_DECLARE_ENUM
41-
42-
// Plugin is not considered a native device type.
43-
// It has a special treatment for some functions.
44-
PLUGIN,
45-
};
17+
enum class XlaDeviceType : int8_t { CPU = 0, CUDA, TPU, NEURON, SPMD, PLUGIN };
4618

4719
struct DeviceType : public torch::lazy::BackendDeviceType {
4820
DeviceType(XlaDeviceType xla_device_type);

0 commit comments

Comments
 (0)