@@ -66,10 +66,13 @@ TORCH_LIBRARY_IMPL(c10d, XPU, m) {
6666 m.impl (" broadcast_" , broadcast_xpu_);
6767}
6868
69+ #if TORCH_VERSION_MAJOR > 1 && TORCH_VERSION_MINOR >= 1
70+ // PyTorch 2.1 allreduce support sparse tensor
6971std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>> allreduce_xpu_ (
7072 at::TensorList tensors,
7173 const c10::intrusive_ptr<ProcessGroup>& process_group,
7274 const c10::intrusive_ptr<ReduceOp>& reduce_op,
75+ const c10::optional<at::Tensor>& sparse_indices,
7376 int64_t timeout) {
7477 auto tensor_vec = tensors.vec ();
7578 auto work =
@@ -85,6 +88,28 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>> allreduce_xpu
8588 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>>(
8689 std::move (tensor_vec), work);
8790}
91+ #else
92+ // TODO: Remove after updating to PyTorch 2.1
93+ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>> allreduce_xpu_ (
94+ at::TensorList tensors,
95+ const c10::intrusive_ptr<ProcessGroup>& process_group,
96+ const c10::intrusive_ptr<ReduceOp>& reduce_op,
97+ int64_t timeout) {
98+ auto tensor_vec = tensors.vec ();
99+ auto work =
100+ process_group->getBackend (c10::DeviceType::XPU)
101+ ->allreduce (
102+ tensor_vec,
103+ c10d::AllreduceOptions{
104+ *reduce_op.get (), std::chrono::milliseconds (timeout)});
105+
106+ // Return input tensors as output tensors to make inplace allreduce look like
107+ // a functional API, so that make_fx can correctly build the dependencies in
108+ // the graph later.
109+ return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<C10D_Work>>(
110+ std::move (tensor_vec), work);
111+ }
112+ #endif
88113
89114TORCH_LIBRARY_IMPL (c10d, XPU, m) {
90115 m.impl (" allreduce_" , allreduce_xpu_);
0 commit comments