diff --git a/.gitmodules b/.gitmodules index 262aaf3..67e4d7b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "third_party/oneCCL"] path = third_party/oneCCL - url = https://github.com/intel-innersource/libraries.performance.communication.oneccl.git + url = https://github.com/oneapi-src/oneCCL.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 70288df..9a72c46 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,14 +6,6 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wformat") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror=cpp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wformat-security") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector") -# Since 2016 Debian start using RUNPATH instead of normally RPATH, which gave the annoy effect that -# allow LD_LIBRARY_PATH to override dynamic linking path. Depends on intention of linking priority, -# change below for best outcome: disable, using RPATH, enable, using RUNPATH -if (ENABLE_LINKER_RUNPATH) - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--enable-new-dtags") -else() - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--disable-new-dtags") -endif() set(LINUX TRUE) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) diff --git a/README.md b/README.md index 0f82a86..41ce27e 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ This repository holds PyTorch bindings maintained by Intel for the Intel® oneAP [PyTorch](https://github.com/pytorch/pytorch) is an open-source machine learning framework. -[Intel® oneCCL](https://github.com/oneapi-src/oneCCL) (collective communications library) is a library for efficient distributed deep learning training implementing such collectives like `allreduce`, `allgather`, `alltoall`. For more information on oneCCL, please refer to the [oneCCL documentation](https://spec.oneapi.com/versions/latest/elements/oneCCL/source/index.html) and [oneCCL specification](https://spec.oneapi.com/versions/latest/elements/oneCCL/source/index.html). +[Intel® oneCCL](https://github.com/oneapi-src/oneCCL) (collective communications library) is a library for efficient distributed deep learning training implementing such collectives like `allreduce`, `allgather`, `alltoall`. For more information on oneCCL, please refer to the [oneCCL documentation](https://spec.oneapi.com/versions/latest/elements/oneCCL/source/index.html). `oneccl_bindings_for_pytorch` module implements PyTorch C10D ProcessGroup API and can be dynamically loaded as external ProcessGroup and only works on Linux platform now. @@ -16,26 +16,28 @@ The table below shows which functions are available for use with CPU / Intel dGP | | CPU | GPU | | :--------------- | :---: | :---: | -| `send` | × | × | -| `recv` | × | × | +| `send` | × | √ | +| `recv` | × | √ | | `broadcast` | √ | √ | | `all_reduce` | √ | √ | | `reduce` | √ | √ | | `all_gather` | √ | √ | | `gather` | √ | √ | | `scatter` | × | × | -| `reduce_scatter` | × | × | +| `reduce_scatter` | √ | √ | | `all_to_all` | √ | √ | | `barrier` | √ | √ | ## Pytorch API Align -We recommend Anaconda as Python package management system. The following is the corresponding branches (tags) of `oneccl_bindings_for_pytorch` and supported Pytorch. +We recommend using Anaconda as Python package management system. The followings are the corresponding branches (tags) of `oneccl_bindings_for_pytorch` and supported Pytorch. | `torch` | `oneccl_bindings_for_pytorch` | | :-------------------------------------------------------------: | :-----------------------------------------------------------------------: | | `master` | `master` | + [v2.1.0](https://github.com/pytorch/pytorch/tree/v2.1.0) | [ccl_torch2.1.0+cpu](https://github.com/intel/torch-ccl/tree/ccl_torch2.1.0%2Bcpu) | + | [v2.0.1](https://github.com/pytorch/pytorch/tree/v2.0.1) | [ccl_torch2.0.100](https://github.com/intel/torch-ccl/tree/ccl_torch2.0.100) | | [v1.13](https://github.com/pytorch/pytorch/tree/v1.13) | [ccl_torch1.13](https://github.com/intel/torch-ccl/tree/ccl_torch1.13) | | [v1.12.1](https://github.com/pytorch/pytorch/tree/v1.12.1) | [ccl_torch1.12.100](https://github.com/intel/torch-ccl/tree/ccl_torch1.12.100) | | [v1.12.0](https://github.com/pytorch/pytorch/tree/v1.12.0) | [ccl_torch1.12](https://github.com/intel/torch-ccl/tree/ccl_torch1.12) | @@ -45,33 +47,34 @@ We recommend Anaconda as Python package management system. The following is the | [v1.8.1](https://github.com/pytorch/pytorch/tree/v1.8.1) | [ccl_torch1.8](https://github.com/intel/torch-ccl/tree/ccl_torch1.8) | | [v1.7.1](https://github.com/pytorch/pytorch/tree/v1.7.1) | [ccl_torch1.7](https://github.com/intel/torch-ccl/tree/ccl_torch1.7) | | [v1.6.0](https://github.com/pytorch/pytorch/tree/v1.6.0) | [ccl_torch1.6](https://github.com/intel/torch-ccl/tree/ccl_torch1.6) | - | [v1.5-rc3](https://github.com/pytorch/pytorch/tree/v1.5.0-rc3) | [beta09](https://github.com/intel/torch-ccl/tree/beta09) | + | [v1.5-rc3](https://github.com/pytorch/pytorch/tree/v1.5.0-rc3) | [beta09](https://github.com/intel/torch-ccl/tree/beta09) | The usage details can be found in the README of corresponding branch. The following part is about the usage of v1.9 tag. if you want to use other version of torch-ccl please checkout to that branch(tag). For pytorch-1.5.0-rc3, the [#PR28068](https://github.com/pytorch/pytorch/pull/28068) and [#PR32361](https://github.com/pytorch/pytorch/pull/32361) are need to dynamicall register external ProcessGroup and enable `alltoall` collective communication primitive. The patch file about these two PRs is in `patches` directory and you can use it directly. ## Requirements -- Python 3.6 or later and a C++17 compiler +- Python 3.8 or later and a C++17 compiler -- PyTorch v1.13.0 +- PyTorch v2.0.1 ## Build Option List The following build options are supported in Intel® oneCCL Bindings for PyTorch*. | Build Option | Default Value | Description | -| :---------------------------------: | :------------: | :-------------------------------------------------------------------------------------------------: | +| :---------------------------------- | :------------- | :-------------------------------------------------------------------------------------------------- | | COMPUTE_BACKEND | | Set oneCCL `COMPUTE_BACKEDN`,set to `dpcpp` and use DPC++ Compiler to enable support for Intel XPU | +| USE_SYSTEM_ONECCL | OFF | Use oneCCL library in system | | CCL_PACKAGE_NAME | oneccl-bind-pt | Set Wheel Name | | ONECCL_BINDINGS_FOR_PYTORCH_BACKEND | cpu | Set BACKEND | -| CCL_SHA_VERSION | False |add git head sha version to Wheel name | +| CCL_SHA_VERSION | False | add git head sha version to Wheel name | ## Lunch Option List The following lunch options are supported in Intel® oneCCL Bindings for PyTorch*. | Lunch Option | Default Value | Description | -| :--------------------------------------: | :-----------: | :-------------------------------------------------------------------: | +| :--------------------------------------- | :------------ | :-------------------------------------------------------------------- | | ONECCL_BINDINGS_FOR_PYTORCH_ENV_VERBOSE | 0 | Set verbose level in ONECCL_BINDINGS_FOR_PYTORCH | | ONECCL_BINDINGS_FOR_PYTORCH_ENV_WAIT_GDB | 0 | Set 1 to force the oneccl_bindings_for_pytorch wait for GDB attaching | @@ -96,23 +99,48 @@ The following lunch options are supported in Intel® oneCCL Bindings for PyTorch # build with oneCCL from third party COMPUTE_BACKEND=dpcpp python setup.py install # build without oneCCL - BUILD_NO_ONECCL_PACKAGE=ON COMPUTE_BACKEND=dpcpp python setup.py install + export INTELONEAPIROOT=${HOME}/intel/oneapi + USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py install ``` ### Install PreBuilt Wheel Wheel files are avaiable for the following Python versions. -| Extension Version | Python 3.6 | Python 3.7 | Python 3.8 | Python 3.9 | Python 3.10 | -| :---------------: | :--------: | :--------: | :--------: | :--------: | :---------: | -| 1.13 | | √ | √ | √ | √ | -| 1.12.100 | | √ | √ | √ | √ | -| 1.12.0 | | √ | √ | √ | √ | -| 1.11.0 | | √ | √ | √ | √ | -| 1.10.0 | √ | √ | √ | √ | | +| Extension Version | Python 3.6 | Python 3.7 | Python 3.8 | Python 3.9 | Python 3.10 | Python 3.11 | +| :---------------: | :--------: | :--------: | :--------: | :--------: | :---------: | :---------: | +| 2.0.100 | | | √ | √ | √ | √ | +| 1.13 | | √ | √ | √ | √ | | +| 1.12.100 | | √ | √ | √ | √ | | +| 1.12.0 | | √ | √ | √ | √ | | +| 1.11.0 | | √ | √ | √ | √ | | +| 1.10.0 | √ | √ | √ | √ | | | ```bash -python -m pip install oneccl_bind_pt==1.13 -f https://developer.intel.com/ipex-whl-stable-cpu +python -m pip install oneccl_bind_pt==2.0.100 -f https://developer.intel.com/ipex-whl-stable-xpu +``` + +### Runtime Dynamic Linking + +- If oneccl_bindings_for_pytorch is built without oneCCL and use oneCCL in system, dynamic link oneCCl from oneAPI basekit (recommended usage): + +```bash +source $basekit_root/ccl/latest/env/vars.sh +``` + +Note: Make sure you have installed [basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/toolkits.html#base-kit) when using Intel® oneCCL Bindings for Pytorch\* on Intel® GPUs. + +- If oneccl_bindings_for_pytorch is built with oneCCL from third party or installed from prebuilt wheel: +Dynamic link oneCCL and Intel MPI libraries: + +```bash +source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl;print(torch_ccl.cwd)")/env/setvars.sh +``` + +Dynamic link oneCCL only (not including Intel MPI): + +```bash +source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl;print(torch_ccl.cwd)")/env/vars.sh ``` ## Usage @@ -145,16 +173,11 @@ model = torch.nn.parallel.DistributedDataParallel(model, ...) ... ``` -(oneccl_bindings_for_pytorch is installed along with the MPI tool set.) +(oneccl_bindings_for_pytorch is built without oneCCL, use oneCCL and MPI(if needed) in system) ```bash - -source /env/setvars.sh - -# eg: -# $ oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)") -# $ source $oneccl_bindings_for_pytorch_path/env/setvars.sh - +source $basekit_root/ccl/latest/env/vars.sh +source $basekit_root/mpi/latest/env/vars.sh mpirun -n -ppn -f python example.py ``` @@ -226,6 +249,10 @@ mpirun -n 2 -l python profiling.py ``` +## Known Issues + +For Point-to-point communication, directly call dist.send/recv after initializing the process group in launch script will trigger runtime error. Because all ranks of the group are expected to participate in this call to create communicators in our current implementation, while dist.send/recv only has a pair of ranks' participation. As a result, dist.send/recv should be used after collective call, which ensures all ranks' participation. The further solution for supporting directly call dist.send/recv after initializing the process group is still under investigation. + ## License [BSD License](https://github.com/intel/torch-ccl/blob/master/LICENSE) diff --git a/setup.py b/setup.py index 84718f0..77c9b1c 100644 --- a/setup.py +++ b/setup.py @@ -102,16 +102,11 @@ def build_cmake(self, extension: CMakeExtension): if _check_env_flag('DEBUG'): build_type = 'Debug' - run_path = 'OFF' - if _check_env_flag('RUNPATH'): - run_path = 'ON' build_options = { 'CMAKE_BUILD_TYPE': build_type, # The value cannot be easily obtained in CMakeLists.txt. 'CMAKE_PREFIX_PATH': torch.utils.cmake_prefix_path, - # Enable the RPATH of the oneCCL and torchCCL - 'ENABLE_LINKER_RUNPATH': run_path, # skip the example and test code in oneCCL 'BUILD_EXAMPLES': 'OFF', 'BUILD_CONFIG': 'OFF', diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cd56bac..d2a9f56 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -10,6 +10,7 @@ target_compile_options(oneccl_bindings_for_pytorch PUBLIC -Wall if(COMPUTE_BACKEND STREQUAL "dpcpp") add_subdirectory(./gpu) + add_definitions (-DUSE_GPU) endif() target_include_directories(oneccl_bindings_for_pytorch PUBLIC ./) diff --git a/src/ProcessGroupCCL.cpp b/src/ProcessGroupCCL.cpp index 5bf7507..4a3cd87 100644 --- a/src/ProcessGroupCCL.cpp +++ b/src/ProcessGroupCCL.cpp @@ -42,6 +42,354 @@ namespace c10d { +namespace ops { + +std::tuple, c10::intrusive_ptr> broadcast_xpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t root_tensor, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->broadcast( + tensor_vec, + BroadcastOptions{ + root_rank, root_tensor, std::chrono::milliseconds(timeout)}); + + return std::tuple, c10::intrusive_ptr>( + std::move(tensor_vec), work); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("broadcast_", broadcast_xpu_); +} + +#if TORCH_VERSION_MAJOR > 1 && TORCH_VERSION_MINOR >= 1 +// PyTorch 2.1 allreduce support sparse tensor +std::tuple, c10::intrusive_ptr> allreduce_xpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + const c10::optional& sparse_indices, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->allreduce( + tensor_vec, + c10d::AllreduceOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); + + // Return input tensors as output tensors to make inplace allreduce look like + // a functional API, so that make_fx can correctly build the dependencies in + // the graph later. + return std::tuple, c10::intrusive_ptr>( + std::move(tensor_vec), work); +} +#else +// TODO: Remove after updating to PyTorch 2.1 +std::tuple, c10::intrusive_ptr> allreduce_xpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->allreduce( + tensor_vec, + c10d::AllreduceOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); + + // Return input tensors as output tensors to make inplace allreduce look like + // a functional API, so that make_fx can correctly build the dependencies in + // the graph later. + return std::tuple, c10::intrusive_ptr>( + std::move(tensor_vec), work); +} +#endif + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("allreduce_", allreduce_xpu_); +} + +c10::intrusive_ptr allreduce_coalesced_xpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->getBackend(c10::DeviceType::XPU) + ->allreduce_coalesced(tensor_vec, opts); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("allreduce_coalesced_", allreduce_coalesced_xpu_); +} + +c10::intrusive_ptr reduce_xpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t root_rank, + int64_t root_tensor, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->reduce( + tensor_vec, + ReduceOptions{ + *reduce_op.get(), + root_rank, + root_tensor, + std::chrono::milliseconds(timeout)}); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("reduce_", reduce_xpu_); +} + +std::tuple>, c10::intrusive_ptr> +allgather_xpu_( + const std::vector>& output_tensors, + at::TensorList input_tensors, + const c10::intrusive_ptr& process_group, + int64_t timeout) { + auto input_tensors_vec = input_tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->allgather( + const_cast>&>(output_tensors), + input_tensors_vec, + AllgatherOptions{std::chrono::milliseconds(timeout)}); + + // Copy output tensors (not storage) so that this can be used in a functional + // manner + return std:: + tuple>, c10::intrusive_ptr>( + output_tensors, work); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("allgather_", allgather_xpu_); +} + +std::tuple> _allgather_base_xpu_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + auto work = process_group->getBackend(c10::DeviceType::XPU) + ->_allgather_base(output_tensor, input_tensor); + + return std::tuple>(output_tensor, work); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("_allgather_base_", _allgather_base_xpu_); +} + +c10::intrusive_ptr allgather_coalesced_xpu_( + const std::vector>& output_lists, + const at::TensorList& input_list, + const c10::intrusive_ptr& process_group) { + auto input_list_vec = input_list.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->allgather_coalesced( + const_cast>&>(output_lists), + input_list_vec); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("allgather_coalesced_", allgather_coalesced_xpu_); +} + +c10::intrusive_ptr gather_xpu_( + const std::vector>& output_tensors, + const at::TensorList& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + auto input_tensors_vec = input_tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->gather( + const_cast>&>(output_tensors), + input_tensors_vec, + GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("gather_", gather_xpu_); +} + +std::tuple, c10::intrusive_ptr> scatter_xpu_( + const at::TensorList& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->scatter( + output_tensors_vec, + const_cast>&>(input_tensors), + ScatterOptions{root_rank, std::chrono::milliseconds(timeout)}); + + return std::tuple, c10::intrusive_ptr>( + std::move(output_tensors_vec), work); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("scatter_", scatter_xpu_); +} + +std::tuple, c10::intrusive_ptr> +reduce_scatter_xpu_( + const at::TensorList& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->reduce_scatter( + output_tensors_vec, + const_cast>&>(input_tensors), + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); + + return std::tuple, c10::intrusive_ptr>( + output_tensors_vec, work); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("reduce_scatter_", reduce_scatter_xpu_); +} + +std::tuple> _reduce_scatter_base_xpu_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->_reduce_scatter_base( + output_tensor, + input_tensor, + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); + + return std::tuple>(output_tensor, work); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("_reduce_scatter_base_", _reduce_scatter_base_xpu_); +} + +c10::intrusive_ptr alltoall_base_xpu_( + at::Tensor& output, + at::Tensor& input, + const c10::intrusive_ptr& process_group, + std::vector output_split_sizes, + std::vector input_split_sizes, + int64_t timeout) { + return process_group->getBackend(c10::DeviceType::XPU) + ->alltoall_base( + output, + input, + output_split_sizes, + input_split_sizes, + AllToAllOptions{std::chrono::milliseconds(timeout)}); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("alltoall_base_", alltoall_base_xpu_); +} + +std::tuple, c10::intrusive_ptr> alltoall_xpu_( + const at::TensorList& output_tensors, + const at::TensorList& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto input_tensors_vec = input_tensors.vec(); + auto work = process_group->getBackend(c10::DeviceType::XPU) + ->alltoall( + output_tensors_vec, + input_tensors_vec, + AllToAllOptions{std::chrono::milliseconds(timeout)}); + return std::tuple, c10::intrusive_ptr>( + std::move(output_tensors_vec), work); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("alltoall_", alltoall_xpu_); +} + +c10::intrusive_ptr send_xpu( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t dstRank, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->send(tensor_vec, static_cast(dstRank), static_cast(tag)); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("send", send_xpu); +} + +c10::intrusive_ptr recv_xpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t srcRank, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->recv(tensor_vec, static_cast(srcRank), static_cast(tag)); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("recv_", recv_xpu_); +} + +c10::intrusive_ptr recv_any_source_xpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->recvAnysource(tensor_vec, static_cast(tag)); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("recv_any_source_", recv_any_source_xpu_); +} + +c10::intrusive_ptr barrier_xpu( + at::Tensor /* unused */, + const c10::intrusive_ptr& process_group, + const std::vector& device_ids, + int64_t timeout) { + return process_group->getBackend(c10::DeviceType::XPU) + ->barrier(BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("barrier", barrier_xpu); +} +} // namespace ops + + using oneccl_bindings_for_pytorch::DispatchStub; using oneccl_bindings_for_pytorch::call_with_lock; using oneccl_bindings_for_pytorch::format_tensors_param; diff --git a/src/ProcessGroupCCL.hpp b/src/ProcessGroupCCL.hpp index 474bffb..ef71cac 100644 --- a/src/ProcessGroupCCL.hpp +++ b/src/ProcessGroupCCL.hpp @@ -41,6 +41,7 @@ #if TORCH_VERSION_MAJOR > 1 || TORCH_VERSION_MINOR >= 13 #if TORCH_VERSION_MAJOR > 1 #include + #include #else #include #endif diff --git a/src/dispatch_stub.cpp b/src/dispatch_stub.cpp index c46695e..9f0340a 100644 --- a/src/dispatch_stub.cpp +++ b/src/dispatch_stub.cpp @@ -548,7 +548,13 @@ c10::intrusive_ptr DispatchStub::recv(std::vector c10::intrusive_ptr DispatchStub::barrier(const BarrierOptions& opts, ProcessGroupCCL& pg_ccl) { +#ifdef USE_GPU + std::cout << "Barrier: using xpu" << std::endl; + c10::DeviceType dev_type = c10::DeviceType::XPU; +#else + std::cout << "Barrier: using cpu" << std::endl; c10::DeviceType dev_type = c10::DeviceType::CPU; +#endif return get_ccl_stub(dev_type)->barrier_(opts, pg_ccl); } diff --git a/src/gpu/dpcpp_ccl.cpp b/src/gpu/dpcpp_ccl.cpp index cd946a3..9237aa6 100644 --- a/src/gpu/dpcpp_ccl.cpp +++ b/src/gpu/dpcpp_ccl.cpp @@ -401,6 +401,9 @@ class XPUCCLStubs final: public DispatchStub { int tag, ProcessGroupCCL& pg) override; + c10::intrusive_ptr barrier_(const BarrierOptions& opts, + ProcessGroupCCL& pg) override; + void destroy(); void reset() override {} void runLoop(); @@ -1018,6 +1021,8 @@ c10::intrusive_ptr XPUCCLStubs::alltoall_(std::ve outputs[i].view({-1}).copy_(flatOutputSplits[i]); } } + + torch_stream.synchronize(); return ret_evt; }, c10d::OpType::ALLTOALL); @@ -1114,6 +1119,35 @@ c10::intrusive_ptr XPUCCLStubs::recv_(std::vector return work; } +c10::intrusive_ptr XPUCCLStubs::barrier_(const BarrierOptions& opts, + ProcessGroupCCL& pg) { + + c10::intrusive_ptr work = c10::make_intrusive(); + + if (pg.ccl_member_->ccl_comms.size() == 0) { + std::vector xpu_devices{at::Device(at::kXPU)}; + const auto key = get_key_from_devs(xpu_devices); + get_ccl_comms(pg, key, xpu_devices); + } + + auto& comms_map = pg.ccl_member_->ccl_comms; + for(auto iter = comms_map.begin(); iter != comms_map.end(); iter++){ + for(size_t i =0 ; i < iter->second->comms.size(); i++){ + work->getEvents().emplace_back( + call_with_lock(c10d::ProcessGroupCCL::globalMutex, [&](){ + if (i < iter->second->streams.size()) { + CCL_CHECK(return ccl::barrier(iter->second->comms[i], + iter->second->streams[i]);); + } else { + CCL_CHECK(return ccl::barrier(iter->second->comms[i]);); + } + }) + ); + } + } + return work; +} + RegisterXPUMethods xpu_register; } diff --git a/src/utils.h b/src/utils.h index 60b10a4..37036fd 100644 --- a/src/utils.h +++ b/src/utils.h @@ -683,7 +683,7 @@ c10::intrusive_ptr pointToPoint( attr_t attr = ccl::create_operation_attr(); const auto devices = get_device_list(inputs); std::string key; - int p2pRank = 0, p2pTargetRank = 0; + int p2pRank = 0; bool isSendRecvSelf = false; int rank_ = pg_ccl.getRank(); @@ -696,14 +696,12 @@ c10::intrusive_ptr pointToPoint( // rank and my peer key = get_key_from_devs(devices); p2pRank = rank_; - p2pTargetRank = peer; } else { // TODO: single P2P // For single P2P, preserve the old two-rank behavior (to avoid perf diff) key = get_key_send_recv(rank_, peer); p2pRank = rank_ <= peer ? 0 : 1; isSendRecvSelf = rank_ == peer; - p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; } auto& comms = get_ccl_fn(pg_ccl, key, devices, op_type, p2pRank, isSendRecvSelf); diff --git a/tests/README.md b/tests/README.md index c29b804..c393921 100644 --- a/tests/README.md +++ b/tests/README.md @@ -23,6 +23,19 @@ For cross-nodes p2p test, run: mpiexec -host nodeA,nodeB -np 24 -ppn 12 python -u test_p2p_crossnodes.py --dist_url $NODE_IP --world_size 24 ``` +## functionality validation of barrier +For cpu barrier, run: + +```bash +mpirun -np 2 python test_barrier.py +``` + +For xpu barrier (built with "COMPUTE_BACKEND=dpcpp"), run: + +```bash +mpirun -np 2 python test_barrier.py --device xpu +``` + ## broadcast/allreduce profiling To start the test_allreduce.py test, run: diff --git a/tests/test_barrier.py b/tests/test_barrier.py new file mode 100644 index 0000000..4a51a52 --- /dev/null +++ b/tests/test_barrier.py @@ -0,0 +1,28 @@ +import torch +import intel_extension_for_pytorch +import oneccl_bindings_for_pytorch +import torch.distributed as dist +import os + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--device', '-dev', type=str, default='cpu', help='Device type to use: cpu, xpu') +args = parser.parse_args() + +os.environ['RANK'] = str(os.environ.get('PMI_RANK', 0)) +os.environ['WORLD_SIZE'] = str(os.environ.get('PMI_SIZE', 1)) +os.environ['MASTER_ADDR'] = '127.0.0.1' +os.environ['MASTER_PORT'] = '29500' + +dist.init_process_group("ccl") +rank = dist.get_rank() +size = dist.get_world_size() + +if args.device == 'xpu': + device = "xpu:{}".format(rank) +else: + device = 'cpu' + +print("Barrier using device: ", args.device) +dist.barrier() +print("Finish") diff --git a/third_party/oneCCL b/third_party/oneCCL index 9fc00e3..b4c31ba 160000 --- a/third_party/oneCCL +++ b/third_party/oneCCL @@ -1 +1 @@ -Subproject commit 9fc00e3509c8554b8dd2e30d3e9481b39ae8eb9f +Subproject commit b4c31ba6b42b5bcc7228820e44d2d3f3d519fdc7 diff --git a/version.txt b/version.txt index 61b11cb..7ec1d6d 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -1.13.100+gpu +2.1.0