From 41a6373e6c201603e97d699c20516e9192c08bf8 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Sat, 3 Jan 2026 05:43:50 -0800 Subject: [PATCH 1/2] Initial commit to pass scale as Tensor for multi_tensor_scale op --- .../include/transformer_engine/multi_tensor.h | 16 ++++ .../common/multi_tensor/scale.cu | 93 +++++++++++++++++++ transformer_engine/pytorch/csrc/extensions.h | 3 + .../csrc/extensions/multi_tensor/scale.cpp | 11 +++ .../pytorch/csrc/extensions/pybind.cpp | 3 + .../pytorch/optimizers/__init__.py | 1 + 6 files changed, 127 insertions(+) diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index 1bea4cb21f..70fe979400 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -244,6 +244,22 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens const size_t num_tensor_lists, const size_t num_tensors_per_list, float scale, cudaStream_t stream); +/*! \brief Check overflow and scale a list of tensors. scale is tensor input. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] scale Tensor for the scaling operation. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, + const size_t num_tensor_lists, const size_t num_tensors_per_list, + NVTETensor scale, cudaStream_t stream); + /*! \brief Check overflow and scale a list of tensors. * * \warning This API is **experimental** and subject to change. diff --git a/transformer_engine/common/multi_tensor/scale.cu b/transformer_engine/common/multi_tensor/scale.cu index b3266200c4..3a7e4d5554 100644 --- a/transformer_engine/common/multi_tensor/scale.cu +++ b/transformer_engine/common/multi_tensor/scale.cu @@ -102,6 +102,75 @@ struct ScaleFunctor { } }; +template +struct ScalePtrFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<2> &tl, // NOLINT(*) + float *scale_ptr) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + float scale = *scale_ptr; + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + in_t *in = reinterpret_cast(tl.addresses[0][tensor_loc]); + in += chunk_idx * chunk_size; + + out_t *out = reinterpret_cast(tl.addresses[1][tensor_loc]); + out += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + bool finite = true; + in_t r_in[ILP]; + out_t r_out[ILP]; + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_in, in, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(static_cast(r_in[ii])); + } + // store + load_store(out, r_out, i_start, 0); + } + } else { + // Non-divergent exit condition for __syncthreads, not necessary here + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_in[ii] = 0.f; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) r_in[ii] = in[i]; + } + // note for clarification to future michael: + // From a pure memory dependency perspective, there's likely no point unrolling + // the write loop, since writes just fire off once their LDGs arrive. + // Put another way, the STGs are dependent on the LDGs, but not on each other. + // There is still compute ILP benefit from unrolling the loop though. +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(static_cast(r_in[ii])); + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) out[i] = r_out[ii]; + } + } + } + if (!finite) *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. + } +}; + void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, float scale, cudaStream_t stream) { @@ -114,6 +183,18 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK_CUDA(cudaGetLastError()); } +void multi_tensor_scale_tensor_cuda(int chunk_size, Tensor noop_flag, + std::vector> tensor_lists, float *scale, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + tensor_lists[0][0]->dtype(), p_in_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + tensor_lists[1][0]->dtype(), g_in_type, + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + ScalePtrFunctor(), stream, scale);)) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace multi_tensor_scale } // namespace transformer_engine @@ -127,3 +208,15 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream); } + +void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, + const size_t num_tensor_lists, const size_t num_tensors_per_list, + NVTETensor scale, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_scale_tensor_cuda); + using namespace transformer_engine; + + Tensor *scale_tensor = convertNVTETensorCheck(scale); + multi_tensor_scale::multi_tensor_scale_tensor_cuda( + chunk_size, *convertNVTETensorCheck(noop_flag), + convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), reinterpret_cast(scale_tensor->data.dptr), stream); +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 52ef02a347..a4cad031d2 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -411,6 +411,9 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, float scale); +void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor scale); + std::tuple multi_tensor_l2norm_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::optional per_tensor_python); diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp index 4bb83bfeed..5199975043 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp @@ -18,4 +18,15 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, num_tensors, scale, at::cuda::getCurrentCUDAStream()); } +void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor scale) { + auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); + auto scale_cu = makeTransformerEngineTensor(scale); + auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = + makeTransformerEngineTensorList(tensor_lists); + std::cout << "multi_tensor_scale_cuda TENSOR\n"; + nvte_multi_tensor_scale_tensor_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, + num_tensors, scale_cu.data(), at::cuda::getCurrentCUDAStream()); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e73eca7861..9fb91ecd09 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -402,6 +402,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_scale", &transformer_engine::pytorch::multi_tensor_scale_cuda, "Fused overflow check + scale for a list of contiguous tensors", py::call_guard()); + m.def("multi_tensor_scale_tensor", &transformer_engine::pytorch::multi_tensor_scale_tensor_cuda, + "Fused overflow check + scale for a list of contiguous tensors with scale passed as tensor", + py::call_guard()); m.def("multi_tensor_l2norm", &transformer_engine::pytorch::multi_tensor_l2norm_cuda, "Computes L2 norm for a list of contiguous tensors", py::call_guard()); diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index 792eab094a..7220f1924a 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -5,6 +5,7 @@ """Fused optimizers and multi-tensor kernels.""" from transformer_engine_torch import ( multi_tensor_scale, + multi_tensor_scale_tensor, multi_tensor_l2norm, multi_tensor_unscale_l2norm, multi_tensor_adam, From ab07859128a9beab4dcb552882513c349fa3ccee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 22:29:10 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/include/transformer_engine/multi_tensor.h | 7 ++++--- transformer_engine/common/multi_tensor/scale.cu | 10 ++++++---- transformer_engine/pytorch/csrc/extensions.h | 3 ++- .../pytorch/csrc/extensions/multi_tensor/scale.cpp | 8 +++++--- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index 70fe979400..73616ce88e 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -256,9 +256,10 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens * \param[in] scale Tensor for the scaling operation. * \param[in] stream CUDA stream used for this operation. */ -void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, - const size_t num_tensor_lists, const size_t num_tensors_per_list, - NVTETensor scale, cudaStream_t stream); +void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, const size_t num_tensor_lists, + const size_t num_tensors_per_list, NVTETensor scale, + cudaStream_t stream); /*! \brief Check overflow and scale a list of tensors. * diff --git a/transformer_engine/common/multi_tensor/scale.cu b/transformer_engine/common/multi_tensor/scale.cu index 3a7e4d5554..c68b935ce5 100644 --- a/transformer_engine/common/multi_tensor/scale.cu +++ b/transformer_engine/common/multi_tensor/scale.cu @@ -209,14 +209,16 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream); } -void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, - const size_t num_tensor_lists, const size_t num_tensors_per_list, - NVTETensor scale, cudaStream_t stream) { +void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, const size_t num_tensor_lists, + const size_t num_tensors_per_list, NVTETensor scale, + cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_scale_tensor_cuda); using namespace transformer_engine; Tensor *scale_tensor = convertNVTETensorCheck(scale); multi_tensor_scale::multi_tensor_scale_tensor_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), - convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), reinterpret_cast(scale_tensor->data.dptr), stream); + convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), + reinterpret_cast(scale_tensor->data.dptr), stream); } diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a4cad031d2..656ec299ca 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -412,7 +412,8 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, float scale); void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, at::Tensor scale); + std::vector> tensor_lists, + at::Tensor scale); std::tuple multi_tensor_l2norm_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp index 5199975043..de957a901a 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp @@ -19,14 +19,16 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, } void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, at::Tensor scale) { + std::vector> tensor_lists, + at::Tensor scale) { auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto scale_cu = makeTransformerEngineTensor(scale); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); std::cout << "multi_tensor_scale_cuda TENSOR\n"; - nvte_multi_tensor_scale_tensor_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, - num_tensors, scale_cu.data(), at::cuda::getCurrentCUDAStream()); + nvte_multi_tensor_scale_tensor_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), + num_lists, num_tensors, scale_cu.data(), + at::cuda::getCurrentCUDAStream()); } } // namespace transformer_engine::pytorch