Skip to content

Conversation

@rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Jan 8, 2026

Library Size -- Nanobind is 52.6% smaller than PyBind11

Size (MB) Library
1.060 Nanobind
2.238 Pybind11

@rdspring1 rdspring1 added the Direct Bindings Python extension with direct mapping to NvFuser CPP objects. label Jan 8, 2026
@github-actions
Copy link

github-actions bot commented Jan 8, 2026

Review updated until commit 9462d03

Description

  • Migrate from pybind11 to nanobind library for 52.6% size reduction

  • Fix github workflow configuration for new binding system

  • Resolve define_tensor_error_generator compatibility issues

  • Optimize from_pysequence usage and handle IntegerProxy cases

  • Create get_shape_from_pysequence utility function

  • Remove lambda expressions for improved performance

Changes walkthrough

Relevant files

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
API Migration Completeness

The migration from pybind11 to nanobind appears comprehensive in the ops.cpp file, but the reviewer should verify that all pybind11 API calls have been correctly replaced with nanobind equivalents. Pay special attention to return value policies, argument handling, and type conversions to ensure no functionality was lost or broken during the migration.

// clang-format off
/*
 * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 */
// clang-format on
#include <nanobind/stl/complex.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <ranges>

#include <bindings.h>
#include <direct_utils.h>
#include <ops/all_ops.h>
#include <ops/arith.h>
#include <utils.h>

namespace nvfuser::python {

namespace {

// ScalarVariant represents a NvFuser Val or python scalar.
using ScalarVariant = std::variant<Val*, PolymorphicValue::VariantType>;

// This function converts a ScalarVariant to a nvfuser Val.
Val* convertToVal(
    ScalarVariant value,
    std::optional<PrimDataType> dtype = std::nullopt) {
  // short-circuit: already a NvFuser val
  if (std::holds_alternative<Val*>(value)) {
    Val* v = std::get<Val*>(value);
    NVF_ERROR(
        dtype == std::nullopt || v == nullptr ||
        std::get<PrimDataType>(v->dtype().type) == dtype.value());
    return std::get<Val*>(value);
  }

  PolymorphicValue::VariantType pv =
      std::get<PolymorphicValue::VariantType>(value);

  // short-circuit: PolymorphicValue is empty
  if (std::holds_alternative<std::monostate>(pv)) {
    return nullptr;
  }

  // Create NvFuser Val with desired dtype
  PolymorphicValue cast_value(
      dtype.has_value() ? castToDtype(std::move(pv), dtype.value())
                        : std::move(pv));
  PrimDataType value_dtype(
      dtype.has_value() ? dtype.value()
                        : std::get<PrimDataType>(getDataType(cast_value).type));
  return IrBuilder::create<Val>(cast_value, value_dtype);
}

#define NVFUSER_DIRECT_BINDING_UNARY_OP(NAME, OP_NAME, DOCSTRING)      \
  ops.def(                                                             \
      NAME,                                                            \
      [](ScalarVariant v) -> Val* {                                    \
        return static_cast<Val* (*)(Val*)>(OP_NAME)(convertToVal(v));  \
      },                                                               \
      nb::rv_policy::reference);                                       \
  ops.def(                                                             \
      NAME,                                                            \
      [](TensorView* tv) -> TensorView* {                              \
        return static_cast<TensorView* (*)(TensorView*)>(OP_NAME)(tv); \
      },                                                               \
      DOCSTRING,                                                       \
      nb::rv_policy::reference);

#define NVFUSER_DIRECT_BINDING_BINARY_OP(NAME, OP_NAME, DOCSTRING)       \
  ops.def(                                                               \
      NAME,                                                              \
      [](ScalarVariant lhs, ScalarVariant rhs) -> Val* {                 \
        return static_cast<Val* (*)(Val*, Val*)>(OP_NAME)(               \
            convertToVal(lhs), convertToVal(rhs));                       \
      },                                                                 \
      nb::rv_policy::reference);                                         \
  ops.def(                                                               \
      NAME,                                                              \
      [](TensorView* lhs, ScalarVariant rhs) -> TensorView* {            \
        return static_cast<TensorView* (*)(TensorView*, Val*)>(OP_NAME)( \
            lhs, convertToVal(rhs));                                     \
      },                                                                 \
      nb::rv_policy::reference);                                         \
  ops.def(                                                               \
      NAME,                                                              \
      [](ScalarVariant lhs, TensorView* rhs) -> TensorView* {            \
        return static_cast<TensorView* (*)(Val*, TensorView*)>(OP_NAME)( \
            convertToVal(lhs), rhs);                                     \
      },                                                                 \
      nb::rv_policy::reference);                                         \
  ops.def(                                                               \
      NAME,                                                              \
      [](TensorView* lhs, TensorView* rhs) -> TensorView* {              \
        return static_cast<TensorView* (*)(TensorView*, TensorView*)>(   \
            OP_NAME)(lhs, rhs);                                          \
      },                                                                 \
      DOCSTRING,                                                         \
      nb::rv_policy::reference);

#define NVFUSER_DIRECT_BINDING_TERNARY_OP(NAME, OP_NAME, DOCSTRING)            \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](ScalarVariant arg1, ScalarVariant arg2, ScalarVariant arg3) -> Val* { \
        return static_cast<Val* (*)(Val*, Val*, Val*)>(OP_NAME)(               \
            convertToVal(arg1), convertToVal(arg2), convertToVal(arg3));       \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](TensorView* arg1,                                                     \
         TensorView* arg2,                                                     \
         TensorView* arg3) -> TensorView* {                                    \
        return static_cast<                                                    \
            TensorView* (*)(TensorView*, TensorView*, TensorView*)>(OP_NAME)(  \
            arg1, arg2, arg3);                                                 \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](TensorView* arg1,                                                     \
         TensorView* arg2,                                                     \
         ScalarVariant arg3) -> TensorView* {                                  \
        return static_cast<TensorView* (*)(TensorView*, TensorView*, Val*)>(   \
            OP_NAME)(arg1, arg2, convertToVal(arg3));                          \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](TensorView* arg1,                                                     \
         ScalarVariant arg2,                                                   \
         TensorView* arg3) -> TensorView* {                                    \
        return static_cast<TensorView* (*)(TensorView*, Val*, TensorView*)>(   \
            OP_NAME)(arg1, convertToVal(arg2), arg3);                          \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](ScalarVariant arg1,                                                   \
         TensorView* arg2,                                                     \
         TensorView* arg3) -> TensorView* {                                    \
        return static_cast<TensorView* (*)(Val*, TensorView*, TensorView*)>(   \
            OP_NAME)(convertToVal(arg1), arg2, arg3);                          \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](ScalarVariant arg1,                                                   \
         ScalarVariant arg2,                                                   \
         TensorView* arg3) -> TensorView* {                                    \
        return static_cast<TensorView* (*)(Val*, Val*, TensorView*)>(OP_NAME)( \
            convertToVal(arg1), convertToVal(arg2), arg3);                     \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](TensorView* arg1,                                                     \
         ScalarVariant arg2,                                                   \
         ScalarVariant arg3) -> TensorView* {                                  \
        return static_cast<TensorView* (*)(TensorView*, Val*, Val*)>(OP_NAME)( \
            arg1, convertToVal(arg2), convertToVal(arg3));                     \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](ScalarVariant arg1,                                                   \
         TensorView* arg2,                                                     \
         ScalarVariant arg3) -> TensorView* {                                  \
        return static_cast<TensorView* (*)(Val*, TensorView*, Val*)>(OP_NAME)( \
            convertToVal(arg1), arg2, convertToVal(arg3));                     \
      },                                                                       \
      DOCSTRING,                                                               \
      nb::rv_policy::reference);

#define NVFUSER_DIRECT_BINDING_THRESHOLD_LIKE_OP(NAME, OP_NAME, DOCSTRING)     \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](ScalarVariant arg1, ScalarVariant arg2, ScalarVariant arg3) -> Val* { \
        return static_cast<Val* (*)(Val*, Val*, Val*)>(OP_NAME)(               \
            convertToVal(arg1), convertToVal(arg2), convertToVal(arg3));       \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](TensorView* arg1,                                                     \
         ScalarVariant arg2,                                                   \
         ScalarVariant arg3) -> TensorView* {                                  \
        return static_cast<TensorView* (*)(TensorView*, Val*, Val*)>(OP_NAME)( \
            arg1, convertToVal(arg2), convertToVal(arg3));                     \
      },                                                                       \
      DOCSTRING,                                                               \
      nb::rv_policy::reference);

#define NVFUSER_DIRECT_BINDING_TERNARY_WITH_ALPHA_OP(NAME, OP_NAME, DOCSTRING) \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](ScalarVariant arg1,                                                   \
         ScalarVariant arg2,                                                   \
         ScalarVariant arg3,                                                   \
         ScalarVariant arg4) -> Val* {                                         \
        return static_cast<Val* (*)(Val*, Val*, Val*, Val*)>(OP_NAME)(         \
            convertToVal(arg1),                                                \
            convertToVal(arg2),                                                \
            convertToVal(arg3),                                                \
            convertToVal(arg4));                                               \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](TensorView* arg1,                                                     \
         TensorView* arg2,                                                     \
         TensorView* arg3,                                                     \
         ScalarVariant arg4) -> TensorView* {                                  \
        return static_cast<                                                    \
            TensorView* (*)(TensorView*, TensorView*, TensorView*, Val*)>(     \
            OP_NAME)(arg1, arg2, arg3, convertToVal(arg4));                    \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](TensorView* arg1,                                                     \
         TensorView* arg2,                                                     \
         ScalarVariant arg3,                                                   \
         ScalarVariant arg4) -> TensorView* {                                  \
        return static_cast<                                                    \
            TensorView* (*)(TensorView*, TensorView*, Val*, Val*)>(OP_NAME)(   \
            arg1, arg2, convertToVal(arg3), convertToVal(arg4));               \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](TensorView* arg1,                                                     \
         ScalarVariant arg2,                                                   \
         TensorView* arg3,                                                     \
         ScalarVariant arg4) -> TensorView* {                                  \
        return static_cast<TensorView* (*)(TensorView*, Val*, Val*, Val*)>(    \
            OP_NAME)(arg1, convertToVal(arg2), arg3, convertToVal(arg4));      \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](ScalarVariant arg1,                                                   \
         TensorView* arg2,                                                     \
         TensorView* arg3,                                                     \
         ScalarVariant arg4) -> TensorView* {                                  \
        return static_cast<                                                    \
            TensorView* (*)(Val*, TensorView*, TensorView*, Val*)>(OP_NAME)(   \
            convertToVal(arg1), arg2, arg3, convertToVal(arg4));               \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](ScalarVariant arg1,                                                   \
         ScalarVariant arg2,                                                   \
         TensorView* arg3,                                                     \
         ScalarVariant arg4) -> TensorView* {                                  \
        return static_cast<TensorView* (*)(Val*, Val*, TensorView*, Val*)>(    \
            OP_NAME)(                                                          \
            convertToVal(arg1), convertToVal(arg2), arg3, convertToVal(arg4)); \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](TensorView* arg1,                                                     \
         ScalarVariant arg2,                                                   \
         ScalarVariant arg3,                                                   \
         ScalarVariant arg4) -> TensorView* {                                  \
        return static_cast<TensorView* (*)(TensorView*, Val*, Val*, Val*)>(    \
            OP_NAME)(                                                          \
            arg1, convertToVal(arg2), convertToVal(arg3), convertToVal(arg4)); \
      },                                                                       \
      nb::rv_policy::reference);                                               \
  ops.def(                                                                     \
      NAME,                                                                    \
      [](ScalarVariant arg1,                                                   \
         TensorView* arg2,                                                     \
         ScalarVariant arg3,                                                   \
         ScalarVariant arg4) -> TensorView* {                                  \
        return static_cast<TensorView* (*)(Val*, TensorView*, Val*, Val*)>(    \
            OP_NAME)(                                                          \
            convertToVal(arg1), arg2, convertToVal(arg3), convertToVal(arg4)); \
      },                                                                       \
      DOCSTRING,                                                               \
      nb::rv_policy::reference);

#define NVFUSER_DIRECT_BINDING_REDUCTION_OP(NAME, OP_NAME, DOCSTRING)   \
  ops.def(                                                              \
      NAME,                                                             \
      [](TensorView* arg, PrimDataType dtype) -> TensorView* {          \
        std::vector<int64_t> dims(arg->nDims());                        \
        std::iota(dims.begin(), dims.end(), 0);                         \
        return static_cast<TensorView* (*)(TensorView*,                 \
                                           const std::vector<int64_t>&, \
                                           bool,                        \
                                           DataType)>(OP_NAME)(         \
            arg, dims, /*keepdim=*/false, dtype);                       \
      },                                                                \
      nb::arg("arg"),                                                   \
      nb::arg("dtype") = DataType::Null,                                \
      nb::rv_policy::reference);                                        \
  ops.def(                                                              \
      NAME,                                                             \
      [](TensorView* arg, int dim, bool keepdim, PrimDataType dtype)    \
          -> TensorView* {                                              \
        return static_cast<TensorView* (*)(TensorView*,                 \
                                           const std::vector<int64_t>&, \
                                           bool,                        \
                                           DataType)>(OP_NAME)(         \
            arg, {dim}, keepdim, dtype);                                \
      },                                                                \
      nb::arg("arg"),                                                   \
      nb::arg("dim"),                                                   \
      nb::arg("keepdim") = false,                                       \
      nb::arg("dtype") = DataType::Null,                                \
      nb::rv_policy::reference);                                        \
  ops.def(                                                              \
      NAME,                                                             \
      [](TensorView* arg,                                               \
         const std::vector<int64_t>& dims,                              \
         bool keepdim,                                                  \
         PrimDataType dtype) -> TensorView* {                           \
        return static_cast<TensorView* (*)(TensorView*,                 \
                                           const std::vector<int64_t>&, \
                                           bool,                        \
                                           DataType)>(OP_NAME)(         \
            arg, dims, keepdim, dtype);                                 \
      },                                                                \
      nb::arg("arg"),                                                   \
      nb::arg("dims"),                                                  \
      nb::arg("keepdim") = false,                                       \
      nb::arg("dtype") = DataType::Null,                                \
      DOCSTRING,                                                        \
      nb::rv_policy::reference);

#define NVFUSER_DIRECT_BINDING_SCAN_OP(NAME, OP_NAME, OP_TYPE, DOCSTRING) \
  ops.def(                                                                \
      NAME,                                                               \
      [](TensorView* arg, int dim, Val* init) -> TensorView* {            \
        BinaryOpType op_type = OP_TYPE;                                   \
        return static_cast<                                               \
            TensorView* (*)(TensorView*, int64_t, BinaryOpType, Val*)>(   \
            OP_NAME)(arg, dim, op_type, init);                            \
      },                                                                  \
      nb::arg("arg"),                                                     \
      nb::arg("dim"),                                                     \
      nb::arg("init").none(true) = nb::none(),                            \
      DOCSTRING,                                                          \
      nb::rv_policy::reference);

void bindUnaryOps(nb::module_& ops) {
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "abs",
      abs,
      R"(
Element-wise absolute value.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The absolute value of the input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "acos",
      acos,
      R"(
Element-wise inverse cosine.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The inverse cosine of the input in radians.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "acosh",
      acosh,
      R"(
Element-wise inverse hyperbolic cosine.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The inverse hyperbolic cosine of the input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "asin",
      asin,
      R"(
Element-wise inverse sine.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The inverse sine of the input in radians.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "asinh",
      asinh,
      R"(
Element-wise inverse hyperbolic sine.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The inverse hyperbolic sine of the input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "atan",
      atan,
      R"(
Element-wise inverse tangent.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The inverse tangent of the input in radians.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "atanh",
      atanh,
      R"(
Element-wise inverse hyperbolic tangent.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The inverse hyperbolic tangent of the input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "ceil",
      ceil,
      R"(
Element-wise ceiling function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The smallest integer greater than or equal to each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "cos",
      cos,
      R"(
Element-wise cosine.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The cosine of the input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "cosh",
      cosh,
      R"(
Element-wise hyperbolic cosine.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The hyperbolic cosine of the input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "exp",
      exp,
      R"(
Element-wise exponential function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    e raised to the power of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "exp2",
      exp2,
      R"(
Element-wise base-2 exponential function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    2 raised to the power of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "expm1",
      expm1,
      R"(
Element-wise exponential minus 1.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    exp(x) - 1 for each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "erf",
      erf,
      R"(
Element-wise error function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The error function of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "erfc",
      erfc,
      R"(
Element-wise complementary error function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    1 - erf(x) for each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "erfinv",
      erfinv,
      R"(
Element-wise inverse error function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The inverse error function of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "erfcinv",
      erfcinv,
      R"(
Element-wise inverse complementary error function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The inverse complementary error function of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "floor",
      floor,
      R"(
Element-wise floor function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The largest integer less than or equal to each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "frac",
      frac,
      R"(
Element-wise fractional part.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The fractional part of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "lgamma",
      lgamma,
      R"(
Element-wise natural logarithm of the absolute value of the gamma function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The natural logarithm of the absolute value of the gamma function.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "log",
      log,
      R"(
Element-wise natural logarithm.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The natural logarithm of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "log10",
      log10,
      R"(
Element-wise base-10 logarithm.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The base-10 logarithm of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "log1p",
      log1p,
      R"(
Element-wise natural logarithm of 1 plus x.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    log(1 + x) for each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "log2",
      log2,
      R"(
Element-wise base-2 logarithm.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The base-2 logarithm of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "neg",
      neg,
      R"(
Element-wise negation.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The negative of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "logical_not",
      logical_not,
      R"(
Element-wise logical NOT.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    True where input is False, False where input is True.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "bitwise_not",
      bitwise_not,
      R"(
Element-wise bitwise NOT.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The bitwise NOT of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "relu",
      relu,
      R"(
Element-wise rectified linear unit.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    max(0, x) for each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "rand_like",
      rand_like,
      R"(
Generate random values with the same shape as input.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    Random values with the same shape as input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "randn_like",
      randn_like,
      R"(
Generate random values from a normal distribution with the same shape as input.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    Random values from a normal distribution with the same shape as input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "reciprocal",
      reciprocal,
      R"(
Element-wise reciprocal.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    1/x for each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "round",
      round,
      R"(
Element-wise rounding to nearest integer.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    Each element rounded to the nearest integer.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "rsqrt",
      rsqrt,
      R"(
Element-wise reciprocal square root.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    1/sqrt(x) for each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "set",
      set,
      R"(
Element-wise identity operation.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    A copy of the input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "segment_set",
      segment_set,
      R"(
Element-wise identity operation, forces a segmentation between the producer and
consumer in generated kernel.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    Tensor with values set in the specified segment.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "sign",
      sign,
      R"(
Element-wise sign function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    1 for positive values, -1 for negative values, 0 for zero.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "sigmoid",
      sigmoid,
      R"(
Element-wise sigmoid function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    1/(1 + exp(-x)) for each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "signbit",
      signbit,
      R"(
Element-wise sign bit.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    True where the sign bit is set, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "silu",
      silu,
      R"(
Element-wise SiLU (Swish) activation function.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    x * sigmoid(x) for each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "sin",
      sin,
      R"(
Element-wise sine.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The sine of the input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "sinh",
      sinh,
      R"(
Element-wise hyperbolic sine.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The hyperbolic sine of the input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "sqrt",
      sqrt,
      R"(
Element-wise square root.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The square root of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "tan",
      tan,
      R"(
Element-wise tangent.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The tangent of the input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "tanh",
      tanh,
      R"(
Element-wise hyperbolic tangent.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The hyperbolic tangent of the input.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "trunc",
      trunc,
      R"(
Element-wise truncation.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The truncated value of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "isfinite",
      isfinite,
      R"(
Element-wise finite check.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    True where the element is finite, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "isinf",
      isinf,
      R"(
Element-wise infinity check.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    True where the element is infinite, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "isnan",
      isnan,
      R"(
Element-wise NaN check.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    True where the element is NaN, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "isneginf",
      isneginf,
      R"(
Element-wise negative infinity check.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    True where the element is negative infinity, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "isposinf",
      isposinf,
      R"(
Element-wise positive infinity check.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    True where the element is positive infinity, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "isreal",
      isreal,
      R"(
Element-wise real number check.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    True where the element is a real number, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "real",
      real,
      R"(
Element-wise real part of complex number.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The real part of each element.
)")
  NVFUSER_DIRECT_BINDING_UNARY_OP(
      "imag",
      imag,
      R"(
Element-wise imaginary part of complex number.

Parameters
----------
x : Val or TensorView

Returns
-------
Val or TensorView
    The imaginary part of each element.
)")
}

void bindBinaryOps(nb::module_& ops) {
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "add",
      add,
      R"(
Element-wise addition of two operands.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The sum of the inputs.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "atan2",
      atan2,
      R"(
Element-wise arctangent of lhs/rhs choosing the quadrant.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The angles in radians between the positive x-axis and a line to the (x, y) point.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "div",
      div,
      R"(
Element-wise division.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The quotient of the division, truncated towards zero as per C++'s operator /.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "truediv",
      truediv,
      R"(
Element-wise true (floating point) division.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The floating point quotient.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "fmod",
      fmod,
      R"(
Element-wise floating-point mod.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The floating-point mod.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "mul",
      mul,
      R"(
Element-wise multiplication.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The product of the inputs.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "nextafter",
      nextafter,
      R"(
Return the next floating-point value after lhs towards rhs.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The next representable values after lhs in the direction of rhs.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "pow",
      pow,
      R"(
Element-wise power function.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The bases raised to the exponents.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "remainder",
      remainder,
      R"(
Element-wise IEEE remainder.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The IEEE remainder of the division.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "sub",
      sub,
      R"(
Element-wise subtraction.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The difference of the inputs.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "minimum",
      minimum,
      R"(
Element-wise minimum.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The smaller of each pair of elements.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "maximum",
      maximum,
      R"(
Element-wise maximum.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The larger of each pair of elements.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "mod",
      mod,
      R"(
Element-wise modulo operation.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    The remainder after division.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "eq",
      eq,
      R"(
Element-wise equality comparison.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    True where elements are equal, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "ge",
      ge,
      R"(
Element-wise greater than or equal comparison.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    True where lhs >= rhs, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "gt",
      gt,
      R"(
Element-wise greater than comparison.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    True where lhs > rhs, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "le",
      le,
      R"(
Element-wise less than or equal comparison.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    True where lhs <= rhs, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "lt",
      lt,
      R"(
Element-wise less than comparison.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    True where lhs < rhs, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "ne",
      ne,
      R"(
Element-wise not equal comparison.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    True where elements are not equal, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "logical_and",
      logical_and,
      R"(
Element-wise logical AND.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    True where both inputs are True, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "logical_or",
      logical_or,
      R"(
Element-wise logical OR.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    True where either input is True, False otherwise.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "bitwise_and",
      bitwise_and,
      R"(
Element-wise bitwise AND.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    Bitwise AND of the inputs.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "bitwise_or",
      bitwise_or,
      R"(
Element-wise bitwise OR.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    Bitwise OR of the inputs.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "bitwise_xor",
      bitwise_xor,
      R"(
Element-wise bitwise XOR.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    Bitwise XOR of the inputs.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "bitwise_left_shift",
      bitwise_left_shift,
      R"(
Element-wise bitwise left shift.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    Values shifted left by specified amounts.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "bitwise_right_shift",
      bitwise_right_shift,
      R"(
Element-wise bitwise right shift.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    Values shifted right by specified amounts.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "logical_right_shift",
      logical_right_shift,
      R"(
Element-wise logical right shift.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    Values logically shifted right by specified amounts.
)")
  NVFUSER_DIRECT_BINDING_BINARY_OP(
      "gcd",
      gcd,
      R"(
Element-wise greatest common divisor.

Parameters
----------
lhs : Val or TensorView
rhs : Val or TensorView

Returns
-------
Val or TensorView
    Greatest common divisor of each pair of elements.
)")
  // complex does not support (TV, Val) and (Val, TV) argument combinations.
  ops.def(
      "complex",
      [](Val* lhs, Val* rhs) -> Val* {
        return static_cast<Val* (*)(Val*, Val*)>(complex)(lhs, rhs);
      },
      nb::rv_policy::reference);
  ops.def(
      "complex",
      [](TensorView* lhs, TensorView* rhs) -> TensorView* {
        return static_cast<TensorView* (*)(TensorView*, TensorView*)>(complex)(
            lhs, rhs);
      },
      R"(
Create a complex number from real and imaginary parts.

Parameters
----------
real : Val or TensorView
imag : Val or TensorView

Returns
-------
Val or TensorView
    A complex number.
)",
      nb::rv_policy::reference);
};

void bindTernaryOps(nb::module_& ops) {
  NVFUSER_DIRECT_BINDING_TERNARY_OP("lerp", lerp, R"(
Element-wise linear interpolation.

Parameters
----------
x : Val or TensorView
y : Val or TensorView
weight : Val or TensorView

Returns
-------
Val or TensorView
    Linear interpolation of the inputs.
)")

  NVFUSER_DIRECT_BINDING_TERNARY_OP("where", where, R"(
Select elements from either input or other tensors based on condition.

Parameters
----------
condition : Val or TensorView
x : Val or TensorView
y : Val or TensorView

Returns
-------
Val or TensorView
    Elements from x if condition is True, otherwise elements from y.
)")

  NVFUSER_DIRECT_BINDING_THRESHOLD_LIKE_OP("clamp", clamp, R"(
Clamps all elements in input into the range [ min, max ]

Parameters
----------
input : Val or TensorView
min : Val or TensorView
max : Val or TensorView

Returns
-------
Val or TensorView
    Clamped values.
)")

  NVFUSER_DIRECT_BINDING_THRESHOLD_LIKE_OP("threshold", threshold, R"(
Thresholds each element of the input Tensor.

Parameters
----------
input : Val or TensorView
threshold : Val or TensorView
value : Val or TensorView

Returns
-------
Val or TensorView
    Thresholded values.
)")

  NVFUSER_DIRECT_BINDING_TERNARY_WITH_ALPHA_OP("addcmul", addcmul, R"(
Element-wise multiplication of input1 and input2,
then adds alpha * input3 to the result.

Parameters
----------
input1 : Val or TensorView
input2 : Val or TensorView
input3 : Val or TensorView
alpha : Val

Returns
-------
Val or TensorView
    The result of the element-wise multiplication and addition.
)")
}

void bindReductionOps(nb::module_& ops) {
  NVFUSER_DIRECT_BINDING_REDUCTION_OP(
      "max",
      max,
      R"(
Reduce a tensor by computing the maximum value along specified dimensions.

Parameters
----------
arg : TensorView
    Input tensor to reduce.
dim : int, optional
    Dimension to reduce over. If not specified, reduces over all dimensions.
keepdim : bool, optional
    Whether to keep the reduced dimensions with size 1. Default is False.
dtype : PrimDataType, optional
    This argument is not used for max.

Returns
-------
TensorView
    A new tensor containing the maximum values along the specified dimensions.
)")
  NVFUSER_DIRECT_BINDING_REDUCTION_OP(
      "min",
      min,
      R"(
Reduce a tensor by computing the minimum value along specified dimensions.

Parameters
----------
arg : TensorView
    Input tensor to reduce.
dim : int, optional
    Dimension to reduce over. If not specified, reduces over all dimensions.
keepdim : bool, optional
    Whether to keep the reduced dimensions with size 1. Default is False.
dtype : PrimDataType, optional
    This argument is not used for min.

Returns
-------
TensorView
    A new tensor containing the minimum values along the specified dimensions.
)")
  NVFUSER_DIRECT_BINDING_REDUCTION_OP(
      "prod",
      prod,
      R"(
Reduce a tensor by computing the product of elements along specified dimensions.

Parameters
----------
arg : TensorView
    Input tensor to reduce.
dim : int, optional
    Dimension to reduce over. If not specified, reduces over all dimensions.
keepdim : bool, optional
    Whether to keep the reduced dimensions with size 1. Default is False.
dtype : PrimDataType, optional
    The data type to cast the arg to before computation. If the dtype argument
    is None, use the input data type if it is floating point. Otherwise, it is
    DataType::Int for boolean or integral input.

Returns
-------
TensorView
    A new tensor containing the product of elements along the specified dimensions.
)")
  NVFUSER_DIRECT_BINDING_REDUCTION_OP(
      "sum",
      sum,
      R"(
Reduce a tensor by computing the sum of elements along specified dimensions.

Parameters
----------
arg : TensorView
    Input tensor to reduce.
dim : int, optional
    Dimension to reduce over. If not specified, reduces over all dimensions.
keepdim : bool, optional
    Whether to keep the reduced dimensions with size 1. Default is False.
dtype : PrimDataType, optional
    The data type to cast the arg to before computation. If the dtype argument
    is None, use the input data type if it is floating point. Otherwise, it is
    DataType::Int for boolean or integral input.

Returns
-------
TensorView
    A new tensor containing the sum of elements along the specified dimensions.
)")
  ops.def(
      "var",
      [](TensorView* arg,
         const std::vector<int64_t>& dims,
         int64_t correction,
         bool keepdim) -> TensorView* {
        return variance(arg, dims, correction, keepdim);
      },
      nb::arg("arg"),
      nb::arg("dims"),
      nb::arg("correction") = 1,
      nb::arg("keepdim") = false,
      R"(
Reduce a tensor by computing the variance along specified dimensions.

Parameters
----------
arg : TensorView
    Input tensor to reduce.
dims : list or tuple
    Dimensions to reduce over.
correction : int, optional
    The correction factor to apply to the variance. Default is 1.
keepdim : bool, optional
    Whether to keep the reduced dimensions with size 1. Default is False.

Returns
-------
TensorView
    A tensor containing the variance along the specified dimensions.
)",
      nb::rv_policy::reference);
  ops.def(
      "var_mean",
      [](TensorView* arg,
         const std::vector<int64_t>& dims,
         int64_t correction,
         bool keepdim) -> std::tuple<TensorView*, TensorView*> {
        VarMeanResult output = variance_mean(arg, dims, correction, keepdim);
        return std::make_tuple(output.var, output.mean);
      },
      nb::arg("arg"),
      nb::arg("dims"),
      nb::arg("correction") = 1,
      nb::arg("keepdim") = false,
      R"(
Reduce a tensor by computing the mean and variance along specified dimensions.

Parameters
----------
arg : TensorView
    Input tensor to reduce.
dims : list or tuple
    Dimensions to reduce over.
correction : int, optional
    The correction factor to apply to the variance. Default is 1.
keepdim : bool, optional
    Whether to keep the reduced dimensions with size 1. Default is False.

Returns
-------
tuple
    A tuple containing the variance and mean along the specified dimensions.
)",
      nb::rv_policy::reference);
  ops.def(
      "welford",
      [](TensorView* arg, const std::vector<int64_t>& dims) -> decltype(auto) {
        WelfordResult output = WelfordRaw(arg, dims);
        return std::make_tuple(output.avg, output.var_sum, output.n);
      },
      nb::arg("arg"),
      nb::arg("dims"),
      R"(
Reduce a tensor by computing the mean and variance along specified dimensions.

Parameters
----------
arg : TensorView
    Input tensor to reduce.
dims : list or tuple
    Dimensions to reduce over.

Returns
-------
tuple
    A tuple containing the mean, variance, and count along the specified dimensions.
)",
      nb::rv_policy::reference);
}

void bindScanOps(nb::module_& ops) {
  // cumsum (prefix sum) along a dimension
  NVFUSER_DIRECT_BINDING_SCAN_OP(
      "cumsum",
      scan,
      BinaryOpType::Add,
      R"(
Cumulative sum along a dimension.

Parameters
----------
arg : TensorView
    Input tensor to compute cumulative sum.
dim : int
    Dimension to compute cumulative sum over.

Returns
-------
TensorView
    A new tensor containing the cumulative sum along the specified dimension.
)");

  NVFUSER_DIRECT_BINDING_SCAN_OP(
      "cumprod",
      scan,
      BinaryOpType::Mul,
      R"(
Cumulative product along a dimension.

Parameters
----------
arg : TensorView
    Input tensor to compute cumulative product.
dim : int
    Dimension to compute cumulative product over.

Returns
-------
TensorView
    A new tensor containing the cumulative product along the specified dimension.
)");

  NVFUSER_DIRECT_BINDING_SCAN_OP(
      "cummin",
      scan,
      BinaryOpType::Min,
      R"(
Cumulative minimum along a dimension.

Parameters
----------
arg : TensorView
    Input tensor to compute cumulative minimum.
dim : int
    Dimension to compute cumulative minimum over.

Returns
-------
TensorView
    A new tensor containing the cumulative minimum along the specified dimension.
)");
  NVFUSER_DIRECT_BINDING_SCAN_OP(
      "cummax",
      scan,
      BinaryOpType::Max,
      R"(
Cumulative maximum along a dimension.

Parameters
----------
arg : TensorView
    Input tensor to compute cumulative maximum.
dim : int
    Dimension to compute cumulative maximum over.

Returns
-------
TensorView
    A new tensor containing the cumulative maximum along the specified dimension.
)");
}

void bindCastOps(nb::module_& ops) {
  ops.def(
      "cast",
      [](TensorView* arg, PrimDataType dtype) -> TensorView* {
        return static_cast<TensorView* (*)(DataType, TensorView*)>(castOp)(
            dtype, arg);
      },
      nb::arg("arg"),
      nb::arg("dtype"),
      nb::rv_policy::reference);
  ops.def(
      "cast",
      [](ScalarVariant arg, PrimDataType dtype) -> Val* {
        return static_cast<Val* (*)(DataType, Val*)>(castOp)(
            dtype, convertToVal(arg));
      },
      nb::arg("arg"),
      nb::arg("dtype"),
      R"(
Cast a scalar value to a different data type.

Parameters
----------
arg : Val
    Input scalar value to cast.
dtype : PrimDataType
    Target data type for the cast operation.

Returns
-------
Val
    A new scalar value with the specified data type.
)",
      nb::rv_policy::reference);
}

void bindCompositeOps(nb::module_& ops) {
  ops.def(
      "triu",
      [](TensorView* arg, int64_t diagonal) -> TensorView* {
        Val* diagonal_val =
            IrBuilder::create<nvfuser::Val>(diagonal, DataType::Int);
        return triu(arg, diagonal_val);
      },
      nb::arg("arg"),
      nb::arg("diagonal") = 0,
      R"(
Get the upper triangular part of a tensor.

Parameters
----------
arg : TensorView
diagonal : int
    Offset of the diagonal relative to the main diagonal.

Returns
-------
TensorView
    The upper triangular part of the tensor.
)",
      nb::rv_policy::reference);
}

void bindMatmulOps(nb::module_& ops) {
  ops.def(
      "matmul",
      static_cast<TensorView* (*)(TensorView*, TensorView*)>(matmul),
      nb::arg("arg1"),
      nb::arg("arg2"),
      R"(
The matrix product of two tensors.

Parameters
----------
arg1 : TensorView
arg2 : TensorView

Returns
-------
TensorView
    The result of the matrix multiplication.
)",
      nb::rv_policy::reference);
  ops.def(
      "linear",
      [](TensorView* arg1, TensorView* arg2, TensorView* bias) -> TensorView* {
        return static_cast<
            TensorView* (*)(TensorView*, TensorView*, TensorView*)>(linear)(
            arg1, arg2, bias);
      },
      nb::arg("arg1"),
      nb::arg("arg2"),
      nb::arg("bias").none(true) = nb::none(),
      R"(
Applies an affine linear transformation to the incoming data:
output = arg1 @ transpose(arg2) + bias.

Parameters
----------
arg1 : TensorView
arg2 : TensorView
bias : TensorView, optional
    The bias vector to add to the output. If not provided, the bias is not added.

Returns
-------
TensorView
    The result of the affine linear transformation.
)",
      nb::rv_policy::reference);
  ops.def(
      "grouped_mm",
      [](TensorView* mat1,
         TensorView* mat2,
         TensorView* offsets) -> TensorView* {
        // Calculate output dimensions based on mat1 & mat2 rank
        ScaledTensorView scaled_out = grouped_mm(mat1, mat2, offsets);
        return scaled_out.tv;
      },
      nb::arg("mat1"),
      nb::arg("mat2"),
      nb::arg("offsets"),
      R"(
Grouped matrix multiplication.

Performs matrix multiplication on grouped sets of matrices using offsets
to define variable-sized groups.

Parameters
----------
mat1 : TensorView
    First set of matrices
mat2 : TensorView
    Second set of matrices
offsets : TensorView
    Offsets tensor defining group boundaries

Returns
-------
TensorView
    Result of grouped matrix multiplication
)",
      nb::rv_policy::reference);
  ops.def(
      "grouped_mm",
      [](TensorView* mat1,
         TensorView* mat2,
         TensorView* offsets,
         TensorView* scale1,
         TensorView* scale2,
         TensorView* alpha,
         TensorView* bias,
         TensorView* beta,
         PrimDataType dtype,
         int64_t output_block_scale_size,
         PrimDataType output_block_scale_dtype,
         bool output_gamma)
          -> std::tuple<
              TensorView*,
              std::optional<TensorView*>,
              std::optional<TensorView*>> {
        auto [output, block_scaling_factor, global_scaling_factor] = grouped_mm(
            mat1,
            mat2,
            offsets,
            scale1,
            scale2,
            alpha,
            bias,
            beta,
            dtype,
            output_block_scale_size,
            output_block_scale_dtype,
            output_gamma);

        if (output_gamma) {
          NVF_CHECK(
              output_block_scale_size > 0,
              "output_block_scale_size must be greater than 0 when "
              "output_gamma is true");
          return std::make_tuple(
              output, block_scaling_factor, global_scaling_factor);
        } else if (output_block_scale_size > 0) {
          return std::make_tuple(output, block_scaling_factor, std::nullopt);
        }
        return std::make_tuple(output, std::nullopt, std::nullopt);
      },
      nb::arg("mat1"),
      nb::arg("mat2"),
      nb::arg("offsets"),
      nb::arg("scale1"),
      nb::arg("scale2"),
      nb::arg("alpha").none(true) = nb::none(),
      nb::arg("bias").none(true) = nb::none(),
      nb::arg("beta").none(true) = nb::none(),
      nb::arg("dtype") = DataType::BFloat16,
      nb::arg("output_block_scale_size") = 0,
      nb::arg("output_block_scale_dtype") = DataType::BFloat16,
      nb::arg("output_gamma") = false,
      R"(
Scaled Grouped matrix multiplication.

Performs matrix multiplication on grouped sets of matrices using offsets
to define variable-sized groups.

The math operation is roughly two steps:
    out = alpha * grouped_mm(dequant(mat1, scale1), dequant(mat2, scale2), offsets) + beta * bias

    (out_mat, out_scale, out_gamma) = Quantization(
        out,
        dtype,
        output_block_scale_size,
        output_block_scale_dtype,
        output_gamma)

Note 1: The post quantization only applies when output_block_scale_size > 0,
        which would produce out_scale tensor. Otherwise, None will be returned;
Note 2: When output_gamma is set to True, it should produce global scaling factor out_gamma tensor.
        Otherwise, None will be returned.

Parameters
----------
mat1 : TensorView
    First set of matrices
mat2 : TensorView
    Second set of matrices
offsets : TensorView
    Offsets tensor defining group boundaries
scale1 : TensorView
    Scale tensor for mat1
scale2 : TensorView
    Scale tensor for mat2
alpha : TensorView, optional
    Alpha tensor
bias : TensorView, optional
    Bias tensor
beta : TensorView, optional
    Beta tensor
dtype : PrimDataType, optional
    Output tensor type [default: DataType::BFloat16]
output_block_scale_size : int, optional
    Output block scale size
output_block_scale_dtype : PrimDataType, optional
    Output block scale dtype
output_gamma : bool, optional
    Output gamma [default: False]

Returns
-------
tuple
    A tuple containing the result of matrix multiplication, output block scale tensor, and output gamma tensor.
)",
      nb::rv_policy::reference);
  ops.def(
      "cutlass_nvfp4_grouped_mm",
      [](TensorView* mat1,
         TensorView* mat2,
         TensorView* scale1,
         TensorView* scale2,
         TensorView* alpha,
         TensorView* problem_sizes,
         TensorView* expert_offsets,
         TensorView* sf_offsets,
         PrimDataType dtype) -> TensorView* {
        return cutlass_nvfp4_grouped_mm(
            mat1,
            mat2,
            scale1,
            scale2,
            alpha,
            problem_sizes,
            expert_offsets,
            sf_offsets,
            dtype);
      },
      R"(
Cutlass NVFP4 Grouped Matrix Multiplication.

Parameters
----------
mat1 : TensorView
    First set of matrices
mat2 : TensorView
    Second set of matrices
scale1 : TensorView
    Scale tensor for mat1
scale2 : TensorView
    Scale tensor for mat2
alpha : TensorView
    Alpha tensor
problem_sizes : TensorView
    Problem sizes tensor
expert_offsets : TensorView
    Expert offsets tensor
sf_offsets : TensorView
    SF offsets tensor
dtype : PrimDataType
    Output tensor type

Returns
-------
TensorView
    Result of grouped matrix multiplication
)",
      nb::arg("mat1"),
      nb::arg("mat2"),
      nb::arg("scale1"),
      nb::arg("scale2"),
      nb::arg("alpha"),
      nb::arg("problem_sizes"),
      nb::arg("expert_offsets"),
      nb::arg("sf_offsets"),
      nb::arg("dtype") = DataType::BFloat16,
      nb::rv_policy::reference);
  ops.def(
      "scaled_mm",
      [](TensorView* mat1,
         TensorView* mat2,
         TensorView* scale1,
         TensorView* scale2,
         TensorView* alpha,
         TensorView* bias,
         TensorView* beta,
         PrimDataType dtype,
         int64_t output_block_scale_size,
         PrimDataType output_block_scale_dtype,
         bool output_gamma)
          -> std::tuple<
              TensorView*,
              std::optional<TensorView*>,
              std::optional<TensorView*>> {
        /* Per https://pytorch.org/docs/stable/generated/torch.matmul.html */
        auto [output, block_scaling_factor, global_scaling_factor] = scaled_mm(
            mat1,
            mat2,
            scale1,
            scale2,
            alpha,
            bias,
            beta,
            dtype,
            output_block_scale_size,
            output_block_scale_dtype,
            output_gamma);

        if (output_gamma) {
          NVF_CHECK(
              output_block_scale_size > 0,
              "output_block_scale_size must be greater than 0 when "
              "output_gamma is true");
          return std::make_tuple(
              output, block_scaling_factor, global_scaling_factor);
        } else if (output_block_scale_size > 0) {
          return std::make_tuple(output, block_scaling_factor, std::nullopt);
        }
        return std::make_tuple(output, std::nullopt, std::nullopt);
      },
      nb::arg("mat1"),
      nb::arg("mat2"),
      nb::ar...

@rdspring1 rdspring1 force-pushed the nanobind_direct branch 5 times, most recently from 87e8c35 to 02c761e Compare January 9, 2026 19:43
@rdspring1 rdspring1 force-pushed the nanobind_direct branch 2 times, most recently from 3baae0d to ae4b1d8 Compare January 12, 2026 18:25
@rdspring1
Copy link
Collaborator Author

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Direct Bindings Python extension with direct mapping to NvFuser CPP objects.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants