From f679e345ad0b27348c3f9ecd827b0fb1e0caa7b3 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Mon, 3 Nov 2025 03:12:20 -0600 Subject: [PATCH 01/27] modifying runtime registion part --- .bazelrc | 2 +- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 96 +++++++++++++++---- 2 files changed, 77 insertions(+), 21 deletions(-) diff --git a/.bazelrc b/.bazelrc index 0074cf53fe..90ec3c6a54 100644 --- a/.bazelrc +++ b/.bazelrc @@ -32,4 +32,4 @@ common --define=allow_oversize_protos=true # See https://github.com/bazel-contrib/rules_python/issues/2445 build --@rules_python//python/config_settings:precompile=force_disabled -build -c opt +build -c dbg diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 6fdc25982c..78eb23f4a3 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -1284,7 +1284,7 @@ struct LowerGPUAlternativesOp ops, floatOps, intOps, loads, stores, branches, }; }; - + #if POLYGEIST_ENABLE_CUDA if (gpuTarget == "cuda") { char cuErrorBuffer[4096] = {0}; @@ -1793,6 +1793,29 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, moduleIDPrefix = "__hip_"; fatMagic = HIPFatMagic; } + std::string registerFatBinaryFuncName; + std::string registerFunctionFuncName; + std::string registerVarFuncName; + std::string unregisterFatBinaryFuncName; + std::string registerFatBinaryEndFuncName; + bool requiresRegisterEnd; + + if (gpuTarget == "cuda") { + registerFatBinaryFuncName = "__cudaRegisterFatBinary"; + registerFunctionFuncName = "__cudaRegisterFunction"; + registerVarFuncName = "__cudaRegisterVar"; + unregisterFatBinaryFuncName = "__cudaUnregisterFatBinary"; + registerFatBinaryEndFuncName = "__cudaRegisterFatBinaryEnd"; + requiresRegisterEnd = true; + } else { + registerFatBinaryFuncName = "__hipRegisterFatBinary"; + registerFunctionFuncName = "__hipRegisterFunction"; + registerVarFuncName = "__hipRegisterVar"; + unregisterFatBinaryFuncName = "__hipUnregisterFatBinary"; + registerFatBinaryEndFuncName = ""; + requiresRegisterEnd = false; + } + (void)fatbinConstantName; (void)moduleIDSectionName; @@ -1855,16 +1878,27 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, auto bitcastOfWrapper = ctorBuilder.create( ctorloc, llvmPointerType, addressOfWrapper); - auto cudaRegisterFatbinFn = - LLVM::lookupOrCreateFn(rewriter, moduleOp, "__cudaRegisterFatBinary", - llvmPointerType, llvmPointerType); - if (failed(cudaRegisterFatbinFn)) { - llvm::errs() << " cudamalloc already exists with different types\n"; + // auto cudaRegisterFatbinFn = + // LLVM::lookupOrCreateFn(rewriter, moduleOp, "__cudaRegisterFatBinary", + // llvmPointerType, llvmPointerType); + // if (failed(cudaRegisterFatbinFn)) { + // llvm::errs() << " cudamalloc already exists with different types\n"; + // return failure(); + // } + + // auto module = rewriter.create( + // ctorloc, cudaRegisterFatbinFn.value(), ValueRange(bitcastOfWrapper)); + + auto registerFatbinFn = LLVM::lookupOrCreateFn(rewriter, moduleOp, registerFatBinaryFuncName, llvmPointerType, llvmPointerType); + + if (failed(registerFatbinFn)) { + llvm::errs() << "register fatbin function already exists with different types\n"; return failure(); } auto module = rewriter.create( - ctorloc, cudaRegisterFatbinFn.value(), ValueRange(bitcastOfWrapper)); + ctorloc, registerFatbinFn.value(), ValueRange(bitcastOfWrapper) + ); auto moduleGlobalName = std::string(llvm::formatv("polygeist_{0}_module_ptr", moduleName)); @@ -1919,12 +1953,22 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType}; - auto cudaRegisterFn = LLVM::lookupOrCreateFn( - rewriter, moduleOp, "__cudaRegisterFunction", tys, llvmInt32Type); - if (failed(cudaRegisterFn)) { - llvm::errs() << " cudamalloc already exists with different types\n"; + // auto cudaRegisterFn = LLVM::lookupOrCreateFn( + // rewriter, moduleOp, "__cudaRegisterFunction", tys, llvmInt32Type); + // if (failed(cudaRegisterFn)) { + // llvm::errs() << " cudamalloc already exists with different types\n"; + // return failure(); + // } + + auto registerFunctionFn = LLVM::lookupOrCreateFn( + rewriter, moduleOp, registerFunctionFuncName, tys, llvmInt32Type + ); + + if (failed(registerFunctionFn)) { + llvm::errs() << " register function already exists with different types\n"; return failure(); } + Value args[] = { module.getResult(), bitcast, @@ -1937,7 +1981,8 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, nullPtr, nullPtr}; - rewriter.create(ctorloc, cudaRegisterFn.value(), args); + // rewriter.create(ctorloc, cudaRegisterFn.value(), args); + rewriter.create(ctorloc, registerFunctionFn.value(), args); } else if (LLVM::GlobalOp g = dyn_cast(op)) { int addrSpace = g.getAddrSpace(); if (addrSpace != 1 /* device */ && addrSpace != 4 /* constant */) @@ -1968,6 +2013,7 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, // to pass the GPU DL in here DataLayout DLI(moduleOp); auto size = DLI.getTypeSize(globalTy); + // why 'mgpu' rtRegisterVarCallBuilder.create( ctorloc, ctorBuilder, {module.getResult(), bitcast, symbolName, symbolName, @@ -1985,7 +2031,7 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, 0)}); } } - // TODO this has to happen only for some CUDA versions + // TODO this has to happen only for some CUDA versions, hip does not need finialize if (gpuTarget == "cuda") { auto cudaRegisterFatbinFn = LLVM::lookupOrCreateFn( rewriter, moduleOp, "__cudaRegisterFatBinaryEnd", llvmPointerType, @@ -2017,16 +2063,26 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, auto module = dtorBuilder.create( ctorloc, llvmPointerPointerType, aoo->getResult(0)); - auto cudaUnRegisterFatbinFn = LLVM::lookupOrCreateFn( - rewriter, moduleOp, "__cudaUnregisterFatBinary", llvmPointerType, + // auto cudaUnRegisterFatbinFn = LLVM::lookupOrCreateFn( + // rewriter, moduleOp, "__cudaUnregisterFatBinary", llvmPointerType, + // llvmVoidType); + // if (failed(cudaUnRegisterFatbinFn)) { + // llvm::errs() << " cudamalloc already exists with different types\n"; + // return failure(); + // } + + // rewriter.create(ctorloc, cudaUnRegisterFatbinFn.value(), + // ValueRange(module)); + auto unregisterFatbinFn = LLVM::lookupOrCreateFn( + rewriter, moduleOp, unregisterFatBinaryFuncName, llvmPointerType, llvmVoidType); - if (failed(cudaUnRegisterFatbinFn)) { - llvm::errs() << " cudamalloc already exists with different types\n"; - return failure(); + if (failed(unregisterFatbinFn)) { + llvm::errs() << " unregister fatbin function already exists with different types\n"; + return failure(); } + rewriter.create(ctorloc, unregisterFatbinFn.value(), + ValueRange(module)); - rewriter.create(ctorloc, cudaUnRegisterFatbinFn.value(), - ValueRange(module)); dtorBuilder.create(ctorloc, ValueRange()); auto dtorSymbol = FlatSymbolRefAttr::get(dtor); { From e99c86b85727f770171e2ce4f9f87bbbf0d38d6e Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Mon, 3 Nov 2025 15:24:46 -0600 Subject: [PATCH 02/27] add cuda rocm wrappers, and temporarily exclude in BUILD --- src/enzyme_ad/jax/BUILD | 7 +- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 16 +- .../jax/Passes/CudaRuntimeWrappers.cpp | 209 ++++++++++++++++ .../jax/Passes/RocmRuntimeWrappers.cpp | 230 ++++++++++++++++++ 4 files changed, 454 insertions(+), 8 deletions(-) create mode 100644 src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp create mode 100644 src/enzyme_ad/jax/Passes/RocmRuntimeWrappers.cpp diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 24ac998d4d..26465ca838 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -720,7 +720,12 @@ cc_library( "Dialect/*.cpp", "Dialect/Distributed/*.cpp", "Dialect/Tessera/*.cpp", - ]) + [ + ], + exclude = [ + "Passes/CudaRuntimeWrappers.cpp", + "Passes/RocmRuntimeWrappers.cpp", + ], + ) + [ "Utils.cpp", ], hdrs = glob([ diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 78eb23f4a3..92a756f812 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -1727,6 +1727,14 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, auto loc = kernelModule.getLoc(); auto ctorloc = rewriter.getUnknownLoc(); + + std::string registerFatBinaryFuncName; + std::string registerFunctionFuncName; + std::string registerVarFuncName; + std::string unregisterFatBinaryFuncName; + std::string registerFatBinaryEndFuncName; + bool requiresRegisterEnd; + rewriter.modifyOpInPlace(kernelModule, [&]() { kernelModule->setAttr("polygeist_stubs", rewriter.getUnitAttr()); }); @@ -1793,13 +1801,7 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, moduleIDPrefix = "__hip_"; fatMagic = HIPFatMagic; } - std::string registerFatBinaryFuncName; - std::string registerFunctionFuncName; - std::string registerVarFuncName; - std::string unregisterFatBinaryFuncName; - std::string registerFatBinaryEndFuncName; - bool requiresRegisterEnd; - + if (gpuTarget == "cuda") { registerFatBinaryFuncName = "__cudaRegisterFatBinary"; registerFunctionFuncName = "__cudaRegisterFunction"; diff --git a/src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp b/src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp new file mode 100644 index 0000000000..c0494cd4bb --- /dev/null +++ b/src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp @@ -0,0 +1,209 @@ +//===- EnzymeCudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements C wrappers around the CUDA library for easy linking in ORC jit. +// Also adds some debugging helpers that are helpful when writing MLIR code to +// run on GPUs. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "cuda.h" +#include "cuda_runtime.h" + +#include "PGORuntime.h" + +#ifdef _WIN32 +#define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) __attribute__((weak)) +#else +#define MLIR_CUDA_WRAPPERS_EXPORT __attribute__((weak)) +#endif // _WIN32 + +#define CUDART_REPORT_IF_ERROR(expr) \ + [](auto result) { \ + if (!result) \ + return result; \ + const char *name = cudaGetErrorString(result); \ + if (!name) \ + name = ""; \ + fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ + return result; \ + }(expr) + +#define CUDA_REPORT_IF_ERROR(expr) \ + [](CUresult result) { \ + if (!result) \ + return result; \ + const char *name = nullptr; \ + cuGetErrorName(result, &name); \ + if (!name) \ + name = ""; \ + fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ + return result; \ + }(expr) + +thread_local static int32_t defaultDevice = 0; + +// Make the primary context of the current default device current for the +// duration +// of the instance and restore the previous context on destruction. +class ScopedContext { +public: + ScopedContext() { + // Static reference to CUDA primary context for device ordinal + // defaultDevice. + static CUcontext context = [] { + CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); + CUdevice device; + CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); + CUcontext ctx; + // Note: this does not affect the current context. + CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device)); + return ctx; + }(); + + CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context)); + } + + ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); } +}; + +//========= CUDA RUNTIME API =========// + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void +mgpurtLaunchKernel(void *function, intptr_t gridX, intptr_t gridY, + intptr_t gridZ, intptr_t blockX, intptr_t blockY, + intptr_t blockZ, int32_t smem, cudaStream_t stream, + void **params) { + CUDART_REPORT_IF_ERROR(cudaLaunchKernel(function, dim3(gridX, gridY, gridZ), + dim3(blockX, blockY, blockZ), params, + smem, stream)); +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT int32_t mgpurtLaunchKernelErr( + void *function, intptr_t gridX, intptr_t gridY, intptr_t gridZ, + intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, + cudaStream_t stream, void **params) { + return CUDART_REPORT_IF_ERROR( + cudaLaunchKernel(function, dim3(gridX, gridY, gridZ), + dim3(blockX, blockY, blockZ), params, smem, stream)); +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * +mgpurtMemAlloc(uint64_t sizeBytes, cudaStream_t /*stream*/) { + void *ptr; + CUDART_REPORT_IF_ERROR(cudaMalloc(&ptr, sizeBytes)); + return reinterpret_cast(ptr); +} + +extern "C" void mgpurtMemcpyErr(void *dst, void *src, size_t sizeBytes) { + CUDART_REPORT_IF_ERROR(cudaMemcpy(dst, src, sizeBytes, cudaMemcpyDefault)); +} + +extern "C" void mgpurtMemcpyAsyncErr(void *dst, void *src, size_t sizeBytes, + cudaStream_t stream) { + CUDART_REPORT_IF_ERROR( + cudaMemcpyAsync(dst, src, sizeBytes, cudaMemcpyDefault, stream)); +} + +//========= CUDA DRIVER API =========// + +// The wrapper uses intptr_t instead of CUDA's unsigned int to match +// the type of MLIR's index type. This avoids the need for casts in the +// generated MLIR code. +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void +mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY, + intptr_t gridZ, intptr_t blockX, intptr_t blockY, + intptr_t blockZ, int32_t smem, CUstream stream, void **params, + void **extra) { + ScopedContext scopedContext; + CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, + blockY, blockZ, smem, stream, params, + extra)); +} + +// The wrapper uses intptr_t instead of CUDA's unsigned int to match +// the type of MLIR's index type. This avoids the need for casts in the +// generated MLIR code. +extern "C" MLIR_CUDA_WRAPPERS_EXPORT int32_t mgpuLaunchKernelErr( + CUfunction function, intptr_t gridX, intptr_t gridY, intptr_t gridZ, + intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, + CUstream stream, void **params, void **extra) { + ScopedContext scopedContext; + return CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, + blockX, blockY, blockZ, smem, + stream, params, extra)); +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) { + ScopedContext scopedContext; + CUmodule module = nullptr; + CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); + return module; +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) { + CUDA_REPORT_IF_ERROR(cuModuleUnload(module)); +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction +mgpuModuleGetFunction(CUmodule module, const char *name) { + CUfunction function = nullptr; + CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); + return function; +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT int32_t mgpurtDeviceSynchronizeErr(void) { + return CUDART_REPORT_IF_ERROR(cudaDeviceSynchronize()); +} + +extern "C" void __cudaRegisterFunction(void **fatCubinHandle, void *hostFun, + void *deviceFun, void *deviceName, + int32_t thread_limit, void *tid, + void *bid, void *bDim, void *gDim, + void *wSize); +extern "C" void __cudaRegisterVar(void **fatCubinHandle, char *hostVar, + char *deviceAddress, const char *deviceName, + int ext, size_t size, int constant, + int global); +extern "C" void **__cudaRegisterFatBinary(void *fatCubin); +extern "C" void __cudaRegisterFatBinaryEnd(void **fatCubinHandle); +extern "C" void __cudaUnregisterFatBinary(void **fatCubinHandle); + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void +__mgpurtRegisterFunction(void **fatCubinHandle, void *hostFun, void *deviceFun, + void *deviceName, int32_t thread_limit, void *tid, + void *bid, void *bDim, void *gDim, void *wSize) { + __cudaRegisterFunction(fatCubinHandle, hostFun, deviceFun, deviceName, + thread_limit, tid, bid, bDim, gDim, wSize); +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void +__mgpurtRegisterVar(void **fatCubinHandle, char *hostVar, char *deviceAddress, + const char *deviceName, int ext, size_t size, int constant, + int global) { + __cudaRegisterVar(fatCubinHandle, hostVar, deviceAddress, deviceName, ext, + size, constant, global); +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void ** +__mgpurtRegisterFatBinary(void *fatCubin) { + return __cudaRegisterFatBinary(fatCubin); +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void +__mgpurtRegisterFatBinaryEnd(void **fatCubinHandle) { + __cudaRegisterFatBinaryEnd(fatCubinHandle); +} + +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void +__mgpurtUnregisterFatBinary(void **fatCubinHandle) { + __cudaUnregisterFatBinary(fatCubinHandle); +} diff --git a/src/enzyme_ad/jax/Passes/RocmRuntimeWrappers.cpp b/src/enzyme_ad/jax/Passes/RocmRuntimeWrappers.cpp new file mode 100644 index 0000000000..5cae5c4948 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/RocmRuntimeWrappers.cpp @@ -0,0 +1,230 @@ +//===- PolygeistRocmRuntimeWrappers.cpp - MLIR ROCM API wrapper library ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements C wrappers around the ROCM library for easy linking in ORC jit. +// Also adds some debugging helpers that are helpful when writing MLIR code to +// run on GPUs. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "hip/hip_runtime.h" + +#include "PGORuntime.h" + +#ifdef _WIN32 +#define MLIR_HIP_WRAPPERS_EXPORT __declspec(dllexport) __attribute__((weak)) +#else +#define MLIR_HIP_WRAPPERS_EXPORT __attribute__((weak)) +#endif // _WIN32 + +#define HIP_REPORT_IF_ERROR(expr) \ + [](hipError_t result) { \ + if (!result) \ + return; \ + const char *name = hipGetErrorName(result); \ + if (!name) \ + name = ""; \ + fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ + }(expr) + +#define ERR_HIP_REPORT_IF_ERROR(expr) \ + [](hipError_t result) -> hipError_t { \ + if (!result) \ + return result; \ + const char *name = hipGetErrorName(result); \ + if (!name) \ + name = ""; \ + fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ + return result; \ + }(expr) + +extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t +mgpurtMemAllocErr(void **mem, uint64_t sizeBytes) { + return ERR_HIP_REPORT_IF_ERROR(hipMalloc(mem, sizeBytes)); +} + +extern "C" MLIR_HIP_WRAPPERS_EXPORT void * +mgpurtMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) { + void *ptr; + HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes)); + return reinterpret_cast(ptr); +} + +extern "C" MLIR_HIP_WRAPPERS_EXPORT void mgpuMemFree(void *ptr, + hipStream_t /*stream*/) { + HIP_REPORT_IF_ERROR(hipFree(ptr)); +} + +extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t +mgpurtMemcpyErr(void *dst, void *src, intptr_t sizeBytes) { + return ERR_HIP_REPORT_IF_ERROR( + hipMemcpy(dst, src, sizeBytes, hipMemcpyDefault)); +} + +extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t mgpurtMemcpyAsyncErr( + void *dst, void *src, intptr_t sizeBytes, hipStream_t stream) { + return ERR_HIP_REPORT_IF_ERROR( + hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream)); +} + +extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t mgpurtDeviceSynchronizeErr(void) { + return ERR_HIP_REPORT_IF_ERROR(hipDeviceSynchronize()); +} + +extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t mgpurtLaunchKernelErr( + void *function, intptr_t gridX, intptr_t gridY, intptr_t gridZ, + intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, + hipStream_t stream, void **params) { + return ERR_HIP_REPORT_IF_ERROR( + hipLaunchKernel(function, dim3(gridX, gridY, gridZ), + dim3(blockX, blockY, blockZ), params, smem, stream)); +} + +extern "C" void __hipRegisterFunction(void **fatCubinHandle, void *hostFun, + void *deviceFun, void *deviceName, + int32_t thread_limit, void *tid, + void *bid, void *bDim, void *gDim, + void *wSize); +extern "C" void __hipRegisterVar(void **fatCubinHandle, char *hostVar, + char *deviceAddress, const char *deviceName, + int ext, size_t size, int constant, + int global); +extern "C" void **__hipRegisterFatBinary(void *fatCubin); +extern "C" void __hipRegisterFatBinaryEnd(void **fatCubinHandle); +extern "C" void __hipUnregisterFatBinary(void **fatCubinHandle); + +extern "C" MLIR_HIP_WRAPPERS_EXPORT void +__mgpurtRegisterFunction(void **fatCubinHandle, void *hostFun, void *deviceFun, + void *deviceName, int32_t thread_limit, void *tid, + void *bid, void *bDim, void *gDim, void *wSize) { + __hipRegisterFunction(fatCubinHandle, hostFun, deviceFun, deviceName, + thread_limit, tid, bid, bDim, gDim, wSize); +} +extern "C" MLIR_HIP_WRAPPERS_EXPORT void +__mgpurtRegisterVar(void **fatCubinHandle, char *hostVar, char *deviceAddress, + const char *deviceName, int ext, size_t size, int constant, + int global) { + __hipRegisterVar(fatCubinHandle, hostVar, deviceAddress, deviceName, ext, + size, constant, global); +} + +extern "C" MLIR_HIP_WRAPPERS_EXPORT void ** +__mgpurtRegisterFatBinary(void *fatCubin) { + return __hipRegisterFatBinary(fatCubin); +} + +extern "C" MLIR_HIP_WRAPPERS_EXPORT void +__mgpurtRegisterFatBinaryEnd(void **fatCubinHandle) { + return __hipRegisterFatBinaryEnd(fatCubinHandle); +} + +extern "C" MLIR_HIP_WRAPPERS_EXPORT void +__mgpurtUnregisterFatBinary(void **fatCubinHandle) { + return __hipUnregisterFatBinary(fatCubinHandle); +} + +#if POLYGEIST_ENABLE_CUDA + +#pragma push_macro("__forceinline__") +#define __VECTOR_TYPES_H__ +#include +#undef __VECTOR_TYPES_H__ +#pragma pop_macro("__forceinline__") + +extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t +mgpurtCudaGetDeviceProperties(struct cudaDeviceProp *cudaProp, int device) { + struct hipDeviceProp_t hipProp; + int err = ERR_HIP_REPORT_IF_ERROR(hipGetDeviceProperties(&hipProp, device)); + + // Reassign all corresponding fields to the hip props, the commented ones dont + // exist in hip one-for-one +#define __polygeist_assign_field(f) \ + memcpy(&(cudaProp->f), &(hipProp.f), sizeof(cudaProp->f)) + __polygeist_assign_field(name); + // __polygeist_assign_field(uuid); + __polygeist_assign_field(totalGlobalMem); + __polygeist_assign_field(sharedMemPerBlock); + __polygeist_assign_field(regsPerBlock); + __polygeist_assign_field(warpSize); + __polygeist_assign_field(memPitch); + __polygeist_assign_field(maxThreadsPerBlock); + __polygeist_assign_field(maxThreadsDim); + __polygeist_assign_field(maxGridSize); + __polygeist_assign_field(clockRate); + __polygeist_assign_field(totalConstMem); + __polygeist_assign_field(major); + __polygeist_assign_field(minor); + __polygeist_assign_field(textureAlignment); + __polygeist_assign_field(texturePitchAlignment); + // __polygeist_assign_field(deviceOverlap); + __polygeist_assign_field(multiProcessorCount); + __polygeist_assign_field(kernelExecTimeoutEnabled); + __polygeist_assign_field(integrated); + __polygeist_assign_field(canMapHostMemory); + __polygeist_assign_field(computeMode); + __polygeist_assign_field(maxTexture1D); + // __polygeist_assign_field(maxTexture1DMipmap); + __polygeist_assign_field(maxTexture1DLinear); + __polygeist_assign_field(maxTexture2D); + // __polygeist_assign_field(maxTexture2DMipmap); + // __polygeist_assign_field(maxTexture2DLinear); + // __polygeist_assign_field(maxTexture2DGather); + __polygeist_assign_field(maxTexture3D); + // __polygeist_assign_field(maxTexture3DAlt); + // __polygeist_assign_field(maxTextureCubemap); + // __polygeist_assign_field(maxTexture1DLayered); + // __polygeist_assign_field(maxTexture2DLayered); + // __polygeist_assign_field(maxTextureCubemapLayered); + // __polygeist_assign_field(maxSurface1D); + // __polygeist_assign_field(maxSurface2D); + // __polygeist_assign_field(maxSurface3D); + // __polygeist_assign_field(maxSurface1DLayered); + // __polygeist_assign_field(maxSurface2DLayered); + // __polygeist_assign_field(maxSurfaceCubemap); + // __polygeist_assign_field(maxSurfaceCubemapLayered); + // __polygeist_assign_field(surfaceAlignment); + __polygeist_assign_field(concurrentKernels); + __polygeist_assign_field(ECCEnabled); + __polygeist_assign_field(pciBusID); + __polygeist_assign_field(pciDeviceID); + __polygeist_assign_field(pciDomainID); + __polygeist_assign_field(tccDriver); + // __polygeist_assign_field(asyncEngineCount); + // __polygeist_assign_field(unifiedAddressing); + __polygeist_assign_field(memoryClockRate); + __polygeist_assign_field(memoryBusWidth); + __polygeist_assign_field(l2CacheSize); + // __polygeist_assign_field(persistingL2CacheMaxSize); + __polygeist_assign_field(maxThreadsPerMultiProcessor); + // __polygeist_assign_field(streamPrioritiesSupported); + // __polygeist_assign_field(globalL1CacheSupported); + // __polygeist_assign_field(localL1CacheSupported); + // __polygeist_assign_field(sharedMemPerMultiprocessor); + // __polygeist_assign_field(regsPerMultiprocessor); + __polygeist_assign_field(managedMemory); + __polygeist_assign_field(isMultiGpuBoard); + // __polygeist_assign_field(multiGpuBoardGroupID); + // __polygeist_assign_field(singleToDoublePrecisionPerfRatio); + __polygeist_assign_field(pageableMemoryAccess); + __polygeist_assign_field(concurrentManagedAccess); + // __polygeist_assign_field(computePreemptionSupported); + // __polygeist_assign_field(canUseHostPointerForRegisteredMem); + __polygeist_assign_field(cooperativeLaunch); + __polygeist_assign_field(cooperativeMultiDeviceLaunch); + __polygeist_assign_field(pageableMemoryAccessUsesHostPageTables); + __polygeist_assign_field(directManagedMemAccessFromHost); + // __polygeist_assign_field(accessPolicyMaxWindowSize); +#undef __polygeist_assign_field + + return err; +} + +#endif From c20aa366bd240658b58935c877ca34fed7ee8d22 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Mon, 3 Nov 2025 16:51:11 -0600 Subject: [PATCH 03/27] add rocm support for enzymexla::MemcpyOp --- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 92a756f812..051efd59d6 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -923,20 +923,32 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { auto ptrty = LLVM::LLVMPointerType::get(op.getContext()); + auto i32 = rewriter.getIntegerType(32); + + std::string memcpyFuncName; + + bool xla = backend.starts_with("xla"); + + if (xla) { + memcpyFuncName = "reactantXLAMemcpy"; + } else if (backend == "cuda") { + memcpyFuncName = "cudaMemcpy"; + } else if (backend == "rocm") { + memcpyFuncName = "hipMemcpy"; + } + SmallVector tys = {ptrty, ptrty, size.getType(), rewriter.getIntegerType(32)}; if (backend.starts_with("xla")) { tys.insert(tys.begin(), ptrty); } - auto i32 = rewriter.getIntegerType(32); - bool xla = backend.starts_with("xla"); - auto cudaMemcpyFn = LLVM::lookupOrCreateFn( - rewriter, moduleOp, xla ? "reactantXLAMemcpy" : "cudaMemcpy", tys, - xla ? (mlir::Type)LLVM::LLVMVoidType::get(rewriter.getContext()) - : (mlir::Type)i32); - if (failed(cudaMemcpyFn)) + auto memcpyFn = LLVM::lookupOrCreateFn( + rewriter, moduleOp, memcpyFuncName, tys, xla ? (mlir::Type)LLVM::LLVMVoidType::get(rewriter.getContext()) + : (mlir::Type)i32); + if (failed(memcpyFn)) { return failure(); + } SmallVector args = {dst, src, size, rewriter.create( @@ -951,7 +963,7 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { args.insert(args.begin(), xdata); } - rewriter.create(op.getLoc(), cudaMemcpyFn.value(), args); + rewriter.create(op.getLoc(), memcpyFn.value(), args); rewriter.eraseOp(op); return success(); } @@ -1809,7 +1821,7 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, unregisterFatBinaryFuncName = "__cudaUnregisterFatBinary"; registerFatBinaryEndFuncName = "__cudaRegisterFatBinaryEnd"; requiresRegisterEnd = true; - } else { + } else if (gpuTarget == "rocm") { registerFatBinaryFuncName = "__hipRegisterFatBinary"; registerFunctionFuncName = "__hipRegisterFunction"; registerVarFuncName = "__hipRegisterVar"; From ad759cf65c1949cfe67424817d716b13c8efb5a5 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Wed, 5 Nov 2025 03:41:06 -0600 Subject: [PATCH 04/27] add rocm support --- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 438 ++++++++++++++++-- 1 file changed, 397 insertions(+), 41 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 051efd59d6..766f2196ed 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -25,10 +25,12 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -944,8 +946,9 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { } auto memcpyFn = LLVM::lookupOrCreateFn( - rewriter, moduleOp, memcpyFuncName, tys, xla ? (mlir::Type)LLVM::LLVMVoidType::get(rewriter.getContext()) - : (mlir::Type)i32); + rewriter, moduleOp, memcpyFuncName, tys, + xla ? (mlir::Type)LLVM::LLVMVoidType::get(rewriter.getContext()) + : (mlir::Type)i32); if (failed(memcpyFn)) { return failure(); } @@ -1296,7 +1299,7 @@ struct LowerGPUAlternativesOp ops, floatOps, intOps, loads, stores, branches, }; }; - + #if POLYGEIST_ENABLE_CUDA if (gpuTarget == "cuda") { char cuErrorBuffer[4096] = {0}; @@ -1813,7 +1816,7 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, moduleIDPrefix = "__hip_"; fatMagic = HIPFatMagic; } - + if (gpuTarget == "cuda") { registerFatBinaryFuncName = "__cudaRegisterFatBinary"; registerFunctionFuncName = "__cudaRegisterFunction"; @@ -1893,7 +1896,8 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, ctorloc, llvmPointerType, addressOfWrapper); // auto cudaRegisterFatbinFn = - // LLVM::lookupOrCreateFn(rewriter, moduleOp, "__cudaRegisterFatBinary", + // LLVM::lookupOrCreateFn(rewriter, moduleOp, + // "__cudaRegisterFatBinary", // llvmPointerType, llvmPointerType); // if (failed(cudaRegisterFatbinFn)) { // llvm::errs() << " cudamalloc already exists with different types\n"; @@ -1901,18 +1905,21 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, // } // auto module = rewriter.create( - // ctorloc, cudaRegisterFatbinFn.value(), ValueRange(bitcastOfWrapper)); + // ctorloc, cudaRegisterFatbinFn.value(), + // ValueRange(bitcastOfWrapper)); - auto registerFatbinFn = LLVM::lookupOrCreateFn(rewriter, moduleOp, registerFatBinaryFuncName, llvmPointerType, llvmPointerType); + auto registerFatbinFn = + LLVM::lookupOrCreateFn(rewriter, moduleOp, registerFatBinaryFuncName, + llvmPointerType, llvmPointerType); if (failed(registerFatbinFn)) { - llvm::errs() << "register fatbin function already exists with different types\n"; + llvm::errs() + << "register fatbin function already exists with different types\n"; return failure(); } auto module = rewriter.create( - ctorloc, registerFatbinFn.value(), ValueRange(bitcastOfWrapper) - ); + ctorloc, registerFatbinFn.value(), ValueRange(bitcastOfWrapper)); auto moduleGlobalName = std::string(llvm::formatv("polygeist_{0}_module_ptr", moduleName)); @@ -1968,18 +1975,19 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType}; // auto cudaRegisterFn = LLVM::lookupOrCreateFn( - // rewriter, moduleOp, "__cudaRegisterFunction", tys, llvmInt32Type); + // rewriter, moduleOp, "__cudaRegisterFunction", tys, + // llvmInt32Type); // if (failed(cudaRegisterFn)) { - // llvm::errs() << " cudamalloc already exists with different types\n"; - // return failure(); + // llvm::errs() << " cudamalloc already exists with different + // types\n"; return failure(); // } auto registerFunctionFn = LLVM::lookupOrCreateFn( - rewriter, moduleOp, registerFunctionFuncName, tys, llvmInt32Type - ); + rewriter, moduleOp, registerFunctionFuncName, tys, llvmInt32Type); if (failed(registerFunctionFn)) { - llvm::errs() << " register function already exists with different types\n"; + llvm::errs() + << " register function already exists with different types\n"; return failure(); } @@ -1995,8 +2003,10 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, nullPtr, nullPtr}; - // rewriter.create(ctorloc, cudaRegisterFn.value(), args); - rewriter.create(ctorloc, registerFunctionFn.value(), args); + // rewriter.create(ctorloc, cudaRegisterFn.value(), + // args); + rewriter.create(ctorloc, registerFunctionFn.value(), + args); } else if (LLVM::GlobalOp g = dyn_cast(op)) { int addrSpace = g.getAddrSpace(); if (addrSpace != 1 /* device */ && addrSpace != 4 /* constant */) @@ -2045,7 +2055,8 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, 0)}); } } - // TODO this has to happen only for some CUDA versions, hip does not need finialize + // TODO this has to happen only for some CUDA versions, hip does not need + // finialize if (gpuTarget == "cuda") { auto cudaRegisterFatbinFn = LLVM::lookupOrCreateFn( rewriter, moduleOp, "__cudaRegisterFatBinaryEnd", llvmPointerType, @@ -2091,11 +2102,12 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, rewriter, moduleOp, unregisterFatBinaryFuncName, llvmPointerType, llvmVoidType); if (failed(unregisterFatbinFn)) { - llvm::errs() << " unregister fatbin function already exists with different types\n"; - return failure(); + llvm::errs() << " unregister fatbin function already exists with " + "different types\n"; + return failure(); } rewriter.create(ctorloc, unregisterFatbinFn.value(), - ValueRange(module)); + ValueRange(module)); dtorBuilder.create(ctorloc, ValueRange()); auto dtorSymbol = FlatSymbolRefAttr::get(dtor); @@ -2220,10 +2232,16 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); Type tys[] = {ptrty, i64, i32, i64, i32, ptrty, i64, ptrty}; - auto launchCall = rewriter.create( - loc, TypeRange(i32), "cudaLaunchKernel", - args); // FlatSymbolRefAttr::get(rewriter.getStringAttr("cudaLaunchKernel")), - // args); + std::string launchFuncName; + if (gpuTarget == "cuda") { + launchFuncName = "cudaLaunchKernel"; + } else if (gpuTarget == "rocm") { + launchFuncName = "hipLaunchKernel"; + } + + auto launchCall = + rewriter.create(loc, TypeRange(i32), launchFuncName, args); + if (launchOp.getAsyncToken()) { // Async launch: make dependent ops use the same stream. rewriter.replaceOp(launchOp, {stream}); @@ -2511,6 +2529,27 @@ class ConvertAllocOpToGpuRuntimeCallPattern }; rewriter.create(loc, cudaMallocFn.value(), args); allocatedPtr = rewriter.create(loc, ptr1ty, ptr); + } else if (backend == "rocm") { + auto one = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(1)); + + auto ptr = rewriter.create(loc, ptrty, ptr1ty, one); + Type tys[] = {ptrty, i64}; + + auto hipMallocFn = + LLVM::lookupOrCreateFn(rewriter, moduleOp, "hipMalloc", tys, i32); + if (failed(hipMallocFn)) { + llvm::errs() << " hipMalloc already exists with different types\n"; + return failure(); + } + + Value args[] = { + ptr, + sizeBytes, + }; + rewriter.create(loc, hipMallocFn.value(), args); + allocatedPtr = rewriter.create(loc, ptr1ty, ptr); + } else if (backend.starts_with("cpu")) { Type convertedIndex = typeConverter->convertType(rewriter.getIndexType()); @@ -2648,7 +2687,7 @@ class ConvertOccupancyOp if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); - if (backend != "cuda") + if (backend != "cuda" && backend != "rocm") return rewriter.notifyMatchFailure( op, "Occupancy op lowering only supported for CUDA"); @@ -2663,12 +2702,20 @@ class ConvertOccupancyOp Type tys[] = {ptrty, ptrty, intty, adaptor.getDynamicSMemSize().getType(), adaptor.getFlags().getType()}; - auto cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn = - LLVM::lookupOrCreateFn( - rewriter, moduleOp, - "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", tys, i32); - if (failed(cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn)) { - llvm::errs() << " cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags " + std::string occupancyFuncName; + if (backend == "cuda") { + occupancyFuncName = + "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"; + } else if (backend == "rocm") { + occupancyFuncName = + "hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"; + } + + auto occupancyFn = + LLVM::lookupOrCreateFn(rewriter, moduleOp, occupancyFuncName, tys, i32); + + if (failed(occupancyFn)) { + llvm::errs() << " occupancyMaxActiveBlocksPerMultiprocessorWithFlags " "already exists with different types\n"; return failure(); } @@ -2684,9 +2731,7 @@ class ConvertOccupancyOp auto addr = rewriter.create(loc, ptrty, funcStubName); Value args[] = {ptr, addr, adaptor.getBlockSize(), adaptor.getDynamicSMemSize(), adaptor.getFlags()}; - rewriter.create( - loc, cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn.value(), - args); + rewriter.create(loc, occupancyFn.value(), args); rewriter.replaceOpWithNewOp(op, intty, ptr); return success(); @@ -2709,7 +2754,7 @@ class ConvertGPUKernelAddressOp matchAndRewrite(enzymexla::GPUKernelAddressOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (backend != "cuda") + if (backend != "cuda" && backend != "rocm") return rewriter.notifyMatchFailure( op, "KernelAddress lowering only supported for CUDA"); @@ -2781,6 +2826,20 @@ class ConvertDeallocOpToGpuRuntimeCallPattern ptr, }; rewriter.create(loc, cudaFreeFn.value(), args); + } else if (backend == "rocm") { + Type tys[] = {ptr1ty}; + auto hipFreeFn = + LLVM::lookupOrCreateFn(rewriter, moduleOp, "hipFree", tys, i32); + + if (failed(hipFreeFn)) { + llvm::errs() << " hipfree already exists with different types\n"; + return failure(); + } + Value args[] = { + ptr, + }; + rewriter.create(loc, hipFreeFn.value(), args); + } else if (backend.starts_with("cpu")) { FailureOr freeFunc = @@ -3618,6 +3677,222 @@ struct OpLowering : public OpConversionPattern { } // namespace gpu } // namespace mlir +// https://rocm.docs.amd.com/projects/HIP/en/docs-6.4.0/reference/hardware_features.html +struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *context = rewriter.getContext(); + auto int32Type = rewriter.getI32Type(); + + Value minusOne = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(-1)); + Value zero = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(0)); + + Value laneIdLo = rewriter.create(loc, int32Type, minusOne, + zero, nullptr, nullptr); + Value laneId = rewriter.create( + loc, int32Type, minusOne, laneIdLo, nullptr, nullptr); + + LLVM::ConstantRangeAttr bounds = nullptr; + if (std::optional upperBound = op.getUpperBound()) + bounds = rewriter.getAttr( + /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue()); + else + bounds = rewriter.getAttr( + /*bitWidth=*/32, /*lower=*/0, /*upper=*/64); + + if (bounds) { + laneId.getDefiningOp()->setAttr("range", bounds); + } + + const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); + + if (indexBitwidth > 32) { + laneId = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), laneId); + } else if (indexBitwidth < 32) { + laneId = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), laneId); + } + + rewriter.replaceOp(op, {laneId}); + + return success(); + } +}; + +struct GPUShuffleOpToROCDL : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + auto valueTy = adaptor.getValue().getType(); + auto value = adaptor.getValue(); + auto int32Type = IntegerType::get(rewriter.getContext(), 32); + + Value minusOne = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(-1)); + Value zero = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(0)); + + Value laneIdLo = rewriter.create(loc, int32Type, minusOne, + zero, nullptr, nullptr); + Value laneId = rewriter.create( + loc, int32Type, minusOne, laneIdLo, nullptr, nullptr); + + Value targetLane; + Value offset = adaptor.getOffset(); + + switch (op.getMode()) { + case gpu::ShuffleMode::XOR: + targetLane = rewriter.create(loc, int32Type, laneId, offset); + break; + case gpu::ShuffleMode::UP: + targetLane = rewriter.create(loc, int32Type, laneId, offset); + break; + case gpu::ShuffleMode::DOWN: + targetLane = rewriter.create(loc, int32Type, laneId, offset); + break; + case gpu::ShuffleMode::IDX: + targetLane = offset; + break; + } + + Value width = adaptor.getWidth(); + + auto isNonNegative = rewriter.create( + loc, LLVM::ICmpPredicate::sge, targetLane, zero); + auto isWithinWidth = rewriter.create( + loc, LLVM::ICmpPredicate::slt, targetLane, width); + auto isValid = + rewriter.create(loc, isNonNegative, isWithinWidth); + + Value maskAndClamp; + + Value widthMinusone = rewriter.create( + loc, width, + rewriter.create(loc, int32Type, + rewriter.getI32IntegerAttr(1))); + Value minResult = rewriter.create( + loc, + rewriter.create(loc, LLVM::ICmpPredicate::slt, targetLane, + widthMinusone), + targetLane, widthMinusone); + maskAndClamp = rewriter.create( + loc, + rewriter.create(loc, LLVM::ICmpPredicate::sgt, minResult, + zero), + minResult, zero); + + Value four = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(4)); + Value byteIndex = rewriter.create(loc, maskAndClamp, four); + + Value shuffleResult; + if (valueTy.isF32()) { + Value valueAsInt = + rewriter.create(loc, int32Type, value); + + Value resultInt = rewriter.create( + loc, int32Type, byteIndex, valueAsInt); + + shuffleResult = rewriter.create(loc, valueTy, resultInt); + + } else if (valueTy.isInteger(32)) { + shuffleResult = rewriter.create(loc, int32Type, + byteIndex, value); + } + // } else if (valueTy.isF64() || valueTy.isInteger(64)) { + // shuffleResult = shuffle64BitValue(loc, rewriter, value, byteIndex, + // valueTy); + // } + + bool predIsUsed = !op->getResult(1).use_empty(); + if (predIsUsed) { + rewriter.replaceOp(op, {shuffleResult, isValid}); + } else { + rewriter.replaceOp(op, {shuffleResult, nullptr}); + } + + return success(); + } +}; + +struct GPUBarrierToROCDL : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; + +struct ClusterIdOpToROCDL : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::ClusterIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto indexType = getTypeConverter()->getIndexType(); + Value zero = rewriter.create(loc, indexType, + rewriter.getIndexAttr(0)); + + rewriter.replaceOp(op, zero); + return success(); + } +}; + +struct ClusterDimOpToROCDL : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::ClusterDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto indexType = getTypeConverter()->getIndexType(); + Value one = rewriter.create(loc, indexType, + rewriter.getIndexAttr(1)); + + rewriter.replaceOp(op, one); + return success(); + } +}; + +struct ClusterBlockIdToBlockIdLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gpu::ClusterBlockIdOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getDimension()); + return success(); + } +}; + +struct ClusterDimBlocksToGridDimLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gpu::ClusterDimBlocksOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getDimension()); + return success(); + } +}; + /// Appends the patterns lowering operations from the Func dialect to the LLVM /// dialect using the C-style type conversion, i.e. converting memrefs to /// pointer to arrays of arrays. @@ -3675,6 +3950,41 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, populateLibDeviceConversionPatterns(typeConverter, patterns, benefit); patterns.add(typeConverter, benefit); + } else if (gpuTarget == "rocm") { + using namespace mlir::gpu::index_lowering; + PatternBenefit benefit(1); + PatternBenefit highBenefit(2); + patterns.add>(typeConverter, IndexKind::Block, IntrType::Id, + benefit); + patterns.add>(typeConverter, IndexKind::Block, IntrType::Dim, + benefit); + patterns.add>(typeConverter, IndexKind::Grid, IntrType::Id, + benefit); + patterns.add>(typeConverter, IndexKind::Grid, IntrType::Dim, + benefit); + + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + + populateMathToLLVMConversionPatterns(typeConverter, patterns); + populateMathToROCDLConversionPatterns(typeConverter, patterns, + std::nullopt); + + patterns.add(typeConverter, highBenefit); + patterns.add(typeConverter, highBenefit); + patterns.add(&typeConverter.getContext(), + highBenefit); + patterns.add( + &typeConverter.getContext(), highBenefit); } } } @@ -3710,6 +4020,28 @@ static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) { return resumeOp; } +static LLVM::LLVMFuncOp addMocROCmFunction(ModuleOp module, Type streamTy) { + const char fname[] = "fake_rocm_dispatch"; + + MLIRContext *ctx = module->getContext(); + auto loc = module->getLoc(); + auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); + + for (auto fn : module.getBody()->getOps()) { + if (fn.getName() == fname) + return fn; + } + + auto voidTy = LLVM::LLVMVoidType::get(ctx); + auto ptrTy = LLVM::LLVMPointerType::get(ctx); + + auto resumeOp = moduleBuilder.create( + fname, LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); + resumeOp.setPrivate(); + + return resumeOp; +} + struct NoAsyncOpLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -3733,6 +4065,11 @@ struct NoAsyncOpLowering : public OpConversionPattern { struct AsyncOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + std::string backend; + AsyncOpLowering(LLVMTypeConverter &converter, std::string backend) + : ConvertOpToLLVMPattern(converter), + backend(std::move(backend)) {} + LogicalResult matchAndRewrite(async::ExecuteOp execute, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -3933,8 +4270,14 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { } assert(vals.size() == 3); - auto f = addMocCUDAFunction(execute->getParentOfType(), - vals.back().getType()); + // auto f = addMocCUDAFunction(execute->getParentOfType(), + // vals.back().getType()); + + auto f = (backend == "cuda") + ? addMocCUDAFunction(execute->getParentOfType(), + vals.back().getType()) + : addMocROCmFunction(execute->getParentOfType(), + vals.back().getType()); rewriter.create(execute.getLoc(), f, vals); rewriter.eraseOp(execute); @@ -4041,7 +4384,11 @@ struct ConvertPolygeistToLLVMPass auto i64 = rewriter.getIntegerType(64); auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); Type tys[] = {ptrty, i64, i32, i64, i32, ptrty, i64, ptrty}; - LLVM::lookupOrCreateFn(rewriter, m, "cudaLaunchKernel", tys, i32); + if (backend == "cuda") { + LLVM::lookupOrCreateFn(rewriter, m, "cudaLaunchKernel", tys, i32); + } else if (backend == "rocm") { + LLVM::lookupOrCreateFn(rewriter, m, "hipLaunchKernel", tys, i32); + } } for (auto mod : gmods) { @@ -4057,6 +4404,11 @@ struct ConvertPolygeistToLLVMPass target.addLegalOp(); target.addLegalDialect(); target.addLegalOp(); + } else if (backend == "rocm") { + target.addIllegalDialect(); + target.addLegalOp(); + target.addLegalDialect(); + target.addLegalOp(); } } @@ -4106,7 +4458,7 @@ struct ConvertPolygeistToLLVMPass if (backend == "cpu") { if (use_async) - patterns.add(converter); + patterns.add(converter, gpuTarget); else patterns.add(patterns.getContext()); } @@ -4228,12 +4580,16 @@ struct ConvertPolygeistToLLVMPass if (auto callee = call.getCallee()) { if (callee == "cudaDeviceSynchronize") { call->erase(); + } else if (callee == "hipDeviceSynchronize") { + call->erase(); } } }); m->walk([](LLVM::LLVMFuncOp call) { if (call.getName() == "cudaDeviceSynchronize") { call->erase(); + } else if (call.getName() == "hipDeviceSynchronize") { + call->erase(); } }); } From a85b028af59e254716ba81d2c3ad0ac91e94eda5 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Wed, 5 Nov 2025 21:37:54 -0600 Subject: [PATCH 05/27] fix --- .bazelrc | 2 +- src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.bazelrc b/.bazelrc index 90ec3c6a54..0074cf53fe 100644 --- a/.bazelrc +++ b/.bazelrc @@ -32,4 +32,4 @@ common --define=allow_oversize_protos=true # See https://github.com/bazel-contrib/rules_python/issues/2445 build --@rules_python//python/config_settings:precompile=force_disabled -build -c dbg +build -c opt diff --git a/src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp b/src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp index c0494cd4bb..cf4e00783f 100644 --- a/src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp +++ b/src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp @@ -129,9 +129,6 @@ mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY, extra)); } -// The wrapper uses intptr_t instead of CUDA's unsigned int to match -// the type of MLIR's index type. This avoids the need for casts in the -// generated MLIR code. extern "C" MLIR_CUDA_WRAPPERS_EXPORT int32_t mgpuLaunchKernelErr( CUfunction function, intptr_t gridX, intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, From 81f3777d83f60185548e89efc80e1b138c870415 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 6 Nov 2025 01:59:33 -0600 Subject: [PATCH 06/27] add lit tests, but not for landOp, shuffleOp, Cluster*Op --- test/lit_tests/lowering/rocm.mlir | 45 +++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 test/lit_tests/lowering/rocm.mlir diff --git a/test/lit_tests/lowering/rocm.mlir b/test/lit_tests/lowering/rocm.mlir new file mode 100644 index 0000000000..48bea8ccb5 --- /dev/null +++ b/test/lit_tests/lowering/rocm.mlir @@ -0,0 +1,45 @@ +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(convert-polygeist-to-llvm{backend=rocm})" | FileCheck %s + +module attributes {gpu.container_module} { + llvm.func @test_rocm_launch(%arg0: !llvm.ptr) { + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %c1_i64 = arith.constant 1 : i64 + %stream = llvm.inttoptr %c1_i64 : i64 to !llvm.ptr + %token = "enzymexla.stream2token"(%stream) : (!llvm.ptr) -> !gpu.async.token + gpu.launch_func [%token] @test_module::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c32, %c1, %c1) args(%arg0 : !llvm.ptr) + llvm.return + } + + func.func @test_rocm_alloc() { + %alloc = gpu.alloc() : memref<256xf32, 1> + gpu.dealloc %alloc : memref<256xf32, 1> + return + } + + func.func @test_rocm_memcpy(%src: memref<256xf32>, %dst: memref<256xf32, 1>) { + %c1024 = arith.constant 1024 : index + "enzymexla.memcpy"(%dst, %src, %c1024) : (memref<256xf32, 1>, memref<256xf32>, index) -> () + return + } + + gpu.module @test_module { + gpu.func @test_kernel(%arg0: !llvm.ptr) kernel { + gpu.return + } + } +} + +// CHECK-LABEL: llvm.func @test_rocm_launch +// CHECK: llvm.call @hipLaunchKernel +// CHECK-NOT: cudaLaunchKernel + +// CHECK-LABEL: llvm.func @test_rocm_alloc +// CHECK: llvm.call @hipMalloc +// CHECK: llvm.call @hipFree +// CHECK-NOT: cudaMalloc +// CHECK-NOT: cudaFree + +// CHECK-LABEL: llvm.func @test_rocm_memcpy +// CHECK: llvm.call @hipMemcpy +// CHECK-NOT: cudaMemcpy From 01900058d5fdce452a62f46e4e4f1931e144a074 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 6 Nov 2025 02:12:36 -0600 Subject: [PATCH 07/27] add --- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 68 +++++++------------ 1 file changed, 24 insertions(+), 44 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index bd25c3db1f..c1115cb9ef 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -969,11 +969,7 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { args.insert(args.begin(), xdata); } -<<<<<<< HEAD rewriter.create(op.getLoc(), memcpyFn.value(), args); -======= - LLVM::CallOp::create(rewriter, op.getLoc(), cudaMemcpyFn.value(), args); ->>>>>>> origin/main rewriter.eraseOp(op); return success(); } @@ -1941,14 +1937,9 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, return failure(); } -<<<<<<< HEAD - auto module = rewriter.create( - ctorloc, registerFatbinFn.value(), ValueRange(bitcastOfWrapper)); -======= auto module = - LLVM::CallOp::create(rewriter, ctorloc, cudaRegisterFatbinFn.value(), + LLVM::CallOp::create(rewriter, ctorloc, registerFatbinFn.value(), ValueRange(bitcastOfWrapper)); ->>>>>>> origin/main auto moduleGlobalName = std::string(llvm::formatv("polygeist_{0}_module_ptr", moduleName)); @@ -2032,14 +2023,10 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, nullPtr, nullPtr}; -<<<<<<< HEAD - // rewriter.create(ctorloc, cudaRegisterFn.value(), + // LLVM::CallOp::create(rewriter, ctorloc, cudaRegisterFn.value(), // args); rewriter.create(ctorloc, registerFunctionFn.value(), args); -======= - LLVM::CallOp::create(rewriter, ctorloc, cudaRegisterFn.value(), args); ->>>>>>> origin/main } else if (LLVM::GlobalOp g = dyn_cast(op)) { int addrSpace = g.getAddrSpace(); if (addrSpace != 1 /* device */ && addrSpace != 4 /* constant */) @@ -2139,17 +2126,10 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, "different types\n"; return failure(); } -<<<<<<< HEAD rewriter.create(ctorloc, unregisterFatbinFn.value(), ValueRange(module)); dtorBuilder.create(ctorloc, ValueRange()); -======= - - LLVM::CallOp::create(rewriter, ctorloc, cudaUnRegisterFatbinFn.value(), - ValueRange(module)); - LLVM::ReturnOp::create(dtorBuilder, ctorloc, ValueRange()); ->>>>>>> origin/main auto dtorSymbol = FlatSymbolRefAttr::get(dtor); { PatternRewriter::InsertionGuard B(rewriter); @@ -2273,7 +2253,6 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); Type tys[] = {ptrty, i64, i32, i64, i32, ptrty, i64, ptrty}; -<<<<<<< HEAD std::string launchFuncName; if (gpuTarget == "cuda") { launchFuncName = "cudaLaunchKernel"; @@ -2284,12 +2263,6 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( auto launchCall = rewriter.create(loc, TypeRange(i32), launchFuncName, args); -======= - auto launchCall = LLVM::CallOp::create( - rewriter, loc, TypeRange(i32), "cudaLaunchKernel", - args); // FlatSymbolRefAttr::get(rewriter.getStringAttr("cudaLaunchKernel")), - // args); ->>>>>>> origin/main if (launchOp.getAsyncToken()) { // Async launch: make dependent ops use the same stream. rewriter.replaceOp(launchOp, {stream}); @@ -2576,7 +2549,6 @@ class ConvertAllocOpToGpuRuntimeCallPattern ptr, sizeBytes, }; -<<<<<<< HEAD rewriter.create(loc, cudaMallocFn.value(), args); allocatedPtr = rewriter.create(loc, ptr1ty, ptr); } else if (backend == "rocm") { @@ -2600,10 +2572,6 @@ class ConvertAllocOpToGpuRuntimeCallPattern rewriter.create(loc, hipMallocFn.value(), args); allocatedPtr = rewriter.create(loc, ptr1ty, ptr); -======= - LLVM::CallOp::create(rewriter, loc, cudaMallocFn.value(), args); - allocatedPtr = LLVM::LoadOp::create(rewriter, loc, ptr1ty, ptr); ->>>>>>> origin/main } else if (backend.starts_with("cpu")) { Type convertedIndex = typeConverter->convertType(rewriter.getIndexType()); @@ -2787,13 +2755,7 @@ class ConvertOccupancyOp auto addr = LLVM::AddressOfOp::create(rewriter, loc, ptrty, funcStubName); Value args[] = {ptr, addr, adaptor.getBlockSize(), adaptor.getDynamicSMemSize(), adaptor.getFlags()}; -<<<<<<< HEAD rewriter.create(loc, occupancyFn.value(), args); -======= - LLVM::CallOp::create( - rewriter, loc, - cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn.value(), args); ->>>>>>> origin/main rewriter.replaceOpWithNewOp(op, intty, ptr); return success(); @@ -2887,7 +2849,6 @@ class ConvertDeallocOpToGpuRuntimeCallPattern Value args[] = { ptr, }; -<<<<<<< HEAD rewriter.create(loc, cudaFreeFn.value(), args); } else if (backend == "rocm") { Type tys[] = {ptr1ty}; @@ -2903,9 +2864,6 @@ class ConvertDeallocOpToGpuRuntimeCallPattern }; rewriter.create(loc, hipFreeFn.value(), args); -======= - LLVM::CallOp::create(rewriter, loc, cudaFreeFn.value(), args); ->>>>>>> origin/main } else if (backend.starts_with("cpu")) { FailureOr freeFunc = @@ -4114,6 +4072,28 @@ static LLVM::LLVMFuncOp addMocROCmFunction(ModuleOp module, Type streamTy) { return resumeOp; } +static LLVM::LLVMFuncOp addMocROCmFunction(ModuleOp module, Type streamTy) { + const char fname[] = "fake_rocm_dispatch"; + + MLIRContext *ctx = module->getContext(); + auto loc = module->getLoc(); + auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); + + for (auto fn : module.getBody()->getOps()) { + if (fn.getName() == fname) + return fn; + } + + auto voidTy = LLVM::LLVMVoidType::get(ctx); + auto ptrTy = LLVM::LLVMPointerType::get(ctx); + + auto resumeOp = moduleBuilder.create( + fname, LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); + resumeOp.setPrivate(); + + return resumeOp; +} + struct NoAsyncOpLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; From 7f7d6ea8fd685a029bc72ab7c077eacce3ad1798 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 6 Nov 2025 02:16:26 -0600 Subject: [PATCH 08/27] fix --- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 683 ++++++++---------- 1 file changed, 306 insertions(+), 377 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index c1115cb9ef..db6079492d 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -133,40 +133,40 @@ static Value insertXLAInitDeinit(mlir::ModuleOp moduleOp, StringRef backend, if (ctor) { assert(dtor && "xla module constructor does not exist but destructor does"); assert(data && "xla module constructor does not exist but data does"); - return LLVM::AddressOfOp::create(rewriter, loc, ptrty, - data.getSymNameAttr()); + return rewriter.create(loc, ptrty, + data.getSymNameAttr()); } { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - ctor = LLVM::LLVMFuncOp::create( - rewriter, loc, ctorNameBuffer, + ctor = rewriter.create( + loc, ctorNameBuffer, LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(moduleOp.getContext()), {}), LLVM::Linkage::Private); - dtor = LLVM::LLVMFuncOp::create( - rewriter, loc, dtorNameBuffer, + dtor = rewriter.create( + loc, dtorNameBuffer, LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(moduleOp.getContext()), {}), LLVM::Linkage::Private); auto ctorSymbol = FlatSymbolRefAttr::get(ctor); - LLVM::GlobalCtorsOp::create( - rewriter, loc, rewriter.getArrayAttr({std::move(ctorSymbol)}), + rewriter.create( + loc, rewriter.getArrayAttr({std::move(ctorSymbol)}), rewriter.getI32ArrayAttr({65535}), rewriter.getArrayAttr({LLVM::ZeroAttr::get(rewriter.getContext())})); auto dtorSymbol = FlatSymbolRefAttr::get(dtor); - LLVM::GlobalDtorsOp::create( - rewriter, loc, rewriter.getArrayAttr({std::move(dtorSymbol)}), + rewriter.create( + loc, rewriter.getArrayAttr({std::move(dtorSymbol)}), rewriter.getI32ArrayAttr({65535}), rewriter.getArrayAttr({LLVM::ZeroAttr::get(rewriter.getContext())})); - data = LLVM::GlobalOp::create(rewriter, loc, ptrty, /*constant*/ false, - LLVM::Linkage::Internal, dataNameBuffer, - /* initValue */ mlir::Attribute(), - /* alignment */ 8, /* addrSpace */ 0); + data = rewriter.create( + loc, ptrty, /*constant*/ false, LLVM::Linkage::Internal, dataNameBuffer, + /* initValue */ mlir::Attribute(), + /* alignment */ 8, /* addrSpace */ 0); } // device id, ptr @@ -201,10 +201,10 @@ static Value insertXLAInitDeinit(mlir::ModuleOp moduleOp, StringRef backend, loc, rewriter, "xlabackend", bstr, LLVM::Linkage::Internal); auto glob = - LLVM::AddressOfOp::create(rewriter, loc, ptrty, data.getSymNameAttr()); + rewriter.create(loc, ptrty, data.getSymNameAttr()); Value args[] = {glob, stringval}; - LLVM::CallOp::create(rewriter, loc, xlaInitFn.value(), args); - LLVM::ReturnOp::create(rewriter, loc, ValueRange()); + rewriter.create(loc, xlaInitFn.value(), args); + rewriter.create(loc, ValueRange()); } { @@ -212,13 +212,13 @@ static Value insertXLAInitDeinit(mlir::ModuleOp moduleOp, StringRef backend, rewriter.setInsertionPointToEnd(dtor.addEntryBlock(rewriter)); auto glob = - LLVM::AddressOfOp::create(rewriter, loc, ptrty, data.getSymNameAttr()); + rewriter.create(loc, ptrty, data.getSymNameAttr()); Value args[] = {glob}; - LLVM::CallOp::create(rewriter, loc, xlaDeInitFn.value(), args); - LLVM::ReturnOp::create(rewriter, loc, ValueRange()); + rewriter.create(loc, xlaDeInitFn.value(), args); + rewriter.create(loc, ValueRange()); } - return LLVM::AddressOfOp::create(rewriter, loc, ptrty, data.getSymNameAttr()); + return rewriter.create(loc, ptrty, data.getSymNameAttr()); } struct Stream2TokenOpLowering : public ConvertOpToLLVMPattern { @@ -246,7 +246,7 @@ struct Memref2PointerOpLowering if (isa(transformed.getSource().getType())) { mlir::Value ptr = transformed.getSource(); if (space0 != LPT.getAddressSpace()) - ptr = LLVM::AddrSpaceCastOp::create(rewriter, loc, LPT, ptr); + ptr = rewriter.create(loc, LPT, ptr); rewriter.replaceOp(op, {ptr}); return success(); } @@ -260,10 +260,10 @@ struct Memref2PointerOpLowering Value baseOffset = targetMemRef.offset(rewriter, loc); Value ptr = targetMemRef.alignedPtr(rewriter, loc); Value idxs[] = {baseOffset}; - ptr = LLVM::GEPOp::create(rewriter, loc, ptr.getType(), - rewriter.getI8Type(), ptr, idxs); + ptr = rewriter.create(loc, ptr.getType(), rewriter.getI8Type(), + ptr, idxs); if (space0 != LPT.getAddressSpace()) - ptr = LLVM::AddrSpaceCastOp::create(rewriter, loc, LPT, ptr); + ptr = rewriter.create(loc, LPT, ptr); rewriter.replaceOp(op, {ptr}); return success(); @@ -287,7 +287,7 @@ struct Pointer2MemrefOpLowering mlir::Value ptr = adaptor.getSource(); if (space1 != cast(op.getOperand().getType()) .getAddressSpace()) - ptr = LLVM::AddrSpaceCastOp::create(rewriter, loc, PT, ptr); + ptr = rewriter.create(loc, PT, ptr); rewriter.replaceOp(op, {ptr}); return success(); } @@ -297,8 +297,8 @@ struct Pointer2MemrefOpLowering if (space1 != cast(op.getOperand().getType()) .getAddressSpace()) - ptr = LLVM::AddrSpaceCastOp::create(rewriter, loc, - descr.getElementPtrType(), ptr); + ptr = rewriter.create( + loc, descr.getElementPtrType(), ptr); // Extract all strides and offsets and verify they are static. int64_t offset; @@ -575,7 +575,7 @@ struct CAllocOpLowering : public AllocLikeOpLowering { if (auto F = module.lookupSymbol("malloc")) { Value allocated = - func::CallOp::create(rewriter, loc, F, sizeBytes).getResult(0); + rewriter.create(loc, F, sizeBytes).getResult(0); rewriter.replaceOpWithNewOp( allocOp, convertedType, allocated); } else { @@ -587,7 +587,7 @@ struct CAllocOpLowering : public AllocLikeOpLowering { if (failed(mallocFunc)) return failure(); Value allocated = - LLVM::CallOp::create(rewriter, loc, mallocFunc.value(), sizeBytes) + rewriter.create(loc, mallocFunc.value(), sizeBytes) .getResult(); rewriter.replaceOpWithNewOp(allocOp, convertedType, allocated); @@ -606,9 +606,9 @@ struct CDeallocOpLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto module = deallocOp->getParentOfType(); if (auto F = module.lookupSymbol("free")) { - Value casted = enzymexla::Pointer2MemrefOp::create( - rewriter, deallocOp->getLoc(), - MemRefType::get({-1}, rewriter.getI8Type()), adaptor.getMemref()); + Value casted = rewriter.create( + deallocOp->getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), + adaptor.getMemref()); rewriter.replaceOpWithNewOp(deallocOp, F, casted); } else { FailureOr freeFunc = @@ -691,8 +691,8 @@ struct GlobalOpLowering : public ConvertOpToLLVMPattern { newGlobal.getInitializerRegion().begin()); rewriter.setInsertionPointToStart(block); Value undef = - LLVM::UndefOp::create(rewriter, globalOp->getLoc(), convertedType); - LLVM::ReturnOp::create(rewriter, globalOp->getLoc(), undef); + rewriter.create(globalOp->getLoc(), convertedType); + rewriter.create(globalOp->getLoc(), undef); } return success(); } @@ -709,8 +709,8 @@ struct GetGlobalOpLowering ConversionPatternRewriter &rewriter) const override { MemRefType originalType = getGlobalOp.getType(); Type convertedType = getTypeConverter()->convertType(originalType); - Value wholeAddress = LLVM::AddressOfOp::create( - rewriter, getGlobalOp->getLoc(), convertedType, getGlobalOp.getName()); + Value wholeAddress = rewriter.create( + getGlobalOp->getLoc(), convertedType, getGlobalOp.getName()); rewriter.replaceOp(getGlobalOp, wholeAddress); return success(); @@ -744,8 +744,8 @@ struct CLoadStoreOpLowering : public ConvertOpToLLVMPattern { (void)rewriter.notifyMatchFailure(loc, "unsupported memref type"); return nullptr; } - return LLVM::GEPOp::create( - rewriter, loc, + return rewriter.create( + loc, LLVM::LLVMPointerType::get(op.getContext(), originalType.getMemorySpaceAsInt()), elTy, adaptor.getMemref(), args); @@ -890,18 +890,16 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { if (dstType.getMemorySpaceAsInt() == 0 && srcType.getMemorySpaceAsInt() == 0) { - LLVM::MemcpyOp::create(rewriter, op.getLoc(), dst, src, size, false); + rewriter.create(op.getLoc(), dst, src, size, false); rewriter.eraseOp(op); return success(); } if (backend == "cpu") { - dst = LLVM::AddrSpaceCastOp::create( - rewriter, op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), - dst); - src = LLVM::AddrSpaceCastOp::create( - rewriter, op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), - src); - LLVM::MemcpyOp::create(rewriter, op.getLoc(), dst, src, size, false); + dst = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), dst); + src = rewriter.create( + op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), src); + rewriter.create(op.getLoc(), dst, src, size, false); rewriter.eraseOp(op); return success(); } @@ -956,13 +954,12 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { } SmallVector args = {dst, src, size, - LLVM::ConstantOp::create(rewriter, op.getLoc(), - tys[3 + xla], - direction)}; + rewriter.create( + op.getLoc(), tys[3 + xla], direction)}; for (int i = 0; i < 2; i++) if (args[i].getType() != tys[i]) - args[i] = LLVM::AddrSpaceCastOp::create(rewriter, op.getLoc(), - tys[i + xla], args[i]); + args[i] = rewriter.create(op.getLoc(), + tys[i + xla], args[i]); if (backend.starts_with("xla")) { auto xdata = insertXLAInitDeinit(moduleOp, backend, rewriter); @@ -1561,7 +1558,7 @@ struct LowerGPUAlternativesOp auto kernelId = LLVM::createGlobalString( loc, rewriter, std::string("kernelId.") + std::to_string(num++), nullTermLocStr, LLVM::Linkage::Internal, /*opaquePointers*/ true); - auto totalAlternatives = LLVM::ConstantOp::create(rewriter, + auto totalAlternatives = rewriter.create( loc, llvmInt32Type, gao->getNumRegions()); auto alternative = rtPGOGetAlternativeCallBuilder @@ -1570,10 +1567,10 @@ struct LowerGPUAlternativesOp int i = 0; for (auto ®ion : gao->getRegions()) { - auto cmpOp = arith::CmpIOp::create(rewriter, + auto cmpOp = rewriter.create( loc, arith::CmpIPredicate::eq, alternative, - arith::ConstantIntOp::create(rewriter, loc, i, 32)); - auto ifOp = scf::IfOp::create(rewriter, loc, cmpOp, /* hasElse */ true); + rewriter.create(loc, i, 32)); + auto ifOp = rewriter.create(loc, cmpOp, /* hasElse */ true); auto block = ®ion.front(); rewriter.eraseOp(block->getTerminator()); rewriter.inlineBlockBefore( @@ -1693,46 +1690,31 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( argumentTypes.push_back(argument.getType()); auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(), argumentTypes); - Value structPtr, arrayPtr, one; + Value structPtr, arrayPtr; { PatternRewriter::InsertionGuard B(builder); builder.setInsertionPointToStart(allocaBlock); - one = LLVM::ConstantOp::create(builder, loc, llvmInt32Type, 1); - structPtr = LLVM::AllocaOp::create( - builder, loc, LLVM::LLVMPointerType::get(builder.getContext()), - structType, one, + auto one = builder.create(loc, llvmInt32Type, 1); + structPtr = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), structType, one, /*alignment=*/0); auto arraySize = - LLVM::ConstantOp::create(builder, loc, llvmInt32Type, numArguments); - arrayPtr = LLVM::AllocaOp::create(builder, loc, llvmPointerPointerType, - llvmPointerType, arraySize, - /*alignment=*/0); + builder.create(loc, llvmInt32Type, numArguments); + arrayPtr = builder.create(loc, llvmPointerPointerType, + llvmPointerType, arraySize, + /*alignment=*/0); } - auto argAttrss = - dyn_cast_or_null(launchOp->getAttr("reactant.arg_attrs")); for (const auto &en : llvm::enumerate(arguments)) { - bool isByVal = - argAttrss && cast(argAttrss[en.index()]) - .getNamed(LLVM::LLVMDialect::getByValAttrName()); - Value fieldPtr; - if (isByVal) { - fieldPtr = en.value(); - } else { - { - PatternRewriter::InsertionGuard B(builder); - builder.setInsertionPointToStart(allocaBlock); - fieldPtr = LLVM::AllocaOp::create(builder, loc, llvmPointerPointerType, - en.value().getType(), one, - /*alignment=*/0); - } - LLVM::StoreOp::create(builder, loc, en.value(), fieldPtr); - } - auto elementPtr = LLVM::GEPOp::create(builder, loc, llvmPointerType, - llvmPointerPointerType, arrayPtr, - ArrayRef{en.index()}); + auto fieldPtr = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), structType, + structPtr, ArrayRef{0, en.index()}); + builder.create(loc, en.value(), fieldPtr); + auto elementPtr = builder.create( + loc, llvmPointerType, llvmPointerPointerType, arrayPtr, + ArrayRef{en.index()}); auto casted = - LLVM::BitcastOp::create(builder, loc, llvmPointerType, fieldPtr); - LLVM::StoreOp::create(builder, loc, casted, elementPtr); + builder.create(loc, llvmPointerType, fieldPtr); + builder.create(loc, casted, elementPtr); } return arrayPtr; } @@ -1783,13 +1765,13 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - ctor = LLVM::LLVMFuncOp::create( - rewriter, ctorloc, ctorNameBuffer, + ctor = rewriter.create( + ctorloc, ctorNameBuffer, LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(moduleOp.getContext()), {}), LLVM::Linkage::Private); - dtor = LLVM::LLVMFuncOp::create( - rewriter, ctorloc, dtorNameBuffer, + dtor = rewriter.create( + ctorloc, dtorNameBuffer, LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(moduleOp.getContext()), {}), LLVM::Linkage::Private); @@ -1866,9 +1848,8 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - fatBinWrapper = LLVM::GlobalOp::create( - rewriter, loc, fatBinWrapperType, /*constant*/ true, - LLVM::Linkage::Internal, + fatBinWrapper = rewriter.create( + loc, fatBinWrapperType, /*constant*/ true, LLVM::Linkage::Internal, std::string( llvm::formatv("__polygeist_{0}_fatbin_wrapper", moduleName)), /* initValue */ mlir::Attribute(), @@ -1880,39 +1861,39 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, fatBinWrapper.getRegion().push_back(new Block); globalBuilder.setInsertionPointToStart(fatBinWrapper.getBody()); auto fatbinMagicVal = - LLVM::ConstantOp::create(globalBuilder, loc, llvmInt32Type, fatMagic); + globalBuilder.create(loc, llvmInt32Type, fatMagic); auto fatbinVersionVal = - LLVM::ConstantOp::create(globalBuilder, loc, llvmInt32Type, 1); - auto nullPtr = LLVM::ZeroOp::create(globalBuilder, loc, llvmPointerType); + globalBuilder.create(loc, llvmInt32Type, 1); + auto nullPtr = globalBuilder.create(loc, llvmPointerType); Value constructedStruct = - LLVM::UndefOp::create(globalBuilder, loc, fatBinWrapperType); + globalBuilder.create(loc, fatBinWrapperType); { int i = 0; - constructedStruct = LLVM::InsertValueOp::create( - globalBuilder, loc, fatBinWrapperType, constructedStruct, - fatbinMagicVal, globalBuilder.getDenseI64ArrayAttr(i++)); - constructedStruct = LLVM::InsertValueOp::create( - globalBuilder, loc, fatBinWrapperType, constructedStruct, - fatbinVersionVal, globalBuilder.getDenseI64ArrayAttr(i++)); + constructedStruct = globalBuilder.create( + loc, fatBinWrapperType, constructedStruct, fatbinMagicVal, + globalBuilder.getDenseI64ArrayAttr(i++)); + constructedStruct = globalBuilder.create( + loc, fatBinWrapperType, constructedStruct, fatbinVersionVal, + globalBuilder.getDenseI64ArrayAttr(i++)); // TODO do we need to specify the section name here...? // data.setSectionAttr(moduleBuilder.getStringAttr(fatbinSectionName)); Value data = LLVM::createGlobalString( loc, globalBuilder, nameBuffer.str(), "binaryAttr", // loc, globalBuilder, nameBuffer.str(), binaryAttr.getValue(), LLVM::Linkage::Internal); - constructedStruct = LLVM::InsertValueOp::create( - globalBuilder, loc, fatBinWrapperType, constructedStruct, data, + constructedStruct = globalBuilder.create( + loc, fatBinWrapperType, constructedStruct, data, globalBuilder.getDenseI64ArrayAttr(i++)); - constructedStruct = LLVM::InsertValueOp::create( - globalBuilder, loc, fatBinWrapperType, constructedStruct, nullPtr, + constructedStruct = globalBuilder.create( + loc, fatBinWrapperType, constructedStruct, nullPtr, globalBuilder.getDenseI64ArrayAttr(i++)); } - LLVM::ReturnOp::create(globalBuilder, loc, constructedStruct); + globalBuilder.create(loc, constructedStruct); auto addressOfWrapper = - LLVM::AddressOfOp::create(ctorBuilder, ctorloc, fatBinWrapper); - auto bitcastOfWrapper = LLVM::AddrSpaceCastOp::create( - ctorBuilder, ctorloc, llvmPointerType, addressOfWrapper); + ctorBuilder.create(ctorloc, fatBinWrapper); + auto bitcastOfWrapper = ctorBuilder.create( + ctorloc, llvmPointerType, addressOfWrapper); // auto cudaRegisterFatbinFn = // LLVM::lookupOrCreateFn(rewriter, moduleOp, @@ -1937,24 +1918,23 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, return failure(); } - auto module = - LLVM::CallOp::create(rewriter, ctorloc, registerFatbinFn.value(), - ValueRange(bitcastOfWrapper)); + auto module = rewriter.create( + ctorloc, registerFatbinFn.value(), ValueRange(bitcastOfWrapper)); auto moduleGlobalName = std::string(llvm::formatv("polygeist_{0}_module_ptr", moduleName)); { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - moduleGlobal = LLVM::GlobalOp::create( - rewriter, ctorloc, llvmPointerPointerType, /* isConstant */ false, + moduleGlobal = rewriter.create( + ctorloc, llvmPointerPointerType, /* isConstant */ false, LLVM::Linkage::Internal, moduleGlobalName, /* initValue */ mlir::Attribute(), /* alignment */ 8, /* addrSpace */ 0); } - auto aoo = LLVM::AddressOfOp::create(ctorBuilder, ctorloc, moduleGlobal); - LLVM::StoreOp::create(ctorBuilder, loc, module->getResult(0), - aoo->getResult(0)); + auto aoo = ctorBuilder.create(ctorloc, moduleGlobal); + ctorBuilder.create(loc, module->getResult(0), + aoo->getResult(0)); for (Operation &op : kernelModule->getRegion(0).front()) { if (auto f = dyn_cast(op)) { if (!f->getAttr("gpu.kernel")) @@ -1963,7 +1943,7 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, kernelModule.getName(), f.getName(), ctorloc, ctorBuilder); auto nullPtr = - LLVM::ZeroOp::create(ctorBuilder, ctorloc, llvmPointerType); + ctorBuilder.create(ctorloc, llvmPointerType); // TODO second param should be ptr to the the original function stub // here like clang does it: e.g. kernel_name_device_stub // @@ -1976,19 +1956,19 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - stub = LLVM::LLVMFuncOp::create( - rewriter, ctorloc, getFuncStubName(moduleName, f.getName()), + stub = rewriter.create( + ctorloc, getFuncStubName(moduleName, f.getName()), LLVM::LLVMFunctionType::get(llvmVoidType, {}), LLVM::Linkage::Internal); } { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(stub.addEntryBlock(rewriter)); - LLVM::ReturnOp::create(rewriter, ctorloc, ValueRange()); + rewriter.create(ctorloc, ValueRange()); } - auto aoo = LLVM::AddressOfOp::create(ctorBuilder, ctorloc, stub); - auto bitcast = LLVM::AddrSpaceCastOp::create(ctorBuilder, ctorloc, - llvmPointerType, aoo); + auto aoo = ctorBuilder.create(ctorloc, stub); + auto bitcast = ctorBuilder.create( + ctorloc, llvmPointerType, aoo); Type tys[] = {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType, @@ -2016,14 +1996,14 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, bitcast, kernelName, kernelName, - LLVM::ConstantOp::create(ctorBuilder, ctorloc, llvmInt32Type, -1), + ctorBuilder.create(ctorloc, llvmInt32Type, -1), nullPtr, nullPtr, nullPtr, nullPtr, nullPtr}; - // LLVM::CallOp::create(rewriter, ctorloc, cudaRegisterFn.value(), + // rewriter.create(ctorloc, cudaRegisterFn.value(), // args); rewriter.create(ctorloc, registerFunctionFn.value(), args); @@ -2046,9 +2026,9 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, // TODO could this be a memref global op? auto stub = moduleOp.lookupSymbol(g.getName()); assert(stub); - auto aoo = LLVM::AddressOfOp::create(ctorBuilder, ctorloc, stub); - auto bitcast = LLVM::AddrSpaceCastOp::create(ctorBuilder, ctorloc, - llvmPointerType, aoo); + auto aoo = ctorBuilder.create(ctorloc, stub); + auto bitcast = ctorBuilder.create( + ctorloc, llvmPointerType, aoo); auto globalTy = stub.getGlobalType(); // TODO This should actually be the GPUModuleOp's data layout I // believe, there were problems with assigning the data layout to @@ -2062,17 +2042,17 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, ctorloc, ctorBuilder, {module.getResult(), bitcast, symbolName, symbolName, /*isExtern*/ - LLVM::ConstantOp::create(ctorBuilder, ctorloc, llvmInt32Type, - /* TODO */ 0), + ctorBuilder.create(ctorloc, llvmInt32Type, + /* TODO */ 0), /*varSize*/ - LLVM::ConstantOp::create(ctorBuilder, ctorloc, llvmIntPtrType, - size), + ctorBuilder.create(ctorloc, llvmIntPtrType, + size), /*isConstant*/ - LLVM::ConstantOp::create(ctorBuilder, ctorloc, llvmInt32Type, - /* TODO */ 0), + ctorBuilder.create(ctorloc, llvmInt32Type, + /* TODO */ 0), /* just a 0? */ - LLVM::ConstantOp::create(ctorBuilder, ctorloc, llvmInt32Type, - 0)}); + ctorBuilder.create(ctorloc, llvmInt32Type, + 0)}); } } // TODO this has to happen only for some CUDA versions, hip does not need @@ -2086,17 +2066,17 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, return failure(); } - LLVM::CallOp::create(rewriter, ctorloc, cudaRegisterFatbinFn.value(), - ValueRange(module->getResult(0))); + rewriter.create(ctorloc, cudaRegisterFatbinFn.value(), + ValueRange(module->getResult(0))); } - LLVM::ReturnOp::create(ctorBuilder, ctorloc, ValueRange()); + ctorBuilder.create(ctorloc, ValueRange()); } auto ctorSymbol = FlatSymbolRefAttr::get(ctor); { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - LLVM::GlobalCtorsOp::create( - rewriter, ctorloc, rewriter.getArrayAttr({std::move(ctorSymbol)}), + rewriter.create( + ctorloc, rewriter.getArrayAttr({std::move(ctorSymbol)}), rewriter.getI32ArrayAttr({65535}), rewriter.getArrayAttr({LLVM::ZeroAttr::get(rewriter.getContext())})); } @@ -2104,9 +2084,9 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, PatternRewriter::InsertionGuard B(rewriter); OpBuilder &dtorBuilder = rewriter; dtorBuilder.setInsertionPointToEnd(dtor.addEntryBlock(dtorBuilder)); - auto aoo = LLVM::AddressOfOp::create(dtorBuilder, ctorloc, moduleGlobal); - auto module = LLVM::LoadOp::create( - dtorBuilder, ctorloc, llvmPointerPointerType, aoo->getResult(0)); + auto aoo = dtorBuilder.create(ctorloc, moduleGlobal); + auto module = dtorBuilder.create( + ctorloc, llvmPointerPointerType, aoo->getResult(0)); // auto cudaUnRegisterFatbinFn = LLVM::lookupOrCreateFn( // rewriter, moduleOp, "__cudaUnregisterFatBinary", llvmPointerType, @@ -2135,8 +2115,8 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); Attribute attrs[] = {LLVM::ZeroAttr::get(rewriter.getContext())}; - LLVM::GlobalDtorsOp::create( - rewriter, ctorloc, rewriter.getArrayAttr({std::move(dtorSymbol)}), + rewriter.create( + ctorloc, rewriter.getArrayAttr({std::move(dtorSymbol)}), rewriter.getI32ArrayAttr({65535}), rewriter.getArrayAttr(attrs)); } } @@ -2208,10 +2188,10 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( launchOp.getKernelName().getValue()); auto bitcast = - LLVM::AddressOfOp::create(rewriter, loc, llvmPointerType, funcStubName); + rewriter.create(loc, llvmPointerType, funcStubName); - Value zero = LLVM::ConstantOp::create(rewriter, loc, llvmInt32Type, 0); - auto nullpointer = LLVM::ZeroOp::create(rewriter, loc, llvmPointerType); + Value zero = rewriter.create(loc, llvmInt32Type, 0); + auto nullpointer = rewriter.create(loc, llvmPointerType); Value stream = adaptor.getAsyncDependencies().empty() ? nullpointer : adaptor.getAsyncDependencies().front(); @@ -2228,17 +2208,16 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( auto i32 = rewriter.getIntegerType(32); auto i64 = rewriter.getIntegerType(64); auto dim3 = [&](Value x, Value y, Value z) { - x = LLVM::TruncOp::create(rewriter, x.getLoc(), i32, x); - y = LLVM::TruncOp::create(rewriter, y.getLoc(), i32, y); - z = LLVM::TruncOp::create(rewriter, z.getLoc(), i32, z); + x = rewriter.create(x.getLoc(), i32, x); + y = rewriter.create(y.getLoc(), i32, y); + z = rewriter.create(z.getLoc(), i32, z); - x = LLVM::ZExtOp::create(rewriter, x.getLoc(), i64, x); - y = LLVM::ZExtOp::create(rewriter, y.getLoc(), i64, y); + x = rewriter.create(x.getLoc(), i64, x); + y = rewriter.create(y.getLoc(), i64, y); - y = LLVM::ShlOp::create( - rewriter, y.getLoc(), y, - LLVM::ConstantOp::create(rewriter, y.getLoc(), i64, 32)); - args.push_back(LLVM::OrOp::create(rewriter, x.getLoc(), x, y)); + y = rewriter.create( + y.getLoc(), y, rewriter.create(y.getLoc(), i64, 32)); + args.push_back(rewriter.create(x.getLoc(), x, y)); args.push_back(z); }; dim3(adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeZ()); @@ -2247,7 +2226,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( args.push_back(kernelParams); args.push_back( - LLVM::ZExtOp::create(rewriter, loc, i64, dynamicSharedMemorySize)); + rewriter.create(loc, i64, dynamicSharedMemorySize)); args.push_back(stream); auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); @@ -2272,8 +2251,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( if (errOp) { rewriter.setInsertionPoint(errOp); - auto reg = scf::ExecuteRegionOp::create(rewriter, errOp.getLoc(), - launchCall->getResultTypes()[0]); + auto reg = rewriter.create( + errOp.getLoc(), launchCall->getResultTypes()[0]); rewriter.inlineRegionBefore(errOp.getRegion(), reg.getRegion(), reg.getRegion().begin()); rewriter.createBlock(&errOp.getRegion()); @@ -2282,36 +2261,35 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto one = LLVM::ConstantOp::create(rewriter, loc, i64, - rewriter.getI64IntegerAttr(1)); + auto one = rewriter.create(loc, i64, + rewriter.getI64IntegerAttr(1)); - auto alloca = LLVM::AllocaOp::create(rewriter, launchOp.getLoc(), ptrty, - launchCall->getResultTypes()[0], one); - auto zero = arith::ConstantIntOp::create( - rewriter, loc, launchCall->getResultTypes()[0], 0); + auto alloca = rewriter.create( + launchOp.getLoc(), ptrty, launchCall->getResultTypes()[0], one); + auto zero = rewriter.create( + loc, launchCall->getResultTypes()[0], 0); rewriter.setInsertionPoint(errOp); - LLVM::StoreOp::create(rewriter, launchOp.getLoc(), zero, alloca); + rewriter.create(launchOp.getLoc(), zero, alloca); rewriter.setInsertionPointAfter(launchCall); - LLVM::StoreOp::create(rewriter, launchOp.getLoc(), launchCall->getResult(0), - alloca); + rewriter.create(launchOp.getLoc(), launchCall->getResult(0), + alloca); for (auto &block : reg.getRegion()) { if (auto terminator = dyn_cast(block.getTerminator())) { rewriter.setInsertionPointToEnd(&block); - auto load = - LLVM::LoadOp::create(rewriter, launchOp.getLoc(), - launchCall->getResultTypes()[0], alloca); + auto load = rewriter.create( + launchOp.getLoc(), launchCall->getResultTypes()[0], alloca); rewriter.replaceOpWithNewOp(terminator, load->getResults()); } } rewriter.setInsertionPointAfter(errOp); - auto cast = arith::IndexCastOp::create( - rewriter, loc, rewriter.getIndexType(), reg->getResult(0)); + auto cast = rewriter.create( + loc, rewriter.getIndexType(), reg->getResult(0)); rewriter.replaceOp(errOp, cast->getResults()); } @@ -2414,8 +2392,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( uint64_t staticSize = static_cast(bitwidth / 8) * static_cast(memrefTy.getNumElements()); - Value sizeArg = LLVM::ConstantOp::create( - rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize)); + Value sizeArg = rewriter.create( + loc, getIndexType(), rewriter.getIndexAttr(staticSize)); llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer. llvmArgumentsWithSizes.push_back(sizeArg); } @@ -2427,8 +2405,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), adaptor.getClusterSizeZ()}; } - gpu::LaunchFuncOp::create( - rewriter, launchOp.getLoc(), launchOp.getKernelAttr(), + rewriter.create( + launchOp.getLoc(), launchOp.getKernelAttr(), gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeZ()}, gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), @@ -2517,7 +2495,7 @@ class ConvertAllocOpToGpuRuntimeCallPattern // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType); + auto nullPtr = rewriter.create(loc, llvmPointerType); Value stream = adaptor.getAsyncDependencies().empty() ? nullPtr : adaptor.getAsyncDependencies().front(); @@ -2533,10 +2511,10 @@ class ConvertAllocOpToGpuRuntimeCallPattern auto ptr1ty = LLVM::LLVMPointerType::get(rewriter.getContext(), 1); if (backend == "cuda") { - auto one = LLVM::ConstantOp::create(rewriter, loc, i64, - rewriter.getI64IntegerAttr(1)); + auto one = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(1)); - auto ptr = LLVM::AllocaOp::create(rewriter, loc, ptrty, ptr1ty, one); + auto ptr = rewriter.create(loc, ptrty, ptr1ty, one); Type tys[] = {ptrty, i64}; auto cudaMallocFn = LLVM::lookupOrCreateFn(rewriter, moduleOp, "cudaMalloc", tys, i32); @@ -2588,43 +2566,42 @@ class ConvertAllocOpToGpuRuntimeCallPattern sizeBytes, }; allocatedPtr = - LLVM::CallOp::create(rewriter, loc, mallocFunc.value(), args) + rewriter.create(loc, mallocFunc.value(), args) ->getResult(0); allocatedPtr = - LLVM::AddrSpaceCastOp::create(rewriter, loc, ptr1ty, allocatedPtr); + rewriter.create(loc, ptr1ty, allocatedPtr); } else if (backend.starts_with("xla")) { - auto zero = LLVM::ConstantOp::create(rewriter, loc, i64, - rewriter.getI64IntegerAttr(0)); + auto zero = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(0)); - auto one = LLVM::ConstantOp::create(rewriter, loc, i64, - rewriter.getI64IntegerAttr(1)); + auto one = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(1)); - auto tyid = - LLVM::ConstantOp::create(rewriter, loc, i64, - rewriter.getI64IntegerAttr(xla_type_id( - memRefType.getElementType()))); + auto tyid = rewriter.create( + loc, i64, + rewriter.getI64IntegerAttr( + xla_type_id(memRefType.getElementType()))); Type convertedIndex = typeConverter->convertType(rewriter.getIndexType()); - auto shapeDim = LLVM::ConstantOp::create( - rewriter, loc, i64, - rewriter.getI64IntegerAttr(memRefType.getShape().size())); + auto shapeDim = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(memRefType.getShape().size())); auto AT = LLVM::LLVMArrayType::get(i64, memRefType.getShape().size()); - auto shapePtr = LLVM::AllocaOp::create(rewriter, loc, ptrty, AT, one); + auto shapePtr = rewriter.create(loc, ptrty, AT, one); int dynIdx = 0; for (int i = 0; i < memRefType.getShape().size(); i++) { - auto idx = LLVM::ConstantOp::create(rewriter, loc, i64, - rewriter.getI64IntegerAttr(i)); + auto idx = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(i)); Value idxs[] = {zero, idx}; auto gep = - LLVM::GEPOp::create(rewriter, loc, ptrty, AT, shapePtr, idxs); + rewriter.create(loc, ptrty, AT, shapePtr, idxs); Value val; @@ -2632,12 +2609,11 @@ class ConvertAllocOpToGpuRuntimeCallPattern val = adaptor.getDynamicSizes()[dynIdx]; dynIdx++; } else { - val = LLVM::ConstantOp::create( - rewriter, loc, i64, - rewriter.getI64IntegerAttr(memRefType.getShape()[i])); + val = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(memRefType.getShape()[i])); } - LLVM::StoreOp::create(rewriter, loc, val, gep); + rewriter.create(loc, val, gep); } // handle, type id, shape len, shape ptr @@ -2653,19 +2629,19 @@ class ConvertAllocOpToGpuRuntimeCallPattern auto xdata = insertXLAInitDeinit(moduleOp, backend, rewriter); Value args[] = {xdata, tyid, shapeDim, shapePtr}; allocatedPtr = - LLVM::CallOp::create(rewriter, loc, xlaMallocFn.value(), args) + rewriter.create(loc, xlaMallocFn.value(), args) ->getResult(0); allocatedPtr = - LLVM::AddrSpaceCastOp::create(rewriter, loc, ptr1ty, allocatedPtr); + rewriter.create(loc, ptr1ty, allocatedPtr); } else { llvm_unreachable("unknown backend"); } } else { - auto isHostShared = mlir::LLVM::ConstantOp::create( - rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); + auto isHostShared = rewriter.create( + loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); allocatedPtr = allocCallBuilder .create(loc, rewriter, {sizeBytes, stream, isHostShared}) @@ -2744,15 +2720,15 @@ class ConvertOccupancyOp return failure(); } - auto one = LLVM::ConstantOp::create(rewriter, loc, i64, - rewriter.getI64IntegerAttr(1)); + auto one = rewriter.create(loc, i64, + rewriter.getI64IntegerAttr(1)); - auto ptr = LLVM::AllocaOp::create(rewriter, loc, ptrty, intty, one); + auto ptr = rewriter.create(loc, ptrty, intty, one); std::string funcStubName = getFuncStubName(op.getFn().getRootReference().getValue(), op.getFn().getLeafReference().getValue()); - auto addr = LLVM::AddressOfOp::create(rewriter, loc, ptrty, funcStubName); + auto addr = rewriter.create(loc, ptrty, funcStubName); Value args[] = {ptr, addr, adaptor.getBlockSize(), adaptor.getDynamicSMemSize(), adaptor.getFlags()}; rewriter.create(loc, occupancyFn.value(), args); @@ -2835,8 +2811,8 @@ class ConvertDeallocOpToGpuRuntimeCallPattern auto ptr1ty = LLVM::LLVMPointerType::get(rewriter.getContext(), 1); if (backend == "cuda") { - auto one = LLVM::ConstantOp::create(rewriter, loc, i64, - rewriter.getI64IntegerAttr(1)); + auto one = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(1)); Type tys[] = {ptr1ty}; auto cudaFreeFn = @@ -2877,7 +2853,7 @@ class ConvertDeallocOpToGpuRuntimeCallPattern Value args[] = { ptr, }; - LLVM::CallOp::create(rewriter, loc, freeFunc.value(), args); + rewriter.create(loc, freeFunc.value(), args); } else if (backend.starts_with("xla")) { auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); @@ -2896,7 +2872,7 @@ class ConvertDeallocOpToGpuRuntimeCallPattern Value args[] = {xdata, ptr}; - LLVM::CallOp::create(rewriter, loc, xlaFreeFn.value(), args); + rewriter.create(loc, xlaFreeFn.value(), args); } else { llvm::errs() << " unknown backend: " << backend << "\n"; return failure(); @@ -2944,28 +2920,27 @@ class ConvertXLAWrapperPattern auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto zero = LLVM::ConstantOp::create(rewriter, loc, i64, - rewriter.getI64IntegerAttr(0)); + auto zero = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(0)); - auto one = LLVM::ConstantOp::create(rewriter, loc, i64, - rewriter.getI64IntegerAttr(1)); + auto one = rewriter.create(loc, i64, + rewriter.getI64IntegerAttr(1)); - auto nargs = LLVM::ConstantOp::create( - rewriter, loc, i64, - rewriter.getI64IntegerAttr(adaptor.getInputs().size())); + auto nargs = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(adaptor.getInputs().size())); auto AT = LLVM::LLVMArrayType::get(i64, adaptor.getInputs().size()); - auto argsPtr = LLVM::AllocaOp::create(rewriter, loc, ptrty, AT, one); + auto argsPtr = rewriter.create(loc, ptrty, AT, one); for (int i = 0; i < adaptor.getInputs().size(); i++) { - auto idx = LLVM::ConstantOp::create(rewriter, loc, i64, - rewriter.getI64IntegerAttr(i)); + auto idx = rewriter.create( + loc, i64, rewriter.getI64IntegerAttr(i)); Value idxs[] = {zero, idx}; - auto gep = LLVM::GEPOp::create(rewriter, loc, ptrty, AT, argsPtr, idxs); + auto gep = rewriter.create(loc, ptrty, AT, argsPtr, idxs); - LLVM::StoreOp::create(rewriter, loc, adaptor.getInputs()[i], gep); + rewriter.create(loc, adaptor.getInputs()[i], gep); } // handle, module, nargs, argptr @@ -2983,7 +2958,7 @@ class ConvertXLAWrapperPattern auto xdata = insertXLAInitDeinit(moduleOp, backend, rewriter); Value args[4] = {xdata, stringval, nargs, argsPtr}; - LLVM::CallOp::create(rewriter, loc, xlaExecFn.value(), args); + rewriter.create(loc, xlaExecFn.value(), args); wrap.setFnAttr( FlatSymbolRefAttr::get(rewriter.getStringAttr(""))); @@ -3022,21 +2997,21 @@ struct ReplaceErrOpWithSuccess : public OpRewritePattern { auto ®ion = errOp.getRegion(); rewriter.setInsertionPointToEnd(condBlock); - cf::BranchOp::create(rewriter, errOp.getLoc(), ®ion.front()); + rewriter.create(errOp.getLoc(), ®ion.front()); for (Block &block : errOp->getRegions()[0]) { if (auto terminator = dyn_cast(block.getTerminator())) { ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(&block); - cf::BranchOp::create(rewriter, errOp.getLoc(), remainingOpsBlock, - terminatorOperands); + rewriter.create(errOp.getLoc(), remainingOpsBlock, + terminatorOperands); rewriter.eraseOp(terminator); } } rewriter.inlineRegionBefore(region, remainingOpsBlock); } - auto zero = arith::ConstantIndexOp::create(rewriter, errOp->getLoc(), 0); + auto zero = rewriter.create(errOp->getLoc(), 0); rewriter.replaceOp(errOp, zero->getResults()); return success(); } @@ -3080,8 +3055,8 @@ struct GPUFuncOpLowering : public ConvertOpToLLVMPattern { auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); std::string name = std::string( llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); - auto globalOp = LLVM::GlobalOp::create( - rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, + auto globalOp = rewriter.create( + gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, LLVM::Linkage::Internal, name, /*value=*/Attribute(), /*alignment=*/0, static_cast(gpu::GPUDialect::getWorkgroupAddressSpace())); @@ -3111,8 +3086,8 @@ struct GPUFuncOpLowering : public ConvertOpToLLVMPattern { // latter is expected by gpu.launch_func. if (gpuFuncOp.isKernel()) attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); - auto llvmFuncOp = LLVM::LLVMFuncOp::create( - rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, + auto llvmFuncOp = rewriter.create( + gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, /*comdat=*/nullptr, attributes); @@ -3130,7 +3105,7 @@ struct GPUFuncOpLowering : public ConvertOpToLLVMPattern { for (const auto &en : llvm::enumerate(workgroupBuffers)) { LLVM::GlobalOp global = en.value(); - Value memory = LLVM::AddressOfOp::create(rewriter, loc, global); + Value memory = rewriter.create(loc, global); // Build a memref descriptor pointing to the buffer to plug with the // existing memref infrastructure. This may use more registers than @@ -3159,12 +3134,11 @@ struct GPUFuncOpLowering : public ConvertOpToLLVMPattern { // memory space and does not support `alloca`s with addrspace(5). auto ptrType = LLVM::LLVMPointerType::get(type.getContext(), allocaAddrSpace); - Value numElements = LLVM::ConstantOp::create( - rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); - Value allocated = - LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType, - type.getElementType(), numElements, - /*alignment=*/0); + Value numElements = rewriter.create( + gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); + Value allocated = rewriter.create( + gpuFuncOp.getLoc(), ptrType, type.getElementType(), numElements, + /*alignment=*/0); Value descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, allocated); signatureConversion.remapInput( @@ -3243,8 +3217,8 @@ struct FuncOpLowering : public ConvertOpToLLVMPattern { cast(funcOp->getAttr(kLLVMCConvAttrName)); cconv = attr.getCallingConv(); } - auto newFuncOp = LLVM::LLVMFuncOp::create( - rewriter, funcOp.getLoc(), funcOp.getName(), convertedType, linkage, + auto newFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), convertedType, linkage, /*dsoLocal=*/false, /*cconv=*/cconv, /*comdat=*/nullptr, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); @@ -3278,9 +3252,9 @@ struct CallOpLowering : public ConvertOpToLLVMPattern { } } - auto newCallOp = - LLVM::CallOp::create(rewriter, callOp->getLoc(), callResultTypes, - callOp.getCallee(), adaptor.getOperands()); + auto newCallOp = rewriter.create( + callOp->getLoc(), callResultTypes, callOp.getCallee(), + adaptor.getOperands()); newCallOp->setAttrs(callOp->getAttrs()); if (numResults <= 1) { @@ -3291,8 +3265,8 @@ struct CallOpLowering : public ConvertOpToLLVMPattern { SmallVector results; results.reserve(numResults); for (auto index : llvm::seq(0, numResults)) { - results.push_back(LLVM::ExtractValueOp::create( - rewriter, callOp->getLoc(), newCallOp->getResult(0), index)); + results.push_back(rewriter.create( + callOp->getLoc(), newCallOp->getResult(0), index)); } rewriter.replaceOp(callOp, results); return success(); @@ -3317,10 +3291,10 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { returnOp->getContext(), llvm::to_vector(adaptor.getOperands().getTypes())); Value packed = - LLVM::UndefOp::create(rewriter, returnOp->getLoc(), returnedType); + rewriter.create(returnOp->getLoc(), returnedType); for (const auto &[index, value] : llvm::enumerate(adaptor.getOperands())) { - packed = LLVM::InsertValueOp::create(rewriter, returnOp->getLoc(), packed, - value, index); + packed = rewriter.create(returnOp->getLoc(), packed, + value, index); } rewriter.replaceOpWithNewOp(returnOp, packed); return success(); @@ -3345,10 +3319,10 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern { returnOp->getContext(), llvm::to_vector(adaptor.getOperands().getTypes())); Value packed = - LLVM::UndefOp::create(rewriter, returnOp->getLoc(), returnedType); + rewriter.create(returnOp->getLoc(), returnedType); for (const auto &[index, value] : llvm::enumerate(adaptor.getOperands())) { - packed = LLVM::InsertValueOp::create(rewriter, returnOp->getLoc(), packed, - value, index); + packed = rewriter.create(returnOp->getLoc(), packed, + value, index); } rewriter.replaceOpWithNewOp(returnOp, packed); return success(); @@ -3399,7 +3373,7 @@ struct AllocaScopeOpLowering remainingOpsBlock, allocaScopeOp.getResultTypes(), SmallVector(allocaScopeOp->getNumResults(), allocaScopeOp.getLoc())); - LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock); + rewriter.create(loc, ValueRange(), remainingOpsBlock); } // Inline body region. @@ -3410,8 +3384,8 @@ struct AllocaScopeOpLowering // Save stack and then branch into the body of the region. rewriter.setInsertionPointToEnd(currentBlock); auto stackSaveOp = - LLVM::StackSaveOp::create(rewriter, loc, getVoidPtrType()); - LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody); + rewriter.create(loc, getVoidPtrType()); + rewriter.create(loc, ValueRange(), beforeBody); // Replace the alloca_scope return with a branch that jumps out of the body. // Stack restore before leaving the body region. @@ -3423,7 +3397,7 @@ struct AllocaScopeOpLowering // Insert stack restore before jumping out the body of the region. rewriter.setInsertionPoint(branchOp); - LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp); + rewriter.create(loc, stackSaveOp); // Replace the op with values return from the body region. rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); @@ -3491,22 +3465,22 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { auto int32Type = IntegerType::get(rewriter.getContext(), 32); auto predTy = IntegerType::get(rewriter.getContext(), 1); - Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1); - Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1); - Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32); - Value numLeadInactiveLane = LLVM::SubOp::create( - rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth()); + Value one = rewriter.create(loc, int32Type, 1); + Value minusOne = rewriter.create(loc, int32Type, -1); + Value thirtyTwo = rewriter.create(loc, int32Type, 32); + Value numLeadInactiveLane = rewriter.create( + loc, int32Type, thirtyTwo, adaptor.getWidth()); // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`. - Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne, - numLeadInactiveLane); + Value activeMask = rewriter.create(loc, int32Type, minusOne, + numLeadInactiveLane); Value maskAndClamp; if (op.getMode() == gpu::ShuffleMode::UP) { // Clamp lane: `32 - activeWidth` maskAndClamp = numLeadInactiveLane; } else { // Clamp lane: `activeWidth - 1` - maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type, - adaptor.getWidth(), one); + maskAndClamp = + rewriter.create(loc, int32Type, adaptor.getWidth(), one); } bool predIsUsed = !op->getResult(1).use_empty(); @@ -3517,14 +3491,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), {valueTy, predTy}); } - Value shfl = NVVM::ShflOp::create( - rewriter, loc, resultTy, activeMask, adaptor.getValue(), - adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()), - returnValueAndIsValidAttr); + Value shfl = rewriter.create( + loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(), + maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr); if (predIsUsed) { - Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0); + Value shflValue = rewriter.create(loc, shfl, 0); Value isActiveSrcLane = - LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1); + rewriter.create(loc, shfl, 1); rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); } else { rewriter.replaceOp(op, {shfl, nullptr}); @@ -3549,16 +3522,16 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern { bounds = rewriter.getAttr( /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize); Value newOp = - NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds); + rewriter.create(loc, rewriter.getI32Type(), bounds); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); if (indexBitwidth > 32) { - newOp = LLVM::SExtOp::create( - rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); } else if (indexBitwidth < 32) { - newOp = LLVM::TruncOp::create( - rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); } rewriter.replaceOp(op, {newOp}); return success(); @@ -3623,13 +3596,13 @@ struct OpLowering : public OpConversionPattern { Operation *newOp; switch (op.getDimension()) { case gpu::Dimension::x: - newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32)); + newOp = rewriter.create(loc, IntegerType::get(context, 32)); break; case gpu::Dimension::y: - newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32)); + newOp = rewriter.create(loc, IntegerType::get(context, 32)); break; case gpu::Dimension::z: - newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32)); + newOp = rewriter.create(loc, IntegerType::get(context, 32)); break; } @@ -3688,13 +3661,11 @@ struct OpLowering : public OpConversionPattern { rewriter.getContext(), 32, min, max)); } if (indexBitwidth > 32) { - newOp = LLVM::SExtOp::create(rewriter, loc, - IntegerType::get(context, indexBitwidth), - newOp->getResult(0)); + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); } else if (indexBitwidth < 32) { - newOp = LLVM::TruncOp::create(rewriter, loc, - IntegerType::get(context, indexBitwidth), - newOp->getResult(0)); + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); } rewriter.replaceOpWithNewOp( @@ -4042,29 +4013,6 @@ static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) { auto voidTy = LLVM::LLVMVoidType::get(ctx); auto ptrTy = LLVM::LLVMPointerType::get(ctx); - auto resumeOp = LLVM::LLVMFuncOp::create( - moduleBuilder, fname, - LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); - resumeOp.setPrivate(); - - return resumeOp; -} - -static LLVM::LLVMFuncOp addMocROCmFunction(ModuleOp module, Type streamTy) { - const char fname[] = "fake_rocm_dispatch"; - - MLIRContext *ctx = module->getContext(); - auto loc = module->getLoc(); - auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); - - for (auto fn : module.getBody()->getOps()) { - if (fn.getName() == fname) - return fn; - } - - auto voidTy = LLVM::LLVMVoidType::get(ctx); - auto ptrTy = LLVM::LLVMPointerType::get(ctx); - auto resumeOp = moduleBuilder.create( fname, LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); resumeOp.setPrivate(); @@ -4101,7 +4049,7 @@ struct NoAsyncOpLowering : public OpConversionPattern { matchAndRewrite(async::ExecuteOp execute, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto exec = - scf::ExecuteRegionOp::create(rewriter, execute->getLoc(), TypeRange()); + rewriter.create(execute->getLoc(), TypeRange()); rewriter.inlineRegionBefore(execute.getRegion(), exec.getRegion(), exec.getRegion().begin()); for (auto &blk : exec.getRegion()) { @@ -4169,8 +4117,8 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { rewriter.setInsertionPointToEnd(module.getBody()); static int off = 0; off++; - func = LLVM::LLVMFuncOp::create( - rewriter, execute.getLoc(), + func = rewriter.create( + execute.getLoc(), "kernelbody." + std::to_string((long long int)&execute) + "." + std::to_string(off), funcType); @@ -4199,50 +4147,49 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { converter->convertType(functionInputs[0].getType()))) { valueMapping.map( functionInputs[0], - LLVM::BitcastOp::create( - rewriter, execute.getLoc(), + rewriter.create( + execute.getLoc(), converter->convertType(functionInputs[0].getType()), arg)); } else if (functionInputs.size() == 1 && isa( converter->convertType(functionInputs[0].getType()))) { valueMapping.map( functionInputs[0], - LLVM::PtrToIntOp::create( - rewriter, execute.getLoc(), + rewriter.create( + execute.getLoc(), converter->convertType(functionInputs[0].getType()), arg)); } else { SmallVector types; for (auto v : functionInputs) types.push_back(converter->convertType(v.getType())); auto ST = LLVM::LLVMStructType::getLiteral(ctx, types); - auto alloc = LLVM::BitcastOp::create( - rewriter, execute.getLoc(), LLVM::LLVMPointerType::get(ctx), arg); + auto alloc = rewriter.create( + execute.getLoc(), LLVM::LLVMPointerType::get(ctx), arg); for (auto idx : llvm::enumerate(functionInputs)) { mlir::Value idxs[] = { - arith::ConstantIntOp::create(rewriter, loc, 0, 32), - arith::ConstantIntOp::create(rewriter, loc, idx.index(), 32), + rewriter.create(loc, 0, 32), + rewriter.create(loc, idx.index(), 32), }; - Value next = LLVM::GEPOp::create(rewriter, loc, - LLVM::LLVMPointerType::get(ctx), + Value next = + rewriter.create(loc, LLVM::LLVMPointerType::get(ctx), idx.value().getType(), alloc, idxs); - valueMapping.map( - idx.value(), - LLVM::LoadOp::create(rewriter, loc, idx.value().getType(), next)); + valueMapping.map(idx.value(), rewriter.create( + loc, idx.value().getType(), next)); } auto freef = getTypeConverter()->getOptions().useGenericFunctions ? LLVM::lookupOrCreateGenericFreeFn(rewriter, module) : LLVM::lookupOrCreateFreeFn(rewriter, module); Value args[] = {arg}; - LLVM::CallOp::create(rewriter, loc, freef.value(), args); + rewriter.create(loc, freef.value(), args); } // Clone all operations from the execute operation body into the outlined // function body. rewriter.cloneRegionBefore(execute.getBodyRegion(), func.getRegion(), func.getRegion().end(), valueMapping); - LLVM::BrOp::create(rewriter, execute.getLoc(), ValueRange(), - &*std::next(func.getRegion().begin())); + rewriter.create(execute.getLoc(), ValueRange(), + &*std::next(func.getRegion().begin())); for (Block &b : func.getRegion()) { auto term = b.getTerminator(); if (isa(term)) { @@ -4264,17 +4211,17 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { SmallVector vals; if (crossing.size() == 0) { vals.push_back( - LLVM::ZeroOp::create(rewriter, execute.getLoc(), voidPtr)); + rewriter.create(execute.getLoc(), voidPtr)); } else if (crossing.size() == 1 && isa( converter->convertType(crossing[0].getType()))) { - vals.push_back(LLVM::BitcastOp::create(rewriter, execute.getLoc(), - voidPtr, crossing[0])); + vals.push_back(rewriter.create(execute.getLoc(), + voidPtr, crossing[0])); } else if (crossing.size() == 1 && isa( converter->convertType(crossing[0].getType()))) { - vals.push_back(LLVM::IntToPtrOp::create(rewriter, execute.getLoc(), - voidPtr, crossing[0])); + vals.push_back(rewriter.create(execute.getLoc(), + voidPtr, crossing[0])); } else { SmallVector types; for (auto v : crossing) @@ -4286,36 +4233,36 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { DataLayout DLI(execute->getParentOfType()); - Value arg = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI64Type(), DLI.getTypeSize(ST)); + Value arg = rewriter.create( + loc, rewriter.getI64Type(), DLI.getTypeSize(ST)); auto mallocFunc = LLVM::lookupOrCreateMallocFn(rewriter, module, getIndexType()); mlir::Value alloc = - LLVM::CallOp::create(rewriter, loc, mallocFunc.value(), arg) + rewriter.create(loc, mallocFunc.value(), arg) .getResult(); rewriter.setInsertionPoint(execute); for (auto idx : llvm::enumerate(crossing)) { mlir::Value idxs[] = { - arith::ConstantIntOp::create(rewriter, loc, 0, 32), - arith::ConstantIntOp::create(rewriter, loc, idx.index(), 32), + rewriter.create(loc, 0, 32), + rewriter.create(loc, idx.index(), 32), }; - Value next = LLVM::GEPOp::create( - rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), + Value next = rewriter.create( + loc, LLVM::LLVMPointerType::get(rewriter.getContext()), idx.value().getType(), alloc, idxs); - LLVM::StoreOp::create(rewriter, loc, idx.value(), next); + rewriter.create(loc, idx.value(), next); } - vals.push_back(LLVM::BitcastOp::create(rewriter, execute.getLoc(), - voidPtr, alloc)); + vals.push_back( + rewriter.create(execute.getLoc(), voidPtr, alloc)); } - vals.push_back(LLVM::BitcastOp::create( - rewriter, execute.getLoc(), voidPtr, - LLVM::AddressOfOp::create(rewriter, execute.getLoc(), func))); + vals.push_back(rewriter.create( + execute.getLoc(), voidPtr, + rewriter.create(execute.getLoc(), func))); for (auto dep : execute.getDependencies()) { auto src = dep.getDefiningOp().getSource(); if (auto MT = dyn_cast(src.getType())) - src = enzymexla::Memref2PointerOp::create( - rewriter, dep.getDefiningOp()->getLoc(), + src = rewriter.create( + dep.getDefiningOp()->getLoc(), LLVM::LLVMPointerType::get(rewriter.getContext(), MT.getMemorySpaceAsInt()), src); @@ -4332,7 +4279,7 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { : addMocROCmFunction(execute->getParentOfType(), vals.back().getType()); - LLVM::CallOp::create(rewriter, execute.getLoc(), f, vals); + rewriter.create(execute.getLoc(), f, vals); rewriter.eraseOp(execute); } @@ -4671,24 +4618,6 @@ struct ConvertPolygeistToLLVMPass signalPassFailure(); return; } - - { - const char *GetDeviceFromHostFuncName = "__reactant$get_device_from_host"; - SmallVector toHandle; - m->walk([&](LLVM::CallOp call) { - CallInterfaceCallable callable = call.getCallableForCallee(); - auto callee = dyn_cast(callable); - if (!callee) - return; - if (callee.getLeafReference() == GetDeviceFromHostFuncName) - toHandle.push_back(call); - }); - for (auto call : toHandle) { - assert(call->getNumResults() == 1 && call.getNumOperands() == 1); - call.getResult().replaceAllUsesWith(call.getArgOperands()[0]); - call->erase(); - } - } } void runOnOperation() override { @@ -4696,4 +4625,4 @@ struct ConvertPolygeistToLLVMPass convertModule(m, /* gpuModule */ false); } }; -} // namespace +} // namespace \ No newline at end of file From c2e48f4cd73b6abed6c9215731fddb8fad5886e5 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 6 Nov 2025 04:27:28 -0600 Subject: [PATCH 09/27] fix --- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 852 +++++++++--------- 1 file changed, 436 insertions(+), 416 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index db6079492d..26b5ef0be3 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -133,40 +133,40 @@ static Value insertXLAInitDeinit(mlir::ModuleOp moduleOp, StringRef backend, if (ctor) { assert(dtor && "xla module constructor does not exist but destructor does"); assert(data && "xla module constructor does not exist but data does"); - return rewriter.create(loc, ptrty, - data.getSymNameAttr()); + return LLVM::AddressOfOp::create(rewriter, loc, ptrty, + data.getSymNameAttr()); } { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - ctor = rewriter.create( - loc, ctorNameBuffer, + ctor = LLVM::LLVMFuncOp::create( + rewriter, loc, ctorNameBuffer, LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(moduleOp.getContext()), {}), LLVM::Linkage::Private); - dtor = rewriter.create( - loc, dtorNameBuffer, + dtor = LLVM::LLVMFuncOp::create( + rewriter, loc, dtorNameBuffer, LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(moduleOp.getContext()), {}), LLVM::Linkage::Private); auto ctorSymbol = FlatSymbolRefAttr::get(ctor); - rewriter.create( - loc, rewriter.getArrayAttr({std::move(ctorSymbol)}), + LLVM::GlobalCtorsOp::create( + rewriter, loc, rewriter.getArrayAttr({std::move(ctorSymbol)}), rewriter.getI32ArrayAttr({65535}), rewriter.getArrayAttr({LLVM::ZeroAttr::get(rewriter.getContext())})); auto dtorSymbol = FlatSymbolRefAttr::get(dtor); - rewriter.create( - loc, rewriter.getArrayAttr({std::move(dtorSymbol)}), + LLVM::GlobalDtorsOp::create( + rewriter, loc, rewriter.getArrayAttr({std::move(dtorSymbol)}), rewriter.getI32ArrayAttr({65535}), rewriter.getArrayAttr({LLVM::ZeroAttr::get(rewriter.getContext())})); - data = rewriter.create( - loc, ptrty, /*constant*/ false, LLVM::Linkage::Internal, dataNameBuffer, - /* initValue */ mlir::Attribute(), - /* alignment */ 8, /* addrSpace */ 0); + data = LLVM::GlobalOp::create(rewriter, loc, ptrty, /*constant*/ false, + LLVM::Linkage::Internal, dataNameBuffer, + /* initValue */ mlir::Attribute(), + /* alignment */ 8, /* addrSpace */ 0); } // device id, ptr @@ -201,10 +201,10 @@ static Value insertXLAInitDeinit(mlir::ModuleOp moduleOp, StringRef backend, loc, rewriter, "xlabackend", bstr, LLVM::Linkage::Internal); auto glob = - rewriter.create(loc, ptrty, data.getSymNameAttr()); + LLVM::AddressOfOp::create(rewriter, loc, ptrty, data.getSymNameAttr()); Value args[] = {glob, stringval}; - rewriter.create(loc, xlaInitFn.value(), args); - rewriter.create(loc, ValueRange()); + LLVM::CallOp::create(rewriter, loc, xlaInitFn.value(), args); + LLVM::ReturnOp::create(rewriter, loc, ValueRange()); } { @@ -212,13 +212,13 @@ static Value insertXLAInitDeinit(mlir::ModuleOp moduleOp, StringRef backend, rewriter.setInsertionPointToEnd(dtor.addEntryBlock(rewriter)); auto glob = - rewriter.create(loc, ptrty, data.getSymNameAttr()); + LLVM::AddressOfOp::create(rewriter, loc, ptrty, data.getSymNameAttr()); Value args[] = {glob}; - rewriter.create(loc, xlaDeInitFn.value(), args); - rewriter.create(loc, ValueRange()); + LLVM::CallOp::create(rewriter, loc, xlaDeInitFn.value(), args); + LLVM::ReturnOp::create(rewriter, loc, ValueRange()); } - return rewriter.create(loc, ptrty, data.getSymNameAttr()); + return LLVM::AddressOfOp::create(rewriter, loc, ptrty, data.getSymNameAttr()); } struct Stream2TokenOpLowering : public ConvertOpToLLVMPattern { @@ -246,7 +246,7 @@ struct Memref2PointerOpLowering if (isa(transformed.getSource().getType())) { mlir::Value ptr = transformed.getSource(); if (space0 != LPT.getAddressSpace()) - ptr = rewriter.create(loc, LPT, ptr); + ptr = LLVM::AddrSpaceCastOp::create(rewriter, loc, LPT, ptr); rewriter.replaceOp(op, {ptr}); return success(); } @@ -260,10 +260,10 @@ struct Memref2PointerOpLowering Value baseOffset = targetMemRef.offset(rewriter, loc); Value ptr = targetMemRef.alignedPtr(rewriter, loc); Value idxs[] = {baseOffset}; - ptr = rewriter.create(loc, ptr.getType(), rewriter.getI8Type(), - ptr, idxs); + ptr = LLVM::GEPOp::create(rewriter, loc, ptr.getType(), + rewriter.getI8Type(), ptr, idxs); if (space0 != LPT.getAddressSpace()) - ptr = rewriter.create(loc, LPT, ptr); + ptr = LLVM::AddrSpaceCastOp::create(rewriter, loc, LPT, ptr); rewriter.replaceOp(op, {ptr}); return success(); @@ -287,7 +287,7 @@ struct Pointer2MemrefOpLowering mlir::Value ptr = adaptor.getSource(); if (space1 != cast(op.getOperand().getType()) .getAddressSpace()) - ptr = rewriter.create(loc, PT, ptr); + ptr = LLVM::AddrSpaceCastOp::create(rewriter, loc, PT, ptr); rewriter.replaceOp(op, {ptr}); return success(); } @@ -297,8 +297,8 @@ struct Pointer2MemrefOpLowering if (space1 != cast(op.getOperand().getType()) .getAddressSpace()) - ptr = rewriter.create( - loc, descr.getElementPtrType(), ptr); + ptr = LLVM::AddrSpaceCastOp::create(rewriter, loc, + descr.getElementPtrType(), ptr); // Extract all strides and offsets and verify they are static. int64_t offset; @@ -575,7 +575,7 @@ struct CAllocOpLowering : public AllocLikeOpLowering { if (auto F = module.lookupSymbol("malloc")) { Value allocated = - rewriter.create(loc, F, sizeBytes).getResult(0); + func::CallOp::create(rewriter, loc, F, sizeBytes).getResult(0); rewriter.replaceOpWithNewOp( allocOp, convertedType, allocated); } else { @@ -587,7 +587,7 @@ struct CAllocOpLowering : public AllocLikeOpLowering { if (failed(mallocFunc)) return failure(); Value allocated = - rewriter.create(loc, mallocFunc.value(), sizeBytes) + LLVM::CallOp::create(rewriter, loc, mallocFunc.value(), sizeBytes) .getResult(); rewriter.replaceOpWithNewOp(allocOp, convertedType, allocated); @@ -606,9 +606,9 @@ struct CDeallocOpLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto module = deallocOp->getParentOfType(); if (auto F = module.lookupSymbol("free")) { - Value casted = rewriter.create( - deallocOp->getLoc(), MemRefType::get({-1}, rewriter.getI8Type()), - adaptor.getMemref()); + Value casted = enzymexla::Pointer2MemrefOp::create( + rewriter, deallocOp->getLoc(), + MemRefType::get({-1}, rewriter.getI8Type()), adaptor.getMemref()); rewriter.replaceOpWithNewOp(deallocOp, F, casted); } else { FailureOr freeFunc = @@ -691,8 +691,8 @@ struct GlobalOpLowering : public ConvertOpToLLVMPattern { newGlobal.getInitializerRegion().begin()); rewriter.setInsertionPointToStart(block); Value undef = - rewriter.create(globalOp->getLoc(), convertedType); - rewriter.create(globalOp->getLoc(), undef); + LLVM::UndefOp::create(rewriter, globalOp->getLoc(), convertedType); + LLVM::ReturnOp::create(rewriter, globalOp->getLoc(), undef); } return success(); } @@ -709,8 +709,8 @@ struct GetGlobalOpLowering ConversionPatternRewriter &rewriter) const override { MemRefType originalType = getGlobalOp.getType(); Type convertedType = getTypeConverter()->convertType(originalType); - Value wholeAddress = rewriter.create( - getGlobalOp->getLoc(), convertedType, getGlobalOp.getName()); + Value wholeAddress = LLVM::AddressOfOp::create( + rewriter, getGlobalOp->getLoc(), convertedType, getGlobalOp.getName()); rewriter.replaceOp(getGlobalOp, wholeAddress); return success(); @@ -744,8 +744,8 @@ struct CLoadStoreOpLowering : public ConvertOpToLLVMPattern { (void)rewriter.notifyMatchFailure(loc, "unsupported memref type"); return nullptr; } - return rewriter.create( - loc, + return LLVM::GEPOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(op.getContext(), originalType.getMemorySpaceAsInt()), elTy, adaptor.getMemref(), args); @@ -890,16 +890,18 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { if (dstType.getMemorySpaceAsInt() == 0 && srcType.getMemorySpaceAsInt() == 0) { - rewriter.create(op.getLoc(), dst, src, size, false); + LLVM::MemcpyOp::create(rewriter, op.getLoc(), dst, src, size, false); rewriter.eraseOp(op); return success(); } if (backend == "cpu") { - dst = rewriter.create( - op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), dst); - src = rewriter.create( - op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), src); - rewriter.create(op.getLoc(), dst, src, size, false); + dst = LLVM::AddrSpaceCastOp::create( + rewriter, op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), + dst); + src = LLVM::AddrSpaceCastOp::create( + rewriter, op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), + src); + LLVM::MemcpyOp::create(rewriter, op.getLoc(), dst, src, size, false); rewriter.eraseOp(op); return success(); } @@ -954,19 +956,20 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { } SmallVector args = {dst, src, size, - rewriter.create( - op.getLoc(), tys[3 + xla], direction)}; + LLVM::ConstantOp::create(rewriter, op.getLoc(), + tys[3 + xla], + direction)}; for (int i = 0; i < 2; i++) if (args[i].getType() != tys[i]) - args[i] = rewriter.create(op.getLoc(), - tys[i + xla], args[i]); + args[i] = LLVM::AddrSpaceCastOp::create(rewriter, op.getLoc(), + tys[i + xla], args[i]); if (backend.starts_with("xla")) { auto xdata = insertXLAInitDeinit(moduleOp, backend, rewriter); args.insert(args.begin(), xdata); } - rewriter.create(op.getLoc(), memcpyFn.value(), args); + LLVM::CallOp::create(rewriter, op.getLoc(), memcpyFn.value(), args); rewriter.eraseOp(op); return success(); } @@ -1558,7 +1561,7 @@ struct LowerGPUAlternativesOp auto kernelId = LLVM::createGlobalString( loc, rewriter, std::string("kernelId.") + std::to_string(num++), nullTermLocStr, LLVM::Linkage::Internal, /*opaquePointers*/ true); - auto totalAlternatives = rewriter.create( + auto totalAlternatives = LLVM::ConstantOp::create(rewriter, loc, llvmInt32Type, gao->getNumRegions()); auto alternative = rtPGOGetAlternativeCallBuilder @@ -1567,10 +1570,10 @@ struct LowerGPUAlternativesOp int i = 0; for (auto ®ion : gao->getRegions()) { - auto cmpOp = rewriter.create( + auto cmpOp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, alternative, - rewriter.create(loc, i, 32)); - auto ifOp = rewriter.create(loc, cmpOp, /* hasElse */ true); + arith::ConstantIntOp::create(rewriter, loc, i, 32)); + auto ifOp = scf::IfOp::create(rewriter, loc, cmpOp, /* hasElse */ true); auto block = ®ion.front(); rewriter.eraseOp(block->getTerminator()); rewriter.inlineBlockBefore( @@ -1690,31 +1693,46 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( argumentTypes.push_back(argument.getType()); auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(), argumentTypes); - Value structPtr, arrayPtr; + Value structPtr, arrayPtr, one; { PatternRewriter::InsertionGuard B(builder); builder.setInsertionPointToStart(allocaBlock); - auto one = builder.create(loc, llvmInt32Type, 1); - structPtr = builder.create( - loc, LLVM::LLVMPointerType::get(builder.getContext()), structType, one, + one = LLVM::ConstantOp::create(builder, loc, llvmInt32Type, 1); + structPtr = LLVM::AllocaOp::create( + builder, loc, LLVM::LLVMPointerType::get(builder.getContext()), + structType, one, /*alignment=*/0); auto arraySize = - builder.create(loc, llvmInt32Type, numArguments); - arrayPtr = builder.create(loc, llvmPointerPointerType, - llvmPointerType, arraySize, - /*alignment=*/0); + LLVM::ConstantOp::create(builder, loc, llvmInt32Type, numArguments); + arrayPtr = LLVM::AllocaOp::create(builder, loc, llvmPointerPointerType, + llvmPointerType, arraySize, + /*alignment=*/0); } + auto argAttrss = + dyn_cast_or_null(launchOp->getAttr("reactant.arg_attrs")); for (const auto &en : llvm::enumerate(arguments)) { - auto fieldPtr = builder.create( - loc, LLVM::LLVMPointerType::get(builder.getContext()), structType, - structPtr, ArrayRef{0, en.index()}); - builder.create(loc, en.value(), fieldPtr); - auto elementPtr = builder.create( - loc, llvmPointerType, llvmPointerPointerType, arrayPtr, - ArrayRef{en.index()}); + bool isByVal = + argAttrss && cast(argAttrss[en.index()]) + .getNamed(LLVM::LLVMDialect::getByValAttrName()); + Value fieldPtr; + if (isByVal) { + fieldPtr = en.value(); + } else { + { + PatternRewriter::InsertionGuard B(builder); + builder.setInsertionPointToStart(allocaBlock); + fieldPtr = LLVM::AllocaOp::create(builder, loc, llvmPointerPointerType, + en.value().getType(), one, + /*alignment=*/0); + } + LLVM::StoreOp::create(builder, loc, en.value(), fieldPtr); + } + auto elementPtr = LLVM::GEPOp::create(builder, loc, llvmPointerType, + llvmPointerPointerType, arrayPtr, + ArrayRef{en.index()}); auto casted = - builder.create(loc, llvmPointerType, fieldPtr); - builder.create(loc, casted, elementPtr); + LLVM::BitcastOp::create(builder, loc, llvmPointerType, fieldPtr); + LLVM::StoreOp::create(builder, loc, casted, elementPtr); } return arrayPtr; } @@ -1765,13 +1783,13 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - ctor = rewriter.create( - ctorloc, ctorNameBuffer, + ctor = LLVM::LLVMFuncOp::create( + rewriter, ctorloc, ctorNameBuffer, LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(moduleOp.getContext()), {}), LLVM::Linkage::Private); - dtor = rewriter.create( - ctorloc, dtorNameBuffer, + dtor = LLVM::LLVMFuncOp::create( + rewriter, ctorloc, dtorNameBuffer, LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(moduleOp.getContext()), {}), LLVM::Linkage::Private); @@ -1848,8 +1866,9 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - fatBinWrapper = rewriter.create( - loc, fatBinWrapperType, /*constant*/ true, LLVM::Linkage::Internal, + fatBinWrapper = LLVM::GlobalOp::create( + rewriter, loc, fatBinWrapperType, /*constant*/ true, + LLVM::Linkage::Internal, std::string( llvm::formatv("__polygeist_{0}_fatbin_wrapper", moduleName)), /* initValue */ mlir::Attribute(), @@ -1861,52 +1880,39 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, fatBinWrapper.getRegion().push_back(new Block); globalBuilder.setInsertionPointToStart(fatBinWrapper.getBody()); auto fatbinMagicVal = - globalBuilder.create(loc, llvmInt32Type, fatMagic); + LLVM::ConstantOp::create(globalBuilder, loc, llvmInt32Type, fatMagic); auto fatbinVersionVal = - globalBuilder.create(loc, llvmInt32Type, 1); - auto nullPtr = globalBuilder.create(loc, llvmPointerType); + LLVM::ConstantOp::create(globalBuilder, loc, llvmInt32Type, 1); + auto nullPtr = LLVM::ZeroOp::create(globalBuilder, loc, llvmPointerType); Value constructedStruct = - globalBuilder.create(loc, fatBinWrapperType); + LLVM::UndefOp::create(globalBuilder, loc, fatBinWrapperType); { int i = 0; - constructedStruct = globalBuilder.create( - loc, fatBinWrapperType, constructedStruct, fatbinMagicVal, - globalBuilder.getDenseI64ArrayAttr(i++)); - constructedStruct = globalBuilder.create( - loc, fatBinWrapperType, constructedStruct, fatbinVersionVal, - globalBuilder.getDenseI64ArrayAttr(i++)); + constructedStruct = LLVM::InsertValueOp::create( + globalBuilder, loc, fatBinWrapperType, constructedStruct, + fatbinMagicVal, globalBuilder.getDenseI64ArrayAttr(i++)); + constructedStruct = LLVM::InsertValueOp::create( + globalBuilder, loc, fatBinWrapperType, constructedStruct, + fatbinVersionVal, globalBuilder.getDenseI64ArrayAttr(i++)); // TODO do we need to specify the section name here...? // data.setSectionAttr(moduleBuilder.getStringAttr(fatbinSectionName)); Value data = LLVM::createGlobalString( loc, globalBuilder, nameBuffer.str(), "binaryAttr", // loc, globalBuilder, nameBuffer.str(), binaryAttr.getValue(), LLVM::Linkage::Internal); - constructedStruct = globalBuilder.create( - loc, fatBinWrapperType, constructedStruct, data, + constructedStruct = LLVM::InsertValueOp::create( + globalBuilder, loc, fatBinWrapperType, constructedStruct, data, globalBuilder.getDenseI64ArrayAttr(i++)); - constructedStruct = globalBuilder.create( - loc, fatBinWrapperType, constructedStruct, nullPtr, + constructedStruct = LLVM::InsertValueOp::create( + globalBuilder, loc, fatBinWrapperType, constructedStruct, nullPtr, globalBuilder.getDenseI64ArrayAttr(i++)); } - globalBuilder.create(loc, constructedStruct); + LLVM::ReturnOp::create(globalBuilder, loc, constructedStruct); auto addressOfWrapper = - ctorBuilder.create(ctorloc, fatBinWrapper); - auto bitcastOfWrapper = ctorBuilder.create( - ctorloc, llvmPointerType, addressOfWrapper); - - // auto cudaRegisterFatbinFn = - // LLVM::lookupOrCreateFn(rewriter, moduleOp, - // "__cudaRegisterFatBinary", - // llvmPointerType, llvmPointerType); - // if (failed(cudaRegisterFatbinFn)) { - // llvm::errs() << " cudamalloc already exists with different types\n"; - // return failure(); - // } - - // auto module = rewriter.create( - // ctorloc, cudaRegisterFatbinFn.value(), - // ValueRange(bitcastOfWrapper)); + LLVM::AddressOfOp::create(ctorBuilder, ctorloc, fatBinWrapper); + auto bitcastOfWrapper = LLVM::AddrSpaceCastOp::create( + ctorBuilder, ctorloc, llvmPointerType, addressOfWrapper); auto registerFatbinFn = LLVM::lookupOrCreateFn(rewriter, moduleOp, registerFatBinaryFuncName, @@ -1918,23 +1924,24 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, return failure(); } - auto module = rewriter.create( - ctorloc, registerFatbinFn.value(), ValueRange(bitcastOfWrapper)); + auto module = + LLVM::CallOp::create(rewriter, ctorloc, registerFatbinFn.value(), + ValueRange(bitcastOfWrapper)); auto moduleGlobalName = std::string(llvm::formatv("polygeist_{0}_module_ptr", moduleName)); { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - moduleGlobal = rewriter.create( - ctorloc, llvmPointerPointerType, /* isConstant */ false, + moduleGlobal = LLVM::GlobalOp::create( + rewriter, ctorloc, llvmPointerPointerType, /* isConstant */ false, LLVM::Linkage::Internal, moduleGlobalName, /* initValue */ mlir::Attribute(), /* alignment */ 8, /* addrSpace */ 0); } - auto aoo = ctorBuilder.create(ctorloc, moduleGlobal); - ctorBuilder.create(loc, module->getResult(0), - aoo->getResult(0)); + auto aoo = LLVM::AddressOfOp::create(ctorBuilder, ctorloc, moduleGlobal); + LLVM::StoreOp::create(ctorBuilder, loc, module->getResult(0), + aoo->getResult(0)); for (Operation &op : kernelModule->getRegion(0).front()) { if (auto f = dyn_cast(op)) { if (!f->getAttr("gpu.kernel")) @@ -1943,7 +1950,7 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, kernelModule.getName(), f.getName(), ctorloc, ctorBuilder); auto nullPtr = - ctorBuilder.create(ctorloc, llvmPointerType); + LLVM::ZeroOp::create(ctorBuilder, ctorloc, llvmPointerType); // TODO second param should be ptr to the the original function stub // here like clang does it: e.g. kernel_name_device_stub // @@ -1956,31 +1963,25 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - stub = rewriter.create( - ctorloc, getFuncStubName(moduleName, f.getName()), + stub = moduleOp.lookupSymbol(stubName); + stub = LLVM::LLVMFuncOp::create( + rewriter, ctorloc, getFuncStubName(moduleName, f.getName()), LLVM::LLVMFunctionType::get(llvmVoidType, {}), LLVM::Linkage::Internal); } { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(stub.addEntryBlock(rewriter)); - rewriter.create(ctorloc, ValueRange()); + LLVM::ReturnOp::create(rewriter, ctorloc, ValueRange()); } - auto aoo = ctorBuilder.create(ctorloc, stub); - auto bitcast = ctorBuilder.create( - ctorloc, llvmPointerType, aoo); + auto aoo = LLVM::AddressOfOp::create(ctorBuilder, ctorloc, stub); + auto bitcast = LLVM::AddrSpaceCastOp::create(ctorBuilder, ctorloc, + llvmPointerType, aoo); Type tys[] = {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType}; - // auto cudaRegisterFn = LLVM::lookupOrCreateFn( - // rewriter, moduleOp, "__cudaRegisterFunction", tys, - // llvmInt32Type); - // if (failed(cudaRegisterFn)) { - // llvm::errs() << " cudamalloc already exists with different - // types\n"; return failure(); - // } auto registerFunctionFn = LLVM::lookupOrCreateFn( rewriter, moduleOp, registerFunctionFuncName, tys, llvmInt32Type); @@ -1996,17 +1997,15 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, bitcast, kernelName, kernelName, - ctorBuilder.create(ctorloc, llvmInt32Type, -1), + LLVM::ConstantOp::create(ctorBuilder, ctorloc, llvmInt32Type, -1), nullPtr, nullPtr, nullPtr, nullPtr, nullPtr}; - // rewriter.create(ctorloc, cudaRegisterFn.value(), - // args); - rewriter.create(ctorloc, registerFunctionFn.value(), - args); + LLVM::CallOp::create(rewriter, ctorloc, registerFunctionFn.value(), + args); } else if (LLVM::GlobalOp g = dyn_cast(op)) { int addrSpace = g.getAddrSpace(); if (addrSpace != 1 /* device */ && addrSpace != 4 /* constant */) @@ -2026,9 +2025,9 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, // TODO could this be a memref global op? auto stub = moduleOp.lookupSymbol(g.getName()); assert(stub); - auto aoo = ctorBuilder.create(ctorloc, stub); - auto bitcast = ctorBuilder.create( - ctorloc, llvmPointerType, aoo); + auto aoo = LLVM::AddressOfOp::create(ctorBuilder, ctorloc, stub); + auto bitcast = LLVM::AddrSpaceCastOp::create(ctorBuilder, ctorloc, + llvmPointerType, aoo); auto globalTy = stub.getGlobalType(); // TODO This should actually be the GPUModuleOp's data layout I // believe, there were problems with assigning the data layout to @@ -2042,21 +2041,21 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, ctorloc, ctorBuilder, {module.getResult(), bitcast, symbolName, symbolName, /*isExtern*/ - ctorBuilder.create(ctorloc, llvmInt32Type, - /* TODO */ 0), + LLVM::ConstantOp::create(ctorBuilder, ctorloc, llvmInt32Type, + /* TODO */ 0), /*varSize*/ - ctorBuilder.create(ctorloc, llvmIntPtrType, - size), + LLVM::ConstantOp::create(ctorBuilder, ctorloc, llvmIntPtrType, + size), /*isConstant*/ - ctorBuilder.create(ctorloc, llvmInt32Type, - /* TODO */ 0), + LLVM::ConstantOp::create(ctorBuilder, ctorloc, llvmInt32Type, + /* TODO */ 0), /* just a 0? */ - ctorBuilder.create(ctorloc, llvmInt32Type, - 0)}); + LLVM::ConstantOp::create(ctorBuilder, ctorloc, llvmInt32Type, + 0)}); } } // TODO this has to happen only for some CUDA versions, hip does not need - // finialize + // finialize cuda 11.X if (gpuTarget == "cuda") { auto cudaRegisterFatbinFn = LLVM::lookupOrCreateFn( rewriter, moduleOp, "__cudaRegisterFatBinaryEnd", llvmPointerType, @@ -2066,17 +2065,17 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, return failure(); } - rewriter.create(ctorloc, cudaRegisterFatbinFn.value(), - ValueRange(module->getResult(0))); + LLVM::CallOp::create(rewriter, ctorloc, cudaRegisterFatbinFn.value(), + ValueRange(module->getResult(0))); } - ctorBuilder.create(ctorloc, ValueRange()); + LLVM::ReturnOp::create(ctorBuilder, ctorloc, ValueRange()); } auto ctorSymbol = FlatSymbolRefAttr::get(ctor); { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - rewriter.create( - ctorloc, rewriter.getArrayAttr({std::move(ctorSymbol)}), + LLVM::GlobalCtorsOp::create( + rewriter, ctorloc, rewriter.getArrayAttr({std::move(ctorSymbol)}), rewriter.getI32ArrayAttr({65535}), rewriter.getArrayAttr({LLVM::ZeroAttr::get(rewriter.getContext())})); } @@ -2084,20 +2083,10 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, PatternRewriter::InsertionGuard B(rewriter); OpBuilder &dtorBuilder = rewriter; dtorBuilder.setInsertionPointToEnd(dtor.addEntryBlock(dtorBuilder)); - auto aoo = dtorBuilder.create(ctorloc, moduleGlobal); - auto module = dtorBuilder.create( - ctorloc, llvmPointerPointerType, aoo->getResult(0)); - - // auto cudaUnRegisterFatbinFn = LLVM::lookupOrCreateFn( - // rewriter, moduleOp, "__cudaUnregisterFatBinary", llvmPointerType, - // llvmVoidType); - // if (failed(cudaUnRegisterFatbinFn)) { - // llvm::errs() << " cudamalloc already exists with different types\n"; - // return failure(); - // } - - // rewriter.create(ctorloc, cudaUnRegisterFatbinFn.value(), - // ValueRange(module)); + auto aoo = LLVM::AddressOfOp::create(dtorBuilder, ctorloc, moduleGlobal); + auto module = LLVM::LoadOp::create( + dtorBuilder, ctorloc, llvmPointerPointerType, aoo->getResult(0)); + auto unregisterFatbinFn = LLVM::lookupOrCreateFn( rewriter, moduleOp, unregisterFatBinaryFuncName, llvmPointerType, llvmVoidType); @@ -2106,17 +2095,17 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, "different types\n"; return failure(); } - rewriter.create(ctorloc, unregisterFatbinFn.value(), - ValueRange(module)); + LLVM::CallOp::create(rewriter, ctorloc, unregisterFatbinFn.value(), + ValueRange(module)); - dtorBuilder.create(ctorloc, ValueRange()); + LLVM::ReturnOp::create(dtorBuilder, ctorloc, ValueRange()); auto dtorSymbol = FlatSymbolRefAttr::get(dtor); { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); Attribute attrs[] = {LLVM::ZeroAttr::get(rewriter.getContext())}; - rewriter.create( - ctorloc, rewriter.getArrayAttr({std::move(dtorSymbol)}), + LLVM::GlobalDtorsOp::create( + rewriter, ctorloc, rewriter.getArrayAttr({std::move(dtorSymbol)}), rewriter.getI32ArrayAttr({65535}), rewriter.getArrayAttr(attrs)); } } @@ -2188,10 +2177,10 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( launchOp.getKernelName().getValue()); auto bitcast = - rewriter.create(loc, llvmPointerType, funcStubName); + LLVM::AddressOfOp::create(rewriter, loc, llvmPointerType, funcStubName); - Value zero = rewriter.create(loc, llvmInt32Type, 0); - auto nullpointer = rewriter.create(loc, llvmPointerType); + Value zero = LLVM::ConstantOp::create(rewriter, loc, llvmInt32Type, 0); + auto nullpointer = LLVM::ZeroOp::create(rewriter, loc, llvmPointerType); Value stream = adaptor.getAsyncDependencies().empty() ? nullpointer : adaptor.getAsyncDependencies().front(); @@ -2208,16 +2197,17 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( auto i32 = rewriter.getIntegerType(32); auto i64 = rewriter.getIntegerType(64); auto dim3 = [&](Value x, Value y, Value z) { - x = rewriter.create(x.getLoc(), i32, x); - y = rewriter.create(y.getLoc(), i32, y); - z = rewriter.create(z.getLoc(), i32, z); + x = LLVM::TruncOp::create(rewriter, x.getLoc(), i32, x); + y = LLVM::TruncOp::create(rewriter, y.getLoc(), i32, y); + z = LLVM::TruncOp::create(rewriter, z.getLoc(), i32, z); - x = rewriter.create(x.getLoc(), i64, x); - y = rewriter.create(y.getLoc(), i64, y); + x = LLVM::ZExtOp::create(rewriter, x.getLoc(), i64, x); + y = LLVM::ZExtOp::create(rewriter, y.getLoc(), i64, y); - y = rewriter.create( - y.getLoc(), y, rewriter.create(y.getLoc(), i64, 32)); - args.push_back(rewriter.create(x.getLoc(), x, y)); + y = LLVM::ShlOp::create( + rewriter, y.getLoc(), y, + LLVM::ConstantOp::create(rewriter, y.getLoc(), i64, 32)); + args.push_back(LLVM::OrOp::create(rewriter, x.getLoc(), x, y)); args.push_back(z); }; dim3(adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeZ()); @@ -2226,7 +2216,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( args.push_back(kernelParams); args.push_back( - rewriter.create(loc, i64, dynamicSharedMemorySize)); + LLVM::ZExtOp::create(rewriter, loc, i64, dynamicSharedMemorySize)); args.push_back(stream); auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); @@ -2240,7 +2230,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( } auto launchCall = - rewriter.create(loc, TypeRange(i32), launchFuncName, args); + LLVM::CallOp::create(rewriter, loc, TypeRange(i32), launchFuncName, args); if (launchOp.getAsyncToken()) { // Async launch: make dependent ops use the same stream. @@ -2251,8 +2241,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( if (errOp) { rewriter.setInsertionPoint(errOp); - auto reg = rewriter.create( - errOp.getLoc(), launchCall->getResultTypes()[0]); + auto reg = scf::ExecuteRegionOp::create(rewriter, errOp.getLoc(), + launchCall->getResultTypes()[0]); rewriter.inlineRegionBefore(errOp.getRegion(), reg.getRegion(), reg.getRegion().begin()); rewriter.createBlock(&errOp.getRegion()); @@ -2261,35 +2251,36 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto one = rewriter.create(loc, i64, - rewriter.getI64IntegerAttr(1)); + auto one = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(1)); - auto alloca = rewriter.create( - launchOp.getLoc(), ptrty, launchCall->getResultTypes()[0], one); - auto zero = rewriter.create( - loc, launchCall->getResultTypes()[0], 0); + auto alloca = LLVM::AllocaOp::create(rewriter, launchOp.getLoc(), ptrty, + launchCall->getResultTypes()[0], one); + auto zero = arith::ConstantIntOp::create( + rewriter, loc, launchCall->getResultTypes()[0], 0); rewriter.setInsertionPoint(errOp); - rewriter.create(launchOp.getLoc(), zero, alloca); + LLVM::StoreOp::create(rewriter, launchOp.getLoc(), zero, alloca); rewriter.setInsertionPointAfter(launchCall); - rewriter.create(launchOp.getLoc(), launchCall->getResult(0), - alloca); + LLVM::StoreOp::create(rewriter, launchOp.getLoc(), launchCall->getResult(0), + alloca); for (auto &block : reg.getRegion()) { if (auto terminator = dyn_cast(block.getTerminator())) { rewriter.setInsertionPointToEnd(&block); - auto load = rewriter.create( - launchOp.getLoc(), launchCall->getResultTypes()[0], alloca); + auto load = + LLVM::LoadOp::create(rewriter, launchOp.getLoc(), + launchCall->getResultTypes()[0], alloca); rewriter.replaceOpWithNewOp(terminator, load->getResults()); } } rewriter.setInsertionPointAfter(errOp); - auto cast = rewriter.create( - loc, rewriter.getIndexType(), reg->getResult(0)); + auto cast = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), reg->getResult(0)); rewriter.replaceOp(errOp, cast->getResults()); } @@ -2392,8 +2383,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( uint64_t staticSize = static_cast(bitwidth / 8) * static_cast(memrefTy.getNumElements()); - Value sizeArg = rewriter.create( - loc, getIndexType(), rewriter.getIndexAttr(staticSize)); + Value sizeArg = LLVM::ConstantOp::create( + rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize)); llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer. llvmArgumentsWithSizes.push_back(sizeArg); } @@ -2405,8 +2396,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), adaptor.getClusterSizeZ()}; } - rewriter.create( - launchOp.getLoc(), launchOp.getKernelAttr(), + gpu::LaunchFuncOp::create( + rewriter, launchOp.getLoc(), launchOp.getKernelAttr(), gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeZ()}, gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), @@ -2495,7 +2486,7 @@ class ConvertAllocOpToGpuRuntimeCallPattern // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - auto nullPtr = rewriter.create(loc, llvmPointerType); + auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType); Value stream = adaptor.getAsyncDependencies().empty() ? nullPtr : adaptor.getAsyncDependencies().front(); @@ -2511,10 +2502,10 @@ class ConvertAllocOpToGpuRuntimeCallPattern auto ptr1ty = LLVM::LLVMPointerType::get(rewriter.getContext(), 1); if (backend == "cuda") { - auto one = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(1)); + auto one = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(1)); - auto ptr = rewriter.create(loc, ptrty, ptr1ty, one); + auto ptr = LLVM::AllocaOp::create(rewriter, loc, ptrty, ptr1ty, one); Type tys[] = {ptrty, i64}; auto cudaMallocFn = LLVM::lookupOrCreateFn(rewriter, moduleOp, "cudaMalloc", tys, i32); @@ -2527,13 +2518,13 @@ class ConvertAllocOpToGpuRuntimeCallPattern ptr, sizeBytes, }; - rewriter.create(loc, cudaMallocFn.value(), args); - allocatedPtr = rewriter.create(loc, ptr1ty, ptr); + LLVM::CallOp::create(rewriter, loc, cudaMallocFn.value(), args); + allocatedPtr = LLVM::LoadOp::create(rewriter, loc, ptr1ty, ptr); } else if (backend == "rocm") { - auto one = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(1)); + auto one = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(1)); - auto ptr = rewriter.create(loc, ptrty, ptr1ty, one); + auto ptr = LLVM::AllocaOp::create(rewriter, loc, ptrty, ptr1ty, one); Type tys[] = {ptrty, i64}; auto hipMallocFn = @@ -2547,9 +2538,8 @@ class ConvertAllocOpToGpuRuntimeCallPattern ptr, sizeBytes, }; - rewriter.create(loc, hipMallocFn.value(), args); - allocatedPtr = rewriter.create(loc, ptr1ty, ptr); - + LLVM::CallOp::create(rewriter, loc, hipMallocFn.value(), args); + allocatedPtr = LLVM::LoadOp::create(rewriter, loc, ptr1ty, ptr); } else if (backend.starts_with("cpu")) { Type convertedIndex = typeConverter->convertType(rewriter.getIndexType()); @@ -2566,42 +2556,43 @@ class ConvertAllocOpToGpuRuntimeCallPattern sizeBytes, }; allocatedPtr = - rewriter.create(loc, mallocFunc.value(), args) + LLVM::CallOp::create(rewriter, loc, mallocFunc.value(), args) ->getResult(0); allocatedPtr = - rewriter.create(loc, ptr1ty, allocatedPtr); + LLVM::AddrSpaceCastOp::create(rewriter, loc, ptr1ty, allocatedPtr); } else if (backend.starts_with("xla")) { - auto zero = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(0)); + auto zero = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(0)); - auto one = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(1)); + auto one = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(1)); - auto tyid = rewriter.create( - loc, i64, - rewriter.getI64IntegerAttr( - xla_type_id(memRefType.getElementType()))); + auto tyid = + LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(xla_type_id( + memRefType.getElementType()))); Type convertedIndex = typeConverter->convertType(rewriter.getIndexType()); - auto shapeDim = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(memRefType.getShape().size())); + auto shapeDim = LLVM::ConstantOp::create( + rewriter, loc, i64, + rewriter.getI64IntegerAttr(memRefType.getShape().size())); auto AT = LLVM::LLVMArrayType::get(i64, memRefType.getShape().size()); - auto shapePtr = rewriter.create(loc, ptrty, AT, one); + auto shapePtr = LLVM::AllocaOp::create(rewriter, loc, ptrty, AT, one); int dynIdx = 0; for (int i = 0; i < memRefType.getShape().size(); i++) { - auto idx = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(i)); + auto idx = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(i)); Value idxs[] = {zero, idx}; auto gep = - rewriter.create(loc, ptrty, AT, shapePtr, idxs); + LLVM::GEPOp::create(rewriter, loc, ptrty, AT, shapePtr, idxs); Value val; @@ -2609,11 +2600,12 @@ class ConvertAllocOpToGpuRuntimeCallPattern val = adaptor.getDynamicSizes()[dynIdx]; dynIdx++; } else { - val = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(memRefType.getShape()[i])); + val = LLVM::ConstantOp::create( + rewriter, loc, i64, + rewriter.getI64IntegerAttr(memRefType.getShape()[i])); } - rewriter.create(loc, val, gep); + LLVM::StoreOp::create(rewriter, loc, val, gep); } // handle, type id, shape len, shape ptr @@ -2629,19 +2621,19 @@ class ConvertAllocOpToGpuRuntimeCallPattern auto xdata = insertXLAInitDeinit(moduleOp, backend, rewriter); Value args[] = {xdata, tyid, shapeDim, shapePtr}; allocatedPtr = - rewriter.create(loc, xlaMallocFn.value(), args) + LLVM::CallOp::create(rewriter, loc, xlaMallocFn.value(), args) ->getResult(0); allocatedPtr = - rewriter.create(loc, ptr1ty, allocatedPtr); + LLVM::AddrSpaceCastOp::create(rewriter, loc, ptr1ty, allocatedPtr); } else { llvm_unreachable("unknown backend"); } } else { - auto isHostShared = rewriter.create( - loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); + auto isHostShared = mlir::LLVM::ConstantOp::create( + rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); allocatedPtr = allocCallBuilder .create(loc, rewriter, {sizeBytes, stream, isHostShared}) @@ -2720,18 +2712,18 @@ class ConvertOccupancyOp return failure(); } - auto one = rewriter.create(loc, i64, - rewriter.getI64IntegerAttr(1)); + auto one = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(1)); - auto ptr = rewriter.create(loc, ptrty, intty, one); + auto ptr = LLVM::AllocaOp::create(rewriter, loc, ptrty, intty, one); std::string funcStubName = getFuncStubName(op.getFn().getRootReference().getValue(), op.getFn().getLeafReference().getValue()); - auto addr = rewriter.create(loc, ptrty, funcStubName); + auto addr = LLVM::AddressOfOp::create(rewriter, loc, ptrty, funcStubName); Value args[] = {ptr, addr, adaptor.getBlockSize(), adaptor.getDynamicSMemSize(), adaptor.getFlags()}; - rewriter.create(loc, occupancyFn.value(), args); + LLVM::CallOp::create(rewriter, loc, occupancyFn.value(), args); rewriter.replaceOpWithNewOp(op, intty, ptr); return success(); @@ -2811,8 +2803,8 @@ class ConvertDeallocOpToGpuRuntimeCallPattern auto ptr1ty = LLVM::LLVMPointerType::get(rewriter.getContext(), 1); if (backend == "cuda") { - auto one = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(1)); + auto one = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(1)); Type tys[] = {ptr1ty}; auto cudaFreeFn = @@ -2825,7 +2817,7 @@ class ConvertDeallocOpToGpuRuntimeCallPattern Value args[] = { ptr, }; - rewriter.create(loc, cudaFreeFn.value(), args); + LLVM::CallOp::create(rewriter, loc, cudaFreeFn.value(), args); } else if (backend == "rocm") { Type tys[] = {ptr1ty}; auto hipFreeFn = @@ -2838,7 +2830,7 @@ class ConvertDeallocOpToGpuRuntimeCallPattern Value args[] = { ptr, }; - rewriter.create(loc, hipFreeFn.value(), args); + LLVM::CallOp::create(rewriter, loc, hipFreeFn.value(), args); } else if (backend.starts_with("cpu")) { @@ -2853,7 +2845,7 @@ class ConvertDeallocOpToGpuRuntimeCallPattern Value args[] = { ptr, }; - rewriter.create(loc, freeFunc.value(), args); + LLVM::CallOp::create(rewriter, loc, freeFunc.value(), args); } else if (backend.starts_with("xla")) { auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); @@ -2872,7 +2864,7 @@ class ConvertDeallocOpToGpuRuntimeCallPattern Value args[] = {xdata, ptr}; - rewriter.create(loc, xlaFreeFn.value(), args); + LLVM::CallOp::create(rewriter, loc, xlaFreeFn.value(), args); } else { llvm::errs() << " unknown backend: " << backend << "\n"; return failure(); @@ -2920,27 +2912,28 @@ class ConvertXLAWrapperPattern auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto zero = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(0)); + auto zero = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(0)); - auto one = rewriter.create(loc, i64, - rewriter.getI64IntegerAttr(1)); + auto one = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(1)); - auto nargs = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(adaptor.getInputs().size())); + auto nargs = LLVM::ConstantOp::create( + rewriter, loc, i64, + rewriter.getI64IntegerAttr(adaptor.getInputs().size())); auto AT = LLVM::LLVMArrayType::get(i64, adaptor.getInputs().size()); - auto argsPtr = rewriter.create(loc, ptrty, AT, one); + auto argsPtr = LLVM::AllocaOp::create(rewriter, loc, ptrty, AT, one); for (int i = 0; i < adaptor.getInputs().size(); i++) { - auto idx = rewriter.create( - loc, i64, rewriter.getI64IntegerAttr(i)); + auto idx = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(i)); Value idxs[] = {zero, idx}; - auto gep = rewriter.create(loc, ptrty, AT, argsPtr, idxs); + auto gep = LLVM::GEPOp::create(rewriter, loc, ptrty, AT, argsPtr, idxs); - rewriter.create(loc, adaptor.getInputs()[i], gep); + LLVM::StoreOp::create(rewriter, loc, adaptor.getInputs()[i], gep); } // handle, module, nargs, argptr @@ -2958,7 +2951,7 @@ class ConvertXLAWrapperPattern auto xdata = insertXLAInitDeinit(moduleOp, backend, rewriter); Value args[4] = {xdata, stringval, nargs, argsPtr}; - rewriter.create(loc, xlaExecFn.value(), args); + LLVM::CallOp::create(rewriter, loc, xlaExecFn.value(), args); wrap.setFnAttr( FlatSymbolRefAttr::get(rewriter.getStringAttr(""))); @@ -2997,21 +2990,21 @@ struct ReplaceErrOpWithSuccess : public OpRewritePattern { auto ®ion = errOp.getRegion(); rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(errOp.getLoc(), ®ion.front()); + cf::BranchOp::create(rewriter, errOp.getLoc(), ®ion.front()); for (Block &block : errOp->getRegions()[0]) { if (auto terminator = dyn_cast(block.getTerminator())) { ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(&block); - rewriter.create(errOp.getLoc(), remainingOpsBlock, - terminatorOperands); + cf::BranchOp::create(rewriter, errOp.getLoc(), remainingOpsBlock, + terminatorOperands); rewriter.eraseOp(terminator); } } rewriter.inlineRegionBefore(region, remainingOpsBlock); } - auto zero = rewriter.create(errOp->getLoc(), 0); + auto zero = arith::ConstantIndexOp::create(rewriter, errOp->getLoc(), 0); rewriter.replaceOp(errOp, zero->getResults()); return success(); } @@ -3055,8 +3048,8 @@ struct GPUFuncOpLowering : public ConvertOpToLLVMPattern { auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); std::string name = std::string( llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); - auto globalOp = rewriter.create( - gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, + auto globalOp = LLVM::GlobalOp::create( + rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, LLVM::Linkage::Internal, name, /*value=*/Attribute(), /*alignment=*/0, static_cast(gpu::GPUDialect::getWorkgroupAddressSpace())); @@ -3086,8 +3079,8 @@ struct GPUFuncOpLowering : public ConvertOpToLLVMPattern { // latter is expected by gpu.launch_func. if (gpuFuncOp.isKernel()) attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); - auto llvmFuncOp = rewriter.create( - gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, + auto llvmFuncOp = LLVM::LLVMFuncOp::create( + rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, /*comdat=*/nullptr, attributes); @@ -3105,7 +3098,7 @@ struct GPUFuncOpLowering : public ConvertOpToLLVMPattern { for (const auto &en : llvm::enumerate(workgroupBuffers)) { LLVM::GlobalOp global = en.value(); - Value memory = rewriter.create(loc, global); + Value memory = LLVM::AddressOfOp::create(rewriter, loc, global); // Build a memref descriptor pointing to the buffer to plug with the // existing memref infrastructure. This may use more registers than @@ -3134,11 +3127,12 @@ struct GPUFuncOpLowering : public ConvertOpToLLVMPattern { // memory space and does not support `alloca`s with addrspace(5). auto ptrType = LLVM::LLVMPointerType::get(type.getContext(), allocaAddrSpace); - Value numElements = rewriter.create( - gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); - Value allocated = rewriter.create( - gpuFuncOp.getLoc(), ptrType, type.getElementType(), numElements, - /*alignment=*/0); + Value numElements = LLVM::ConstantOp::create( + rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); + Value allocated = + LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType, + type.getElementType(), numElements, + /*alignment=*/0); Value descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, allocated); signatureConversion.remapInput( @@ -3217,8 +3211,8 @@ struct FuncOpLowering : public ConvertOpToLLVMPattern { cast(funcOp->getAttr(kLLVMCConvAttrName)); cconv = attr.getCallingConv(); } - auto newFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), convertedType, linkage, + auto newFuncOp = LLVM::LLVMFuncOp::create( + rewriter, funcOp.getLoc(), funcOp.getName(), convertedType, linkage, /*dsoLocal=*/false, /*cconv=*/cconv, /*comdat=*/nullptr, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); @@ -3252,9 +3246,9 @@ struct CallOpLowering : public ConvertOpToLLVMPattern { } } - auto newCallOp = rewriter.create( - callOp->getLoc(), callResultTypes, callOp.getCallee(), - adaptor.getOperands()); + auto newCallOp = + LLVM::CallOp::create(rewriter, callOp->getLoc(), callResultTypes, + callOp.getCallee(), adaptor.getOperands()); newCallOp->setAttrs(callOp->getAttrs()); if (numResults <= 1) { @@ -3265,8 +3259,8 @@ struct CallOpLowering : public ConvertOpToLLVMPattern { SmallVector results; results.reserve(numResults); for (auto index : llvm::seq(0, numResults)) { - results.push_back(rewriter.create( - callOp->getLoc(), newCallOp->getResult(0), index)); + results.push_back(LLVM::ExtractValueOp::create( + rewriter, callOp->getLoc(), newCallOp->getResult(0), index)); } rewriter.replaceOp(callOp, results); return success(); @@ -3291,10 +3285,10 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { returnOp->getContext(), llvm::to_vector(adaptor.getOperands().getTypes())); Value packed = - rewriter.create(returnOp->getLoc(), returnedType); + LLVM::UndefOp::create(rewriter, returnOp->getLoc(), returnedType); for (const auto &[index, value] : llvm::enumerate(adaptor.getOperands())) { - packed = rewriter.create(returnOp->getLoc(), packed, - value, index); + packed = LLVM::InsertValueOp::create(rewriter, returnOp->getLoc(), packed, + value, index); } rewriter.replaceOpWithNewOp(returnOp, packed); return success(); @@ -3319,10 +3313,10 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern { returnOp->getContext(), llvm::to_vector(adaptor.getOperands().getTypes())); Value packed = - rewriter.create(returnOp->getLoc(), returnedType); + LLVM::UndefOp::create(rewriter, returnOp->getLoc(), returnedType); for (const auto &[index, value] : llvm::enumerate(adaptor.getOperands())) { - packed = rewriter.create(returnOp->getLoc(), packed, - value, index); + packed = LLVM::InsertValueOp::create(rewriter, returnOp->getLoc(), packed, + value, index); } rewriter.replaceOpWithNewOp(returnOp, packed); return success(); @@ -3373,7 +3367,7 @@ struct AllocaScopeOpLowering remainingOpsBlock, allocaScopeOp.getResultTypes(), SmallVector(allocaScopeOp->getNumResults(), allocaScopeOp.getLoc())); - rewriter.create(loc, ValueRange(), remainingOpsBlock); + LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock); } // Inline body region. @@ -3384,8 +3378,8 @@ struct AllocaScopeOpLowering // Save stack and then branch into the body of the region. rewriter.setInsertionPointToEnd(currentBlock); auto stackSaveOp = - rewriter.create(loc, getVoidPtrType()); - rewriter.create(loc, ValueRange(), beforeBody); + LLVM::StackSaveOp::create(rewriter, loc, getVoidPtrType()); + LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody); // Replace the alloca_scope return with a branch that jumps out of the body. // Stack restore before leaving the body region. @@ -3397,7 +3391,7 @@ struct AllocaScopeOpLowering // Insert stack restore before jumping out the body of the region. rewriter.setInsertionPoint(branchOp); - rewriter.create(loc, stackSaveOp); + LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp); // Replace the op with values return from the body region. rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); @@ -3465,22 +3459,22 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { auto int32Type = IntegerType::get(rewriter.getContext(), 32); auto predTy = IntegerType::get(rewriter.getContext(), 1); - Value one = rewriter.create(loc, int32Type, 1); - Value minusOne = rewriter.create(loc, int32Type, -1); - Value thirtyTwo = rewriter.create(loc, int32Type, 32); - Value numLeadInactiveLane = rewriter.create( - loc, int32Type, thirtyTwo, adaptor.getWidth()); + Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1); + Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1); + Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32); + Value numLeadInactiveLane = LLVM::SubOp::create( + rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth()); // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`. - Value activeMask = rewriter.create(loc, int32Type, minusOne, - numLeadInactiveLane); + Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne, + numLeadInactiveLane); Value maskAndClamp; if (op.getMode() == gpu::ShuffleMode::UP) { // Clamp lane: `32 - activeWidth` maskAndClamp = numLeadInactiveLane; } else { // Clamp lane: `activeWidth - 1` - maskAndClamp = - rewriter.create(loc, int32Type, adaptor.getWidth(), one); + maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type, + adaptor.getWidth(), one); } bool predIsUsed = !op->getResult(1).use_empty(); @@ -3491,13 +3485,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), {valueTy, predTy}); } - Value shfl = rewriter.create( - loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(), - maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr); + Value shfl = NVVM::ShflOp::create( + rewriter, loc, resultTy, activeMask, adaptor.getValue(), + adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()), + returnValueAndIsValidAttr); if (predIsUsed) { - Value shflValue = rewriter.create(loc, shfl, 0); + Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0); Value isActiveSrcLane = - rewriter.create(loc, shfl, 1); + LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1); rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); } else { rewriter.replaceOp(op, {shfl, nullptr}); @@ -3522,16 +3517,16 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern { bounds = rewriter.getAttr( /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize); Value newOp = - rewriter.create(loc, rewriter.getI32Type(), bounds); + NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); if (indexBitwidth > 32) { - newOp = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), newOp); + newOp = LLVM::SExtOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); } else if (indexBitwidth < 32) { - newOp = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), newOp); + newOp = LLVM::TruncOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); } rewriter.replaceOp(op, {newOp}); return success(); @@ -3596,13 +3591,13 @@ struct OpLowering : public OpConversionPattern { Operation *newOp; switch (op.getDimension()) { case gpu::Dimension::x: - newOp = rewriter.create(loc, IntegerType::get(context, 32)); + newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32)); break; case gpu::Dimension::y: - newOp = rewriter.create(loc, IntegerType::get(context, 32)); + newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32)); break; case gpu::Dimension::z: - newOp = rewriter.create(loc, IntegerType::get(context, 32)); + newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32)); break; } @@ -3661,11 +3656,13 @@ struct OpLowering : public OpConversionPattern { rewriter.getContext(), 32, min, max)); } if (indexBitwidth > 32) { - newOp = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); + newOp = LLVM::SExtOp::create(rewriter, loc, + IntegerType::get(context, indexBitwidth), + newOp->getResult(0)); } else if (indexBitwidth < 32) { - newOp = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); + newOp = LLVM::TruncOp::create(rewriter, loc, + IntegerType::get(context, indexBitwidth), + newOp->getResult(0)); } rewriter.replaceOpWithNewOp( @@ -3688,16 +3685,15 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern { MLIRContext *context = rewriter.getContext(); auto int32Type = rewriter.getI32Type(); - Value minusOne = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(-1)); - Value zero = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(0)); - - Value laneIdLo = rewriter.create(loc, int32Type, minusOne, - zero, nullptr, nullptr); - Value laneId = rewriter.create( - loc, int32Type, minusOne, laneIdLo, nullptr, nullptr); + Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, + rewriter.getI32IntegerAttr(-1)); + Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, + rewriter.getI32IntegerAttr(0)); + Value laneIdLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, + minusOne, zero, nullptr, nullptr); + Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, minusOne, + laneIdLo, nullptr, nullptr); LLVM::ConstantRangeAttr bounds = nullptr; if (std::optional upperBound = op.getUpperBound()) bounds = rewriter.getAttr( @@ -3713,11 +3709,11 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern { const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); if (indexBitwidth > 32) { - laneId = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), laneId); + laneId = LLVM::SExtOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), laneId); } else if (indexBitwidth < 32) { - laneId = rewriter.create( - loc, IntegerType::get(context, indexBitwidth), laneId); + laneId = LLVM::TruncOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), laneId); } rewriter.replaceOp(op, {laneId}); @@ -3738,28 +3734,31 @@ struct GPUShuffleOpToROCDL : public ConvertOpToLLVMPattern { auto value = adaptor.getValue(); auto int32Type = IntegerType::get(rewriter.getContext(), 32); - Value minusOne = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(-1)); - Value zero = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(0)); + Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, + rewriter.getI32IntegerAttr(-1)); + Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, + rewriter.getI32IntegerAttr(0)); - Value laneIdLo = rewriter.create(loc, int32Type, minusOne, - zero, nullptr, nullptr); - Value laneId = rewriter.create( - loc, int32Type, minusOne, laneIdLo, nullptr, nullptr); + Value laneIdLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, + minusOne, zero, nullptr, nullptr); + Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, minusOne, + laneIdLo, nullptr, nullptr); Value targetLane; Value offset = adaptor.getOffset(); switch (op.getMode()) { case gpu::ShuffleMode::XOR: - targetLane = rewriter.create(loc, int32Type, laneId, offset); + targetLane = + LLVM::XOrOp::create(rewriter, loc, int32Type, laneId, offset); break; case gpu::ShuffleMode::UP: - targetLane = rewriter.create(loc, int32Type, laneId, offset); + targetLane = + LLVM::SubOp::create(rewriter, loc, int32Type, laneId, offset); break; case gpu::ShuffleMode::DOWN: - targetLane = rewriter.create(loc, int32Type, laneId, offset); + targetLane = + LLVM::AddOp::create(rewriter, loc, int32Type, laneId, offset); break; case gpu::ShuffleMode::IDX: targetLane = offset; @@ -3768,47 +3767,48 @@ struct GPUShuffleOpToROCDL : public ConvertOpToLLVMPattern { Value width = adaptor.getWidth(); - auto isNonNegative = rewriter.create( - loc, LLVM::ICmpPredicate::sge, targetLane, zero); - auto isWithinWidth = rewriter.create( - loc, LLVM::ICmpPredicate::slt, targetLane, width); + auto isNonNegative = LLVM::ICmpOp::create( + rewriter, loc, LLVM::ICmpPredicate::sge, targetLane, zero); + auto isWithinWidth = LLVM::ICmpOp::create( + rewriter, loc, LLVM::ICmpPredicate::slt, targetLane, width); auto isValid = - rewriter.create(loc, isNonNegative, isWithinWidth); + LLVM::AndOp::create(rewriter, loc, isNonNegative, isWithinWidth); Value maskAndClamp; - Value widthMinusone = rewriter.create( - loc, width, - rewriter.create(loc, int32Type, - rewriter.getI32IntegerAttr(1))); - Value minResult = rewriter.create( - loc, - rewriter.create(loc, LLVM::ICmpPredicate::slt, targetLane, - widthMinusone), + Value widthMinusone = LLVM::SubOp::create( + rewriter, loc, width, + LLVM::ConstantOp::create(rewriter, loc, int32Type, + rewriter.getI32IntegerAttr(1))); + Value minResult = LLVM::SelectOp::create( + rewriter, loc, + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, + targetLane, widthMinusone), targetLane, widthMinusone); - maskAndClamp = rewriter.create( - loc, - rewriter.create(loc, LLVM::ICmpPredicate::sgt, minResult, - zero), + maskAndClamp = LLVM::SelectOp::create( + rewriter, loc, + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, minResult, + zero), minResult, zero); - Value four = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(4)); - Value byteIndex = rewriter.create(loc, maskAndClamp, four); + Value four = LLVM::ConstantOp::create(rewriter, loc, int32Type, + rewriter.getI32IntegerAttr(4)); + Value byteIndex = LLVM::MulOp::create(rewriter, loc, maskAndClamp, four); Value shuffleResult; if (valueTy.isF32()) { Value valueAsInt = - rewriter.create(loc, int32Type, value); + LLVM::BitcastOp::create(rewriter, loc, int32Type, value); - Value resultInt = rewriter.create( - loc, int32Type, byteIndex, valueAsInt); + Value resultInt = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type, + byteIndex, valueAsInt); - shuffleResult = rewriter.create(loc, valueTy, resultInt); + shuffleResult = + LLVM::BitcastOp::create(rewriter, loc, valueTy, resultInt); } else if (valueTy.isInteger(32)) { - shuffleResult = rewriter.create(loc, int32Type, - byteIndex, value); + shuffleResult = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type, + byteIndex, value); } // } else if (valueTy.isF64() || valueTy.isInteger(64)) { // shuffleResult = shuffle64BitValue(loc, rewriter, value, byteIndex, @@ -3845,8 +3845,8 @@ struct ClusterIdOpToROCDL : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto indexType = getTypeConverter()->getIndexType(); - Value zero = rewriter.create(loc, indexType, - rewriter.getIndexAttr(0)); + Value zero = LLVM::ConstantOp::create(rewriter, loc, indexType, + rewriter.getIndexAttr(0)); rewriter.replaceOp(op, zero); return success(); @@ -3861,8 +3861,8 @@ struct ClusterDimOpToROCDL : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto indexType = getTypeConverter()->getIndexType(); - Value one = rewriter.create(loc, indexType, - rewriter.getIndexAttr(1)); + Value one = LLVM::ConstantOp::create(rewriter, loc, indexType, + rewriter.getIndexAttr(1)); rewriter.replaceOp(op, one); return success(); @@ -4013,9 +4013,9 @@ static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) { auto voidTy = LLVM::LLVMVoidType::get(ctx); auto ptrTy = LLVM::LLVMPointerType::get(ctx); - auto resumeOp = moduleBuilder.create( - fname, LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); - resumeOp.setPrivate(); + auto resumeOp = LLVM::LLVMFuncOp::create( + moduleBuilder, fname, + LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); return resumeOp; } @@ -4035,9 +4035,9 @@ static LLVM::LLVMFuncOp addMocROCmFunction(ModuleOp module, Type streamTy) { auto voidTy = LLVM::LLVMVoidType::get(ctx); auto ptrTy = LLVM::LLVMPointerType::get(ctx); - auto resumeOp = moduleBuilder.create( - fname, LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); - resumeOp.setPrivate(); + auto resumeOp = LLVM::LLVMFuncOp::create( + moduleBuilder, fname, + LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); return resumeOp; } @@ -4049,7 +4049,7 @@ struct NoAsyncOpLowering : public OpConversionPattern { matchAndRewrite(async::ExecuteOp execute, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto exec = - rewriter.create(execute->getLoc(), TypeRange()); + scf::ExecuteRegionOp::create(rewriter, execute->getLoc(), TypeRange()); rewriter.inlineRegionBefore(execute.getRegion(), exec.getRegion(), exec.getRegion().begin()); for (auto &blk : exec.getRegion()) { @@ -4117,8 +4117,8 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { rewriter.setInsertionPointToEnd(module.getBody()); static int off = 0; off++; - func = rewriter.create( - execute.getLoc(), + func = LLVM::LLVMFuncOp::create( + rewriter, execute.getLoc(), "kernelbody." + std::to_string((long long int)&execute) + "." + std::to_string(off), funcType); @@ -4147,49 +4147,50 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { converter->convertType(functionInputs[0].getType()))) { valueMapping.map( functionInputs[0], - rewriter.create( - execute.getLoc(), + LLVM::BitcastOp::create( + rewriter, execute.getLoc(), converter->convertType(functionInputs[0].getType()), arg)); } else if (functionInputs.size() == 1 && isa( converter->convertType(functionInputs[0].getType()))) { valueMapping.map( functionInputs[0], - rewriter.create( - execute.getLoc(), + LLVM::PtrToIntOp::create( + rewriter, execute.getLoc(), converter->convertType(functionInputs[0].getType()), arg)); } else { SmallVector types; for (auto v : functionInputs) types.push_back(converter->convertType(v.getType())); auto ST = LLVM::LLVMStructType::getLiteral(ctx, types); - auto alloc = rewriter.create( - execute.getLoc(), LLVM::LLVMPointerType::get(ctx), arg); + auto alloc = LLVM::BitcastOp::create( + rewriter, execute.getLoc(), LLVM::LLVMPointerType::get(ctx), arg); for (auto idx : llvm::enumerate(functionInputs)) { mlir::Value idxs[] = { - rewriter.create(loc, 0, 32), - rewriter.create(loc, idx.index(), 32), + arith::ConstantIntOp::create(rewriter, loc, 0, 32), + arith::ConstantIntOp::create(rewriter, loc, idx.index(), 32), }; - Value next = - rewriter.create(loc, LLVM::LLVMPointerType::get(ctx), + Value next = LLVM::GEPOp::create(rewriter, loc, + LLVM::LLVMPointerType::get(ctx), idx.value().getType(), alloc, idxs); - valueMapping.map(idx.value(), rewriter.create( - loc, idx.value().getType(), next)); + valueMapping.map( + idx.value(), + LLVM::LoadOp::create(rewriter, loc, idx.value().getType(), next)); } auto freef = getTypeConverter()->getOptions().useGenericFunctions ? LLVM::lookupOrCreateGenericFreeFn(rewriter, module) : LLVM::lookupOrCreateFreeFn(rewriter, module); Value args[] = {arg}; - rewriter.create(loc, freef.value(), args); + LLVM::CallOp::create(rewriter, loc, freef.value(), args); } // Clone all operations from the execute operation body into the outlined // function body. rewriter.cloneRegionBefore(execute.getBodyRegion(), func.getRegion(), func.getRegion().end(), valueMapping); - rewriter.create(execute.getLoc(), ValueRange(), - &*std::next(func.getRegion().begin())); + LLVM::BrOp::create(rewriter, execute.getLoc(), ValueRange(), + &*std::next(func.getRegion().begin())); for (Block &b : func.getRegion()) { auto term = b.getTerminator(); if (isa(term)) { @@ -4211,17 +4212,17 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { SmallVector vals; if (crossing.size() == 0) { vals.push_back( - rewriter.create(execute.getLoc(), voidPtr)); + LLVM::ZeroOp::create(rewriter, execute.getLoc(), voidPtr)); } else if (crossing.size() == 1 && isa( converter->convertType(crossing[0].getType()))) { - vals.push_back(rewriter.create(execute.getLoc(), - voidPtr, crossing[0])); + vals.push_back(LLVM::BitcastOp::create(rewriter, execute.getLoc(), + voidPtr, crossing[0])); } else if (crossing.size() == 1 && isa( converter->convertType(crossing[0].getType()))) { - vals.push_back(rewriter.create(execute.getLoc(), - voidPtr, crossing[0])); + vals.push_back(LLVM::IntToPtrOp::create(rewriter, execute.getLoc(), + voidPtr, crossing[0])); } else { SmallVector types; for (auto v : crossing) @@ -4233,36 +4234,36 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { DataLayout DLI(execute->getParentOfType()); - Value arg = rewriter.create( - loc, rewriter.getI64Type(), DLI.getTypeSize(ST)); + Value arg = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI64Type(), DLI.getTypeSize(ST)); auto mallocFunc = LLVM::lookupOrCreateMallocFn(rewriter, module, getIndexType()); mlir::Value alloc = - rewriter.create(loc, mallocFunc.value(), arg) + LLVM::CallOp::create(rewriter, loc, mallocFunc.value(), arg) .getResult(); rewriter.setInsertionPoint(execute); for (auto idx : llvm::enumerate(crossing)) { mlir::Value idxs[] = { - rewriter.create(loc, 0, 32), - rewriter.create(loc, idx.index(), 32), + arith::ConstantIntOp::create(rewriter, loc, 0, 32), + arith::ConstantIntOp::create(rewriter, loc, idx.index(), 32), }; - Value next = rewriter.create( - loc, LLVM::LLVMPointerType::get(rewriter.getContext()), + Value next = LLVM::GEPOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), idx.value().getType(), alloc, idxs); - rewriter.create(loc, idx.value(), next); + LLVM::StoreOp::create(rewriter, loc, idx.value(), next); } - vals.push_back( - rewriter.create(execute.getLoc(), voidPtr, alloc)); + vals.push_back(LLVM::BitcastOp::create(rewriter, execute.getLoc(), + voidPtr, alloc)); } - vals.push_back(rewriter.create( - execute.getLoc(), voidPtr, - rewriter.create(execute.getLoc(), func))); + vals.push_back(LLVM::BitcastOp::create( + rewriter, execute.getLoc(), voidPtr, + LLVM::AddressOfOp::create(rewriter, execute.getLoc(), func))); for (auto dep : execute.getDependencies()) { auto src = dep.getDefiningOp().getSource(); if (auto MT = dyn_cast(src.getType())) - src = rewriter.create( - dep.getDefiningOp()->getLoc(), + src = enzymexla::Memref2PointerOp::create( + rewriter, dep.getDefiningOp()->getLoc(), LLVM::LLVMPointerType::get(rewriter.getContext(), MT.getMemorySpaceAsInt()), src); @@ -4279,7 +4280,7 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { : addMocROCmFunction(execute->getParentOfType(), vals.back().getType()); - rewriter.create(execute.getLoc(), f, vals); + LLVM::CallOp::create(rewriter, execute.getLoc(), f, vals); rewriter.eraseOp(execute); } @@ -4573,6 +4574,25 @@ struct ConvertPolygeistToLLVMPass signalPassFailure(); return; } + + { + const char *GetDeviceFromHostFuncName = + "__reactant$get_device_from_host"; + SmallVector toHandle; + m->walk([&](LLVM::CallOp call) { + CallInterfaceCallable callable = call.getCallableForCallee(); + auto callee = dyn_cast(callable); + if (!callee) + return; + if (callee.getLeafReference() == GetDeviceFromHostFuncName) + toHandle.push_back(call); + }); + for (auto call : toHandle) { + assert(call->getNumResults() == 1 && call.getNumOperands() == 1); + call.getResult().replaceAllUsesWith(call.getArgOperands()[0]); + call->erase(); + } + } } if (StringRef(gpuTarget).starts_with("xla")) { From 6f5fe267d160ef1ae775022ddb859f9947671b5f Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 6 Nov 2025 04:32:06 -0600 Subject: [PATCH 10/27] fix BUILD format --- src/enzyme_ad/jax/BUILD | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 5574352463..66c7931e5e 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -715,17 +715,18 @@ cc_library( cc_library( name = "XLADerivatives", - srcs = glob([ - "Implementations/*.cpp", - "Passes/*.cpp", - "Dialect/*.cpp", - "Dialect/Distributed/*.cpp", - "Dialect/Tessera/*.cpp", - ], - exclude = [ - "Passes/CudaRuntimeWrappers.cpp", - "Passes/RocmRuntimeWrappers.cpp", - ], + srcs = glob( + [ + "Implementations/*.cpp", + "Passes/*.cpp", + "Dialect/*.cpp", + "Dialect/Distributed/*.cpp", + "Dialect/Tessera/*.cpp", + ], + exclude = [ + "Passes/CudaRuntimeWrappers.cpp", + "Passes/RocmRuntimeWrappers.cpp", + ], ) + [ "Utils.cpp", ], From 409b2cee5d565dffea20c1f8edfa1d17b694e959 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 6 Nov 2025 23:37:04 -0600 Subject: [PATCH 11/27] fix branch --- src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 26b5ef0be3..7b9b54abe5 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -1842,7 +1842,7 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, unregisterFatBinaryFuncName = "__cudaUnregisterFatBinary"; registerFatBinaryEndFuncName = "__cudaRegisterFatBinaryEnd"; requiresRegisterEnd = true; - } else if (gpuTarget == "rocm") { + } else { registerFatBinaryFuncName = "__hipRegisterFatBinary"; registerFunctionFuncName = "__hipRegisterFunction"; registerVarFuncName = "__hipRegisterVar"; @@ -1963,7 +1963,6 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, { PatternRewriter::InsertionGuard B(rewriter); rewriter.setInsertionPointToEnd(moduleOp.getBody()); - stub = moduleOp.lookupSymbol(stubName); stub = LLVM::LLVMFuncOp::create( rewriter, ctorloc, getFuncStubName(moduleName, f.getName()), LLVM::LLVMFunctionType::get(llvmVoidType, {}), @@ -2225,7 +2224,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( std::string launchFuncName; if (gpuTarget == "cuda") { launchFuncName = "cudaLaunchKernel"; - } else if (gpuTarget == "rocm") { + } else { launchFuncName = "hipLaunchKernel"; } @@ -3950,7 +3949,7 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, populateLibDeviceConversionPatterns(typeConverter, patterns, benefit); patterns.add(typeConverter, benefit); - } else if (gpuTarget == "rocm") { + } else { using namespace mlir::gpu::index_lowering; PatternBenefit benefit(1); PatternBenefit highBenefit(2); @@ -4387,7 +4386,7 @@ struct ConvertPolygeistToLLVMPass Type tys[] = {ptrty, i64, i32, i64, i32, ptrty, i64, ptrty}; if (backend == "cuda") { LLVM::lookupOrCreateFn(rewriter, m, "cudaLaunchKernel", tys, i32); - } else if (backend == "rocm") { + } else { LLVM::lookupOrCreateFn(rewriter, m, "hipLaunchKernel", tys, i32); } } From 8c77cb29054da2e089a67f06b7deea84e23fbec6 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Fri, 7 Nov 2025 00:00:33 -0600 Subject: [PATCH 12/27] not fix --- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 7b9b54abe5..1e8143044f 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -3949,7 +3949,7 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, populateLibDeviceConversionPatterns(typeConverter, patterns, benefit); patterns.add(typeConverter, benefit); - } else { + } else if (gpuTarget == "rocm"){ using namespace mlir::gpu::index_lowering; PatternBenefit benefit(1); PatternBenefit highBenefit(2); @@ -4378,17 +4378,24 @@ struct ConvertPolygeistToLLVMPass bool hasLaunch = m->walk([](gpu::LaunchFuncOp) { return WalkResult::interrupt(); }).wasInterrupted(); + + std::string launchFuncName; + if (backend == "cuda") { + launchFuncName = "cudaLaunchKernel"; + } else if (backend == "rocm") { + launchFuncName = "hipLaunchKernel"; + } else { + launchFuncName = "cudaLaunchKernel"; + } if (hasLaunch) { OpBuilder rewriter(m); auto i32 = rewriter.getIntegerType(32); auto i64 = rewriter.getIntegerType(64); auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); Type tys[] = {ptrty, i64, i32, i64, i32, ptrty, i64, ptrty}; - if (backend == "cuda") { - LLVM::lookupOrCreateFn(rewriter, m, "cudaLaunchKernel", tys, i32); - } else { - LLVM::lookupOrCreateFn(rewriter, m, "hipLaunchKernel", tys, i32); - } + + LLVM::lookupOrCreateFn(rewriter, m, launchFuncName, tys, i32); + } for (auto mod : gmods) { @@ -4574,24 +4581,6 @@ struct ConvertPolygeistToLLVMPass return; } - { - const char *GetDeviceFromHostFuncName = - "__reactant$get_device_from_host"; - SmallVector toHandle; - m->walk([&](LLVM::CallOp call) { - CallInterfaceCallable callable = call.getCallableForCallee(); - auto callee = dyn_cast(callable); - if (!callee) - return; - if (callee.getLeafReference() == GetDeviceFromHostFuncName) - toHandle.push_back(call); - }); - for (auto call : toHandle) { - assert(call->getNumResults() == 1 && call.getNumOperands() == 1); - call.getResult().replaceAllUsesWith(call.getArgOperands()[0]); - call->erase(); - } - } } if (StringRef(gpuTarget).starts_with("xla")) { @@ -4637,6 +4626,24 @@ struct ConvertPolygeistToLLVMPass signalPassFailure(); return; } + { + const char *GetDeviceFromHostFuncName = + "__reactant$get_device_from_host"; + SmallVector toHandle; + m->walk([&](LLVM::CallOp call) { + CallInterfaceCallable callable = call.getCallableForCallee(); + auto callee = dyn_cast(callable); + if (!callee) + return; + if (callee.getLeafReference() == GetDeviceFromHostFuncName) + toHandle.push_back(call); + }); + for (auto call : toHandle) { + assert(call->getNumResults() == 1 && call.getNumOperands() == 1); + call.getResult().replaceAllUsesWith(call.getArgOperands()[0]); + call->erase(); + } + } } void runOnOperation() override { From a51cdadd5e73679425649567be52e5bbe64c5657 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Fri, 7 Nov 2025 00:05:16 -0600 Subject: [PATCH 13/27] fmt --- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 37 +++++++++---------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 1e8143044f..f884f4e948 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -3949,7 +3949,7 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, populateLibDeviceConversionPatterns(typeConverter, patterns, benefit); patterns.add(typeConverter, benefit); - } else if (gpuTarget == "rocm"){ + } else if (gpuTarget == "rocm") { using namespace mlir::gpu::index_lowering; PatternBenefit benefit(1); PatternBenefit highBenefit(2); @@ -4393,9 +4393,8 @@ struct ConvertPolygeistToLLVMPass auto i64 = rewriter.getIntegerType(64); auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); Type tys[] = {ptrty, i64, i32, i64, i32, ptrty, i64, ptrty}; - + LLVM::lookupOrCreateFn(rewriter, m, launchFuncName, tys, i32); - } for (auto mod : gmods) { @@ -4580,7 +4579,6 @@ struct ConvertPolygeistToLLVMPass signalPassFailure(); return; } - } if (StringRef(gpuTarget).starts_with("xla")) { @@ -4627,23 +4625,22 @@ struct ConvertPolygeistToLLVMPass return; } { - const char *GetDeviceFromHostFuncName = - "__reactant$get_device_from_host"; - SmallVector toHandle; - m->walk([&](LLVM::CallOp call) { - CallInterfaceCallable callable = call.getCallableForCallee(); - auto callee = dyn_cast(callable); - if (!callee) - return; - if (callee.getLeafReference() == GetDeviceFromHostFuncName) - toHandle.push_back(call); - }); - for (auto call : toHandle) { - assert(call->getNumResults() == 1 && call.getNumOperands() == 1); - call.getResult().replaceAllUsesWith(call.getArgOperands()[0]); - call->erase(); - } + const char *GetDeviceFromHostFuncName = "__reactant$get_device_from_host"; + SmallVector toHandle; + m->walk([&](LLVM::CallOp call) { + CallInterfaceCallable callable = call.getCallableForCallee(); + auto callee = dyn_cast(callable); + if (!callee) + return; + if (callee.getLeafReference() == GetDeviceFromHostFuncName) + toHandle.push_back(call); + }); + for (auto call : toHandle) { + assert(call->getNumResults() == 1 && call.getNumOperands() == 1); + call.getResult().replaceAllUsesWith(call.getArgOperands()[0]); + call->erase(); } + } } void runOnOperation() override { From 77dfb892df1880b52aae016a6f88a43ca36a1fc7 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Tue, 11 Nov 2025 14:56:31 -0600 Subject: [PATCH 14/27] fix BUILD --- src/enzyme_ad/jax/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 62a41e2fb1..4bc5e48339 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -933,6 +933,7 @@ cc_library( "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:MathToROCDL", "@llvm-project//mlir:MathToLibm", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefToLLVM", From 1db07cadc62ecd2fdc30684f02b90bc58651c92f Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Tue, 11 Nov 2025 15:01:42 -0600 Subject: [PATCH 15/27] fmt --- src/enzyme_ad/jax/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 4bc5e48339..4bfb8ef95d 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -933,8 +933,8 @@ cc_library( "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathToLLVM", - "@llvm-project//mlir:MathToROCDL", "@llvm-project//mlir:MathToLibm", + "@llvm-project//mlir:MathToROCDL", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefToLLVM", "@llvm-project//mlir:MemRefTransforms", From 2da26ac6e1bf75d5ec41730d94988bd2000b80c2 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Tue, 11 Nov 2025 15:16:18 -0600 Subject: [PATCH 16/27] fix lit tests --- src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index f884f4e948..7c1070a87f 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -2224,8 +2224,10 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( std::string launchFuncName; if (gpuTarget == "cuda") { launchFuncName = "cudaLaunchKernel"; - } else { + } else if (gpuTarget == "rocm") { launchFuncName = "hipLaunchKernel"; + } else { + launchFuncName = "cudaLaunchKernel"; } auto launchCall = From a35939193434b1bcefd17fe5c74b1f020181f089 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Tue, 18 Nov 2025 00:55:32 -0600 Subject: [PATCH 17/27] add fenceOp for GPUBarrierToROCDL --- src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 7c1070a87f..dd173b56fb 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -3833,7 +3833,12 @@ struct GPUBarrierToROCDL : ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + LLVM::FenceOp::create(rewriter, op.getLoc(), LLVM::AtomicOrdering::release, + StringRef("workgroup")); rewriter.replaceOpWithNewOp(op); + LLVM::FenceOp::create(rewriter, op.getLoc(), LLVM::AtomicOrdering::acquire, + StringRef("workgroup")); + rewriter.eraseOp(op); return success(); } }; From cde7153f0185569a1561252b4741abf1b5a4e631 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Wed, 19 Nov 2025 16:24:29 -0600 Subject: [PATCH 18/27] modify ROCDL::MbcntHiOp::create --- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index dd173b56fb..ac4bbf3215 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -3691,10 +3691,10 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern { Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, rewriter.getI32IntegerAttr(0)); - Value laneIdLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, - minusOne, zero, nullptr, nullptr); - Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, minusOne, - laneIdLo, nullptr, nullptr); + Value laneIdLo = + ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, {minusOne, zero}); + Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, + {minusOne, laneIdLo}); LLVM::ConstantRangeAttr bounds = nullptr; if (std::optional upperBound = op.getUpperBound()) bounds = rewriter.getAttr( @@ -3740,10 +3740,10 @@ struct GPUShuffleOpToROCDL : public ConvertOpToLLVMPattern { Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, rewriter.getI32IntegerAttr(0)); - Value laneIdLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, - minusOne, zero, nullptr, nullptr); - Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, minusOne, - laneIdLo, nullptr, nullptr); + Value laneIdLo = + ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, {minusOne, zero}); + Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, + {minusOne, laneIdLo}); Value targetLane; Value offset = adaptor.getOffset(); From 8b336962b61ff278ce50cd21fe40b24cff9e880b Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Wed, 19 Nov 2025 16:28:50 -0600 Subject: [PATCH 19/27] rm CudaRuntimeWrappers and RocmRuntimeWrappers, modify BUILD --- src/enzyme_ad/jax/BUILD | 7 +- .../jax/Passes/CudaRuntimeWrappers.cpp | 206 ---------------- .../jax/Passes/RocmRuntimeWrappers.cpp | 230 ------------------ 3 files changed, 1 insertion(+), 442 deletions(-) delete mode 100644 src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp delete mode 100644 src/enzyme_ad/jax/Passes/RocmRuntimeWrappers.cpp diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index b643284f71..8a7223ac5d 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -845,12 +845,7 @@ cc_library( "Dialect/*.cpp", "Dialect/Distributed/*.cpp", "Dialect/Tessera/*.cpp", - ], - exclude = [ - "Passes/CudaRuntimeWrappers.cpp", - "Passes/RocmRuntimeWrappers.cpp", - ], - ) + [ + ]) + [ "Utils.cpp", ], hdrs = glob([ diff --git a/src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp b/src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp deleted file mode 100644 index cf4e00783f..0000000000 --- a/src/enzyme_ad/jax/Passes/CudaRuntimeWrappers.cpp +++ /dev/null @@ -1,206 +0,0 @@ -//===- EnzymeCudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Implements C wrappers around the CUDA library for easy linking in ORC jit. -// Also adds some debugging helpers that are helpful when writing MLIR code to -// run on GPUs. -// -//===----------------------------------------------------------------------===// - -#include -#include - -#include "cuda.h" -#include "cuda_runtime.h" - -#include "PGORuntime.h" - -#ifdef _WIN32 -#define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) __attribute__((weak)) -#else -#define MLIR_CUDA_WRAPPERS_EXPORT __attribute__((weak)) -#endif // _WIN32 - -#define CUDART_REPORT_IF_ERROR(expr) \ - [](auto result) { \ - if (!result) \ - return result; \ - const char *name = cudaGetErrorString(result); \ - if (!name) \ - name = ""; \ - fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ - return result; \ - }(expr) - -#define CUDA_REPORT_IF_ERROR(expr) \ - [](CUresult result) { \ - if (!result) \ - return result; \ - const char *name = nullptr; \ - cuGetErrorName(result, &name); \ - if (!name) \ - name = ""; \ - fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ - return result; \ - }(expr) - -thread_local static int32_t defaultDevice = 0; - -// Make the primary context of the current default device current for the -// duration -// of the instance and restore the previous context on destruction. -class ScopedContext { -public: - ScopedContext() { - // Static reference to CUDA primary context for device ordinal - // defaultDevice. - static CUcontext context = [] { - CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); - CUdevice device; - CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); - CUcontext ctx; - // Note: this does not affect the current context. - CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device)); - return ctx; - }(); - - CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context)); - } - - ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); } -}; - -//========= CUDA RUNTIME API =========// - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT void -mgpurtLaunchKernel(void *function, intptr_t gridX, intptr_t gridY, - intptr_t gridZ, intptr_t blockX, intptr_t blockY, - intptr_t blockZ, int32_t smem, cudaStream_t stream, - void **params) { - CUDART_REPORT_IF_ERROR(cudaLaunchKernel(function, dim3(gridX, gridY, gridZ), - dim3(blockX, blockY, blockZ), params, - smem, stream)); -} - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT int32_t mgpurtLaunchKernelErr( - void *function, intptr_t gridX, intptr_t gridY, intptr_t gridZ, - intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, - cudaStream_t stream, void **params) { - return CUDART_REPORT_IF_ERROR( - cudaLaunchKernel(function, dim3(gridX, gridY, gridZ), - dim3(blockX, blockY, blockZ), params, smem, stream)); -} - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * -mgpurtMemAlloc(uint64_t sizeBytes, cudaStream_t /*stream*/) { - void *ptr; - CUDART_REPORT_IF_ERROR(cudaMalloc(&ptr, sizeBytes)); - return reinterpret_cast(ptr); -} - -extern "C" void mgpurtMemcpyErr(void *dst, void *src, size_t sizeBytes) { - CUDART_REPORT_IF_ERROR(cudaMemcpy(dst, src, sizeBytes, cudaMemcpyDefault)); -} - -extern "C" void mgpurtMemcpyAsyncErr(void *dst, void *src, size_t sizeBytes, - cudaStream_t stream) { - CUDART_REPORT_IF_ERROR( - cudaMemcpyAsync(dst, src, sizeBytes, cudaMemcpyDefault, stream)); -} - -//========= CUDA DRIVER API =========// - -// The wrapper uses intptr_t instead of CUDA's unsigned int to match -// the type of MLIR's index type. This avoids the need for casts in the -// generated MLIR code. -extern "C" MLIR_CUDA_WRAPPERS_EXPORT void -mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY, - intptr_t gridZ, intptr_t blockX, intptr_t blockY, - intptr_t blockZ, int32_t smem, CUstream stream, void **params, - void **extra) { - ScopedContext scopedContext; - CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, - blockY, blockZ, smem, stream, params, - extra)); -} - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT int32_t mgpuLaunchKernelErr( - CUfunction function, intptr_t gridX, intptr_t gridY, intptr_t gridZ, - intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, - CUstream stream, void **params, void **extra) { - ScopedContext scopedContext; - return CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, - blockX, blockY, blockZ, smem, - stream, params, extra)); -} - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) { - ScopedContext scopedContext; - CUmodule module = nullptr; - CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); - return module; -} - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) { - CUDA_REPORT_IF_ERROR(cuModuleUnload(module)); -} - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction -mgpuModuleGetFunction(CUmodule module, const char *name) { - CUfunction function = nullptr; - CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); - return function; -} - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT int32_t mgpurtDeviceSynchronizeErr(void) { - return CUDART_REPORT_IF_ERROR(cudaDeviceSynchronize()); -} - -extern "C" void __cudaRegisterFunction(void **fatCubinHandle, void *hostFun, - void *deviceFun, void *deviceName, - int32_t thread_limit, void *tid, - void *bid, void *bDim, void *gDim, - void *wSize); -extern "C" void __cudaRegisterVar(void **fatCubinHandle, char *hostVar, - char *deviceAddress, const char *deviceName, - int ext, size_t size, int constant, - int global); -extern "C" void **__cudaRegisterFatBinary(void *fatCubin); -extern "C" void __cudaRegisterFatBinaryEnd(void **fatCubinHandle); -extern "C" void __cudaUnregisterFatBinary(void **fatCubinHandle); - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT void -__mgpurtRegisterFunction(void **fatCubinHandle, void *hostFun, void *deviceFun, - void *deviceName, int32_t thread_limit, void *tid, - void *bid, void *bDim, void *gDim, void *wSize) { - __cudaRegisterFunction(fatCubinHandle, hostFun, deviceFun, deviceName, - thread_limit, tid, bid, bDim, gDim, wSize); -} - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT void -__mgpurtRegisterVar(void **fatCubinHandle, char *hostVar, char *deviceAddress, - const char *deviceName, int ext, size_t size, int constant, - int global) { - __cudaRegisterVar(fatCubinHandle, hostVar, deviceAddress, deviceName, ext, - size, constant, global); -} - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT void ** -__mgpurtRegisterFatBinary(void *fatCubin) { - return __cudaRegisterFatBinary(fatCubin); -} - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT void -__mgpurtRegisterFatBinaryEnd(void **fatCubinHandle) { - __cudaRegisterFatBinaryEnd(fatCubinHandle); -} - -extern "C" MLIR_CUDA_WRAPPERS_EXPORT void -__mgpurtUnregisterFatBinary(void **fatCubinHandle) { - __cudaUnregisterFatBinary(fatCubinHandle); -} diff --git a/src/enzyme_ad/jax/Passes/RocmRuntimeWrappers.cpp b/src/enzyme_ad/jax/Passes/RocmRuntimeWrappers.cpp deleted file mode 100644 index 5cae5c4948..0000000000 --- a/src/enzyme_ad/jax/Passes/RocmRuntimeWrappers.cpp +++ /dev/null @@ -1,230 +0,0 @@ -//===- PolygeistRocmRuntimeWrappers.cpp - MLIR ROCM API wrapper library ---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Implements C wrappers around the ROCM library for easy linking in ORC jit. -// Also adds some debugging helpers that are helpful when writing MLIR code to -// run on GPUs. -// -//===----------------------------------------------------------------------===// - -#include -#include - -#include "hip/hip_runtime.h" - -#include "PGORuntime.h" - -#ifdef _WIN32 -#define MLIR_HIP_WRAPPERS_EXPORT __declspec(dllexport) __attribute__((weak)) -#else -#define MLIR_HIP_WRAPPERS_EXPORT __attribute__((weak)) -#endif // _WIN32 - -#define HIP_REPORT_IF_ERROR(expr) \ - [](hipError_t result) { \ - if (!result) \ - return; \ - const char *name = hipGetErrorName(result); \ - if (!name) \ - name = ""; \ - fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ - }(expr) - -#define ERR_HIP_REPORT_IF_ERROR(expr) \ - [](hipError_t result) -> hipError_t { \ - if (!result) \ - return result; \ - const char *name = hipGetErrorName(result); \ - if (!name) \ - name = ""; \ - fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ - return result; \ - }(expr) - -extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t -mgpurtMemAllocErr(void **mem, uint64_t sizeBytes) { - return ERR_HIP_REPORT_IF_ERROR(hipMalloc(mem, sizeBytes)); -} - -extern "C" MLIR_HIP_WRAPPERS_EXPORT void * -mgpurtMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) { - void *ptr; - HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes)); - return reinterpret_cast(ptr); -} - -extern "C" MLIR_HIP_WRAPPERS_EXPORT void mgpuMemFree(void *ptr, - hipStream_t /*stream*/) { - HIP_REPORT_IF_ERROR(hipFree(ptr)); -} - -extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t -mgpurtMemcpyErr(void *dst, void *src, intptr_t sizeBytes) { - return ERR_HIP_REPORT_IF_ERROR( - hipMemcpy(dst, src, sizeBytes, hipMemcpyDefault)); -} - -extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t mgpurtMemcpyAsyncErr( - void *dst, void *src, intptr_t sizeBytes, hipStream_t stream) { - return ERR_HIP_REPORT_IF_ERROR( - hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream)); -} - -extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t mgpurtDeviceSynchronizeErr(void) { - return ERR_HIP_REPORT_IF_ERROR(hipDeviceSynchronize()); -} - -extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t mgpurtLaunchKernelErr( - void *function, intptr_t gridX, intptr_t gridY, intptr_t gridZ, - intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, - hipStream_t stream, void **params) { - return ERR_HIP_REPORT_IF_ERROR( - hipLaunchKernel(function, dim3(gridX, gridY, gridZ), - dim3(blockX, blockY, blockZ), params, smem, stream)); -} - -extern "C" void __hipRegisterFunction(void **fatCubinHandle, void *hostFun, - void *deviceFun, void *deviceName, - int32_t thread_limit, void *tid, - void *bid, void *bDim, void *gDim, - void *wSize); -extern "C" void __hipRegisterVar(void **fatCubinHandle, char *hostVar, - char *deviceAddress, const char *deviceName, - int ext, size_t size, int constant, - int global); -extern "C" void **__hipRegisterFatBinary(void *fatCubin); -extern "C" void __hipRegisterFatBinaryEnd(void **fatCubinHandle); -extern "C" void __hipUnregisterFatBinary(void **fatCubinHandle); - -extern "C" MLIR_HIP_WRAPPERS_EXPORT void -__mgpurtRegisterFunction(void **fatCubinHandle, void *hostFun, void *deviceFun, - void *deviceName, int32_t thread_limit, void *tid, - void *bid, void *bDim, void *gDim, void *wSize) { - __hipRegisterFunction(fatCubinHandle, hostFun, deviceFun, deviceName, - thread_limit, tid, bid, bDim, gDim, wSize); -} -extern "C" MLIR_HIP_WRAPPERS_EXPORT void -__mgpurtRegisterVar(void **fatCubinHandle, char *hostVar, char *deviceAddress, - const char *deviceName, int ext, size_t size, int constant, - int global) { - __hipRegisterVar(fatCubinHandle, hostVar, deviceAddress, deviceName, ext, - size, constant, global); -} - -extern "C" MLIR_HIP_WRAPPERS_EXPORT void ** -__mgpurtRegisterFatBinary(void *fatCubin) { - return __hipRegisterFatBinary(fatCubin); -} - -extern "C" MLIR_HIP_WRAPPERS_EXPORT void -__mgpurtRegisterFatBinaryEnd(void **fatCubinHandle) { - return __hipRegisterFatBinaryEnd(fatCubinHandle); -} - -extern "C" MLIR_HIP_WRAPPERS_EXPORT void -__mgpurtUnregisterFatBinary(void **fatCubinHandle) { - return __hipUnregisterFatBinary(fatCubinHandle); -} - -#if POLYGEIST_ENABLE_CUDA - -#pragma push_macro("__forceinline__") -#define __VECTOR_TYPES_H__ -#include -#undef __VECTOR_TYPES_H__ -#pragma pop_macro("__forceinline__") - -extern "C" MLIR_HIP_WRAPPERS_EXPORT int32_t -mgpurtCudaGetDeviceProperties(struct cudaDeviceProp *cudaProp, int device) { - struct hipDeviceProp_t hipProp; - int err = ERR_HIP_REPORT_IF_ERROR(hipGetDeviceProperties(&hipProp, device)); - - // Reassign all corresponding fields to the hip props, the commented ones dont - // exist in hip one-for-one -#define __polygeist_assign_field(f) \ - memcpy(&(cudaProp->f), &(hipProp.f), sizeof(cudaProp->f)) - __polygeist_assign_field(name); - // __polygeist_assign_field(uuid); - __polygeist_assign_field(totalGlobalMem); - __polygeist_assign_field(sharedMemPerBlock); - __polygeist_assign_field(regsPerBlock); - __polygeist_assign_field(warpSize); - __polygeist_assign_field(memPitch); - __polygeist_assign_field(maxThreadsPerBlock); - __polygeist_assign_field(maxThreadsDim); - __polygeist_assign_field(maxGridSize); - __polygeist_assign_field(clockRate); - __polygeist_assign_field(totalConstMem); - __polygeist_assign_field(major); - __polygeist_assign_field(minor); - __polygeist_assign_field(textureAlignment); - __polygeist_assign_field(texturePitchAlignment); - // __polygeist_assign_field(deviceOverlap); - __polygeist_assign_field(multiProcessorCount); - __polygeist_assign_field(kernelExecTimeoutEnabled); - __polygeist_assign_field(integrated); - __polygeist_assign_field(canMapHostMemory); - __polygeist_assign_field(computeMode); - __polygeist_assign_field(maxTexture1D); - // __polygeist_assign_field(maxTexture1DMipmap); - __polygeist_assign_field(maxTexture1DLinear); - __polygeist_assign_field(maxTexture2D); - // __polygeist_assign_field(maxTexture2DMipmap); - // __polygeist_assign_field(maxTexture2DLinear); - // __polygeist_assign_field(maxTexture2DGather); - __polygeist_assign_field(maxTexture3D); - // __polygeist_assign_field(maxTexture3DAlt); - // __polygeist_assign_field(maxTextureCubemap); - // __polygeist_assign_field(maxTexture1DLayered); - // __polygeist_assign_field(maxTexture2DLayered); - // __polygeist_assign_field(maxTextureCubemapLayered); - // __polygeist_assign_field(maxSurface1D); - // __polygeist_assign_field(maxSurface2D); - // __polygeist_assign_field(maxSurface3D); - // __polygeist_assign_field(maxSurface1DLayered); - // __polygeist_assign_field(maxSurface2DLayered); - // __polygeist_assign_field(maxSurfaceCubemap); - // __polygeist_assign_field(maxSurfaceCubemapLayered); - // __polygeist_assign_field(surfaceAlignment); - __polygeist_assign_field(concurrentKernels); - __polygeist_assign_field(ECCEnabled); - __polygeist_assign_field(pciBusID); - __polygeist_assign_field(pciDeviceID); - __polygeist_assign_field(pciDomainID); - __polygeist_assign_field(tccDriver); - // __polygeist_assign_field(asyncEngineCount); - // __polygeist_assign_field(unifiedAddressing); - __polygeist_assign_field(memoryClockRate); - __polygeist_assign_field(memoryBusWidth); - __polygeist_assign_field(l2CacheSize); - // __polygeist_assign_field(persistingL2CacheMaxSize); - __polygeist_assign_field(maxThreadsPerMultiProcessor); - // __polygeist_assign_field(streamPrioritiesSupported); - // __polygeist_assign_field(globalL1CacheSupported); - // __polygeist_assign_field(localL1CacheSupported); - // __polygeist_assign_field(sharedMemPerMultiprocessor); - // __polygeist_assign_field(regsPerMultiprocessor); - __polygeist_assign_field(managedMemory); - __polygeist_assign_field(isMultiGpuBoard); - // __polygeist_assign_field(multiGpuBoardGroupID); - // __polygeist_assign_field(singleToDoublePrecisionPerfRatio); - __polygeist_assign_field(pageableMemoryAccess); - __polygeist_assign_field(concurrentManagedAccess); - // __polygeist_assign_field(computePreemptionSupported); - // __polygeist_assign_field(canUseHostPointerForRegisteredMem); - __polygeist_assign_field(cooperativeLaunch); - __polygeist_assign_field(cooperativeMultiDeviceLaunch); - __polygeist_assign_field(pageableMemoryAccessUsesHostPageTables); - __polygeist_assign_field(directManagedMemAccessFromHost); - // __polygeist_assign_field(accessPolicyMaxWindowSize); -#undef __polygeist_assign_field - - return err; -} - -#endif From d9fb7400526480b23d3b0fd405ffbc67a73929e7 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Wed, 19 Nov 2025 16:30:34 -0600 Subject: [PATCH 20/27] fix BUILD format --- src/enzyme_ad/jax/BUILD | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 8a7223ac5d..4b50633827 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -845,7 +845,8 @@ cc_library( "Dialect/*.cpp", "Dialect/Distributed/*.cpp", "Dialect/Tessera/*.cpp", - ]) + [ + ], + ) + [ "Utils.cpp", ], hdrs = glob([ From 1c4e7e010c32b7c177abcdfeadc77e714b1d5953 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Sat, 22 Nov 2025 22:28:48 -0600 Subject: [PATCH 21/27] modify GPUFuncLoweringPatterns --- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 182 ++---------------- 1 file changed, 16 insertions(+), 166 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index ac4bbf3215..8915d8e091 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -3024,8 +3024,9 @@ struct GPUFuncOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; GPUFuncOpLowering(LLVMTypeConverter &converter, unsigned allocaAddrSpace, - StringAttr kernelAttributeName) - : ConvertOpToLLVMPattern(converter), + StringAttr kernelAttributeName, + PatternBenefit benefit = PatternBenefit(1)) + : ConvertOpToLLVMPattern(converter, benefit), allocaAddrSpace(allocaAddrSpace), kernelAttributeName(kernelAttributeName) {} @@ -3675,158 +3676,6 @@ struct OpLowering : public OpConversionPattern { } // namespace gpu } // namespace mlir -// https://rocm.docs.amd.com/projects/HIP/en/docs-6.4.0/reference/hardware_features.html -struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - MLIRContext *context = rewriter.getContext(); - auto int32Type = rewriter.getI32Type(); - - Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, - rewriter.getI32IntegerAttr(-1)); - Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, - rewriter.getI32IntegerAttr(0)); - - Value laneIdLo = - ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, {minusOne, zero}); - Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, - {minusOne, laneIdLo}); - LLVM::ConstantRangeAttr bounds = nullptr; - if (std::optional upperBound = op.getUpperBound()) - bounds = rewriter.getAttr( - /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue()); - else - bounds = rewriter.getAttr( - /*bitWidth=*/32, /*lower=*/0, /*upper=*/64); - - if (bounds) { - laneId.getDefiningOp()->setAttr("range", bounds); - } - - const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); - - if (indexBitwidth > 32) { - laneId = LLVM::SExtOp::create( - rewriter, loc, IntegerType::get(context, indexBitwidth), laneId); - } else if (indexBitwidth < 32) { - laneId = LLVM::TruncOp::create( - rewriter, loc, IntegerType::get(context, indexBitwidth), laneId); - } - - rewriter.replaceOp(op, {laneId}); - - return success(); - } -}; - -struct GPUShuffleOpToROCDL : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - - auto valueTy = adaptor.getValue().getType(); - auto value = adaptor.getValue(); - auto int32Type = IntegerType::get(rewriter.getContext(), 32); - - Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, - rewriter.getI32IntegerAttr(-1)); - Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, - rewriter.getI32IntegerAttr(0)); - - Value laneIdLo = - ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, {minusOne, zero}); - Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, - {minusOne, laneIdLo}); - - Value targetLane; - Value offset = adaptor.getOffset(); - - switch (op.getMode()) { - case gpu::ShuffleMode::XOR: - targetLane = - LLVM::XOrOp::create(rewriter, loc, int32Type, laneId, offset); - break; - case gpu::ShuffleMode::UP: - targetLane = - LLVM::SubOp::create(rewriter, loc, int32Type, laneId, offset); - break; - case gpu::ShuffleMode::DOWN: - targetLane = - LLVM::AddOp::create(rewriter, loc, int32Type, laneId, offset); - break; - case gpu::ShuffleMode::IDX: - targetLane = offset; - break; - } - - Value width = adaptor.getWidth(); - - auto isNonNegative = LLVM::ICmpOp::create( - rewriter, loc, LLVM::ICmpPredicate::sge, targetLane, zero); - auto isWithinWidth = LLVM::ICmpOp::create( - rewriter, loc, LLVM::ICmpPredicate::slt, targetLane, width); - auto isValid = - LLVM::AndOp::create(rewriter, loc, isNonNegative, isWithinWidth); - - Value maskAndClamp; - - Value widthMinusone = LLVM::SubOp::create( - rewriter, loc, width, - LLVM::ConstantOp::create(rewriter, loc, int32Type, - rewriter.getI32IntegerAttr(1))); - Value minResult = LLVM::SelectOp::create( - rewriter, loc, - LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, - targetLane, widthMinusone), - targetLane, widthMinusone); - maskAndClamp = LLVM::SelectOp::create( - rewriter, loc, - LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, minResult, - zero), - minResult, zero); - - Value four = LLVM::ConstantOp::create(rewriter, loc, int32Type, - rewriter.getI32IntegerAttr(4)); - Value byteIndex = LLVM::MulOp::create(rewriter, loc, maskAndClamp, four); - - Value shuffleResult; - if (valueTy.isF32()) { - Value valueAsInt = - LLVM::BitcastOp::create(rewriter, loc, int32Type, value); - - Value resultInt = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type, - byteIndex, valueAsInt); - - shuffleResult = - LLVM::BitcastOp::create(rewriter, loc, valueTy, resultInt); - - } else if (valueTy.isInteger(32)) { - shuffleResult = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type, - byteIndex, value); - } - // } else if (valueTy.isF64() || valueTy.isInteger(64)) { - // shuffleResult = shuffle64BitValue(loc, rewriter, value, byteIndex, - // valueTy); - // } - - bool predIsUsed = !op->getResult(1).use_empty(); - if (predIsUsed) { - rewriter.replaceOp(op, {shuffleResult, isValid}); - } else { - rewriter.replaceOp(op, {shuffleResult, nullptr}); - } - - return success(); - } -}; - struct GPUBarrierToROCDL : ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -3907,6 +3756,7 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, LLVMTypeConverter &typeConverter, std::string gpuTarget, bool func) { if (func) { + PatternBenefit highBenefit(2); patterns.add(typeConverter); patterns.add( typeConverter, @@ -3914,7 +3764,8 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, StringAttr::get(&typeConverter.getContext(), gpuTarget == "cuda" ? NVVM::NVVMDialect::getKernelFuncAttrName() - : ROCDL::ROCDLDialect::getKernelFuncAttrName())); + : ROCDL::ROCDLDialect::getKernelFuncAttrName()), + highBenefit); } else { if (gpuTarget == "cuda") { using namespace mlir::gpu::index_lowering; @@ -3960,30 +3811,29 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, using namespace mlir::gpu::index_lowering; PatternBenefit benefit(1); PatternBenefit highBenefit(2); + + mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, + mlir::gpu::amd::Runtime::HIP, + amdgpu::Chipset()); + patterns.add>(typeConverter, IndexKind::Block, IntrType::Id, - benefit); + highBenefit); patterns.add>(typeConverter, IndexKind::Block, IntrType::Dim, - benefit); + highBenefit); patterns.add>(typeConverter, IndexKind::Grid, IntrType::Id, - benefit); + highBenefit); patterns.add>(typeConverter, IndexKind::Grid, IntrType::Dim, - benefit); - - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + highBenefit); - populateMathToLLVMConversionPatterns(typeConverter, patterns); - populateMathToROCDLConversionPatterns(typeConverter, patterns, - std::nullopt); + patterns.add(typeConverter, highBenefit); patterns.add(typeConverter, highBenefit); patterns.add(typeConverter, highBenefit); From ed5d49139ad12db5b5f2042e29b3b869a41d6f4b Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Sat, 22 Nov 2025 23:02:10 -0600 Subject: [PATCH 22/27] fix --- src/enzyme_ad/jax/BUILD | 3 +++ src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 0c3ba46dcb..de5520492a 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -905,6 +905,8 @@ cc_library( "@llvm-project//llvm:Passes", "@llvm-project//llvm:Scalar", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AMDGPUDialect", + "@llvm-project//mlir:AMDGPUToROCDL", "@llvm-project//mlir:AffineAnalysis", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", @@ -935,6 +937,7 @@ cc_library( "@llvm-project//mlir:GPUPipelines", "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:IndexToLLVM", diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 8915d8e091..e064fedc97 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -20,6 +20,7 @@ #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -30,6 +31,7 @@ #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/DLTI/DLTI.h" @@ -3812,6 +3814,8 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit(1); PatternBenefit highBenefit(2); + typeConverter.getContext().loadDialect(); + mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, mlir::gpu::amd::Runtime::HIP, amdgpu::Chipset()); From 6f17744b4d18a9cf08b7dc405d6b0809e0e4d164 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Sun, 23 Nov 2025 15:41:27 -0600 Subject: [PATCH 23/27] fix --- src/enzyme_ad/jax/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index de5520492a..69d8c61b82 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -907,6 +907,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", "@llvm-project//mlir:AMDGPUToROCDL", + "@llvm-project//mlir:AMDGPUUtils", "@llvm-project//mlir:AffineAnalysis", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", From 4a056b4b20e3d90457ff4bb4cb6df468deac0410 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Wed, 26 Nov 2025 17:38:24 -0600 Subject: [PATCH 24/27] remove one branch --- src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index e064fedc97..c1e1e14378 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -2224,9 +2224,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( Type tys[] = {ptrty, i64, i32, i64, i32, ptrty, i64, ptrty}; std::string launchFuncName; - if (gpuTarget == "cuda") { - launchFuncName = "cudaLaunchKernel"; - } else if (gpuTarget == "rocm") { + if (gpuTarget == "rocm") { launchFuncName = "hipLaunchKernel"; } else { launchFuncName = "cudaLaunchKernel"; @@ -4241,9 +4239,7 @@ struct ConvertPolygeistToLLVMPass }).wasInterrupted(); std::string launchFuncName; - if (backend == "cuda") { - launchFuncName = "cudaLaunchKernel"; - } else if (backend == "rocm") { + if (backend == "rocm") { launchFuncName = "hipLaunchKernel"; } else { launchFuncName = "cudaLaunchKernel"; From a1390a6afa185a5112bf7d16760f99b23f471cd9 Mon Sep 17 00:00:00 2001 From: Vimarsh Sathia Date: Thu, 4 Dec 2025 18:00:08 -0600 Subject: [PATCH 25/27] Undo changes already handled by llvm-project, just stick to enzymexla-specific ops --- .../jax/Passes/ConvertPolygeistToLLVM.cpp | 232 +++--------------- 1 file changed, 40 insertions(+), 192 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index c1e1e14378..54a81b1c89 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -1633,7 +1633,7 @@ struct LowerGPUAlternativesOp } else { // TODO might get some round off errors here, maybe use a better alg // or median - avgs.push_back( + avgs.push_back(1 std::accumulate(timings[i].begin(), timings[i].end(), 0.0f) / timings[i].size()); llvm::errs() << "Alternative " << i << "," << descs[i] << " is " @@ -1767,8 +1767,18 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, std::string registerFunctionFuncName; std::string registerVarFuncName; std::string unregisterFatBinaryFuncName; - std::string registerFatBinaryEndFuncName; - bool requiresRegisterEnd; + + if (gpuTarget == "cuda") { + registerFatBinaryFuncName = "__cudaRegisterFatBinary"; + registerFunctionFuncName = "__cudaRegisterFunction"; + registerVarFuncName = "__cudaRegisterVar"; + unregisterFatBinaryFuncName = "__cudaUnregisterFatBinary"; + } else { + registerFatBinaryFuncName = "__hipRegisterFatBinary"; + registerFunctionFuncName = "__hipRegisterFunction"; + registerVarFuncName = "__hipRegisterVar"; + unregisterFatBinaryFuncName = "__hipUnregisterFatBinary"; + } rewriter.modifyOpInPlace(kernelModule, [&]() { kernelModule->setAttr("polygeist_stubs", rewriter.getUnitAttr()); @@ -1837,22 +1847,6 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, fatMagic = HIPFatMagic; } - if (gpuTarget == "cuda") { - registerFatBinaryFuncName = "__cudaRegisterFatBinary"; - registerFunctionFuncName = "__cudaRegisterFunction"; - registerVarFuncName = "__cudaRegisterVar"; - unregisterFatBinaryFuncName = "__cudaUnregisterFatBinary"; - registerFatBinaryEndFuncName = "__cudaRegisterFatBinaryEnd"; - requiresRegisterEnd = true; - } else { - registerFatBinaryFuncName = "__hipRegisterFatBinary"; - registerFunctionFuncName = "__hipRegisterFunction"; - registerVarFuncName = "__hipRegisterVar"; - unregisterFatBinaryFuncName = "__hipUnregisterFatBinary"; - registerFatBinaryEndFuncName = ""; - requiresRegisterEnd = false; - } - (void)fatbinConstantName; (void)moduleIDSectionName; @@ -1916,18 +1910,17 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, auto bitcastOfWrapper = LLVM::AddrSpaceCastOp::create( ctorBuilder, ctorloc, llvmPointerType, addressOfWrapper); - auto registerFatbinFn = + auto cudaRegisterFatbinFn = LLVM::lookupOrCreateFn(rewriter, moduleOp, registerFatBinaryFuncName, llvmPointerType, llvmPointerType); - if (failed(registerFatbinFn)) { - llvm::errs() - << "register fatbin function already exists with different types\n"; + if (failed(cudaRegisterFatbinFn)) { + llvm::errs() << "cudamalloc already exists with different types\n"; return failure(); } auto module = - LLVM::CallOp::create(rewriter, ctorloc, registerFatbinFn.value(), + LLVM::CallOp::create(rewriter, ctorloc, cudaRegisterFatbinFn.value(), ValueRange(bitcastOfWrapper)); auto moduleGlobalName = @@ -1984,12 +1977,11 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType}; - auto registerFunctionFn = LLVM::lookupOrCreateFn( + auto cudaRegisterFn = LLVM::lookupOrCreateFn( rewriter, moduleOp, registerFunctionFuncName, tys, llvmInt32Type); - if (failed(registerFunctionFn)) { - llvm::errs() - << " register function already exists with different types\n"; + if (failed(cudaRegisterFn)) { + llvm::errs() << " cudamalloc already exists with different types\n"; return failure(); } @@ -2005,8 +1997,7 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, nullPtr, nullPtr}; - LLVM::CallOp::create(rewriter, ctorloc, registerFunctionFn.value(), - args); + LLVM::CallOp::create(rewriter, ctorloc, cudaRegisterFn.value(), args); } else if (LLVM::GlobalOp g = dyn_cast(op)) { int addrSpace = g.getAddrSpace(); if (addrSpace != 1 /* device */ && addrSpace != 4 /* constant */) @@ -2037,7 +2028,6 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, // to pass the GPU DL in here DataLayout DLI(moduleOp); auto size = DLI.getTypeSize(globalTy); - // why 'mgpu' rtRegisterVarCallBuilder.create( ctorloc, ctorBuilder, {module.getResult(), bitcast, symbolName, symbolName, @@ -2088,15 +2078,14 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, auto module = LLVM::LoadOp::create( dtorBuilder, ctorloc, llvmPointerPointerType, aoo->getResult(0)); - auto unregisterFatbinFn = LLVM::lookupOrCreateFn( + auto cudaUnRegisterFatbinFn = LLVM::lookupOrCreateFn( rewriter, moduleOp, unregisterFatBinaryFuncName, llvmPointerType, llvmVoidType); - if (failed(unregisterFatbinFn)) { - llvm::errs() << " unregister fatbin function already exists with " - "different types\n"; + if (failed(cudaUnRegisterFatbinFn)) { + llvm::errs() << " cudamalloc already exists with different types\n"; return failure(); } - LLVM::CallOp::create(rewriter, ctorloc, unregisterFatbinFn.value(), + LLVM::CallOp::create(rewriter, ctorloc, cudaUnRegisterFatbinFn.value(), ValueRange(module)); LLVM::ReturnOp::create(dtorBuilder, ctorloc, ValueRange()); @@ -2223,12 +2212,9 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); Type tys[] = {ptrty, i64, i32, i64, i32, ptrty, i64, ptrty}; - std::string launchFuncName; - if (gpuTarget == "rocm") { - launchFuncName = "hipLaunchKernel"; - } else { - launchFuncName = "cudaLaunchKernel"; - } + // Create LLVM call to launch kernel + std::string launchFuncName = + (gpuTarget == "rocm") ? "hipLaunchKernel" : "cudaLaunchKernel"; auto launchCall = LLVM::CallOp::create(rewriter, loc, TypeRange(i32), launchFuncName, args); @@ -2505,7 +2491,6 @@ class ConvertAllocOpToGpuRuntimeCallPattern if (backend == "cuda") { auto one = LLVM::ConstantOp::create(rewriter, loc, i64, rewriter.getI64IntegerAttr(1)); - auto ptr = LLVM::AllocaOp::create(rewriter, loc, ptrty, ptr1ty, one); Type tys[] = {ptrty, i64}; auto cudaMallocFn = @@ -2519,17 +2504,17 @@ class ConvertAllocOpToGpuRuntimeCallPattern ptr, sizeBytes, }; + LLVM::CallOp::create(rewriter, loc, cudaMallocFn.value(), args); allocatedPtr = LLVM::LoadOp::create(rewriter, loc, ptr1ty, ptr); } else if (backend == "rocm") { auto one = LLVM::ConstantOp::create(rewriter, loc, i64, rewriter.getI64IntegerAttr(1)); - auto ptr = LLVM::AllocaOp::create(rewriter, loc, ptrty, ptr1ty, one); Type tys[] = {ptrty, i64}; - auto hipMallocFn = LLVM::lookupOrCreateFn(rewriter, moduleOp, "hipMalloc", tys, i32); + if (failed(hipMallocFn)) { llvm::errs() << " hipMalloc already exists with different types\n"; return failure(); @@ -2539,6 +2524,7 @@ class ConvertAllocOpToGpuRuntimeCallPattern ptr, sizeBytes, }; + LLVM::CallOp::create(rewriter, loc, hipMallocFn.value(), args); allocatedPtr = LLVM::LoadOp::create(rewriter, loc, ptr1ty, ptr); } else if (backend.starts_with("cpu")) { @@ -2682,7 +2668,7 @@ class ConvertOccupancyOp if (backend != "cuda" && backend != "rocm") return rewriter.notifyMatchFailure( - op, "Occupancy op lowering only supported for CUDA"); + op, "Occupancy op lowering only supported for CUDA and ROCM"); auto moduleOp = op->getParentOfType(); auto i64 = rewriter.getIntegerType(64); @@ -2749,7 +2735,7 @@ class ConvertGPUKernelAddressOp if (backend != "cuda" && backend != "rocm") return rewriter.notifyMatchFailure( - op, "KernelAddress lowering only supported for CUDA"); + op, "KernelAddress lowering only supported for CUDA and ROCM"); std::string funcStubName = getFuncStubName(op.getFn().getRootReference().getValue(), @@ -2763,7 +2749,7 @@ class ConvertGPUKernelAddressOp }; /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime -/// call. Currently it supports CUDA, CPU, and XLA. +/// call. Currently it supports CUDA, ROCM, CPU, and XLA. template class ConvertDeallocOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { @@ -3676,78 +3662,6 @@ struct OpLowering : public OpConversionPattern { } // namespace gpu } // namespace mlir -struct GPUBarrierToROCDL : ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - LLVM::FenceOp::create(rewriter, op.getLoc(), LLVM::AtomicOrdering::release, - StringRef("workgroup")); - rewriter.replaceOpWithNewOp(op); - LLVM::FenceOp::create(rewriter, op.getLoc(), LLVM::AtomicOrdering::acquire, - StringRef("workgroup")); - rewriter.eraseOp(op); - return success(); - } -}; - -struct ClusterIdOpToROCDL : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(gpu::ClusterIdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto indexType = getTypeConverter()->getIndexType(); - Value zero = LLVM::ConstantOp::create(rewriter, loc, indexType, - rewriter.getIndexAttr(0)); - - rewriter.replaceOp(op, zero); - return success(); - } -}; - -struct ClusterDimOpToROCDL : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(gpu::ClusterDimOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto indexType = getTypeConverter()->getIndexType(); - Value one = LLVM::ConstantOp::create(rewriter, loc, indexType, - rewriter.getIndexAttr(1)); - - rewriter.replaceOp(op, one); - return success(); - } -}; - -struct ClusterBlockIdToBlockIdLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(gpu::ClusterBlockIdOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getDimension()); - return success(); - } -}; - -struct ClusterDimBlocksToGridDimLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(gpu::ClusterDimBlocksOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getDimension()); - return success(); - } -}; - /// Appends the patterns lowering operations from the Func dialect to the LLVM /// dialect using the C-style type conversion, i.e. converting memrefs to /// pointer to arrays of arrays. @@ -3756,7 +3670,6 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, LLVMTypeConverter &typeConverter, std::string gpuTarget, bool func) { if (func) { - PatternBenefit highBenefit(2); patterns.add(typeConverter); patterns.add( typeConverter, @@ -3764,8 +3677,7 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, StringAttr::get(&typeConverter.getContext(), gpuTarget == "cuda" ? NVVM::NVVMDialect::getKernelFuncAttrName() - : ROCDL::ROCDLDialect::getKernelFuncAttrName()), - highBenefit); + : ROCDL::ROCDLDialect::getKernelFuncAttrName())); } else { if (gpuTarget == "cuda") { using namespace mlir::gpu::index_lowering; @@ -3808,41 +3720,10 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, populateLibDeviceConversionPatterns(typeConverter, patterns, benefit); patterns.add(typeConverter, benefit); } else if (gpuTarget == "rocm") { - using namespace mlir::gpu::index_lowering; - PatternBenefit benefit(1); - PatternBenefit highBenefit(2); - typeConverter.getContext().loadDialect(); - mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, mlir::gpu::amd::Runtime::HIP, amdgpu::Chipset()); - - patterns.add>(typeConverter, IndexKind::Block, IntrType::Id, - highBenefit); - patterns.add>(typeConverter, IndexKind::Block, IntrType::Dim, - highBenefit); - patterns.add>(typeConverter, IndexKind::Grid, IntrType::Id, - highBenefit); - patterns.add>(typeConverter, IndexKind::Grid, IntrType::Dim, - highBenefit); - - patterns.add(typeConverter, highBenefit); - - patterns.add(typeConverter, highBenefit); - patterns.add(typeConverter, highBenefit); - patterns.add(&typeConverter.getContext(), - highBenefit); - patterns.add( - &typeConverter.getContext(), highBenefit); } } } @@ -3874,29 +3755,7 @@ static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) { auto resumeOp = LLVM::LLVMFuncOp::create( moduleBuilder, fname, LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); - - return resumeOp; -} - -static LLVM::LLVMFuncOp addMocROCmFunction(ModuleOp module, Type streamTy) { - const char fname[] = "fake_rocm_dispatch"; - - MLIRContext *ctx = module->getContext(); - auto loc = module->getLoc(); - auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); - - for (auto fn : module.getBody()->getOps()) { - if (fn.getName() == fname) - return fn; - } - - auto voidTy = LLVM::LLVMVoidType::get(ctx); - auto ptrTy = LLVM::LLVMPointerType::get(ctx); - - auto resumeOp = LLVM::LLVMFuncOp::create( - moduleBuilder, fname, - LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); - + resumeOp.setPrivate(); return resumeOp; } @@ -3923,11 +3782,6 @@ struct NoAsyncOpLowering : public OpConversionPattern { struct AsyncOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - std::string backend; - AsyncOpLowering(LLVMTypeConverter &converter, std::string backend) - : ConvertOpToLLVMPattern(converter), - backend(std::move(backend)) {} - LogicalResult matchAndRewrite(async::ExecuteOp execute, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -4129,14 +3983,8 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { } assert(vals.size() == 3); - // auto f = addMocCUDAFunction(execute->getParentOfType(), - // vals.back().getType()); - - auto f = (backend == "cuda") - ? addMocCUDAFunction(execute->getParentOfType(), - vals.back().getType()) - : addMocROCmFunction(execute->getParentOfType(), - vals.back().getType()); + auto f = addMocCUDAFunction(execute->getParentOfType(), + vals.back().getType()); LLVM::CallOp::create(rewriter, execute.getLoc(), f, vals); rewriter.eraseOp(execute); @@ -4321,7 +4169,7 @@ struct ConvertPolygeistToLLVMPass if (backend == "cpu") { if (use_async) - patterns.add(converter, gpuTarget); + patterns.add(converter); else patterns.add(patterns.getContext()); } @@ -4505,4 +4353,4 @@ struct ConvertPolygeistToLLVMPass convertModule(m, /* gpuModule */ false); } }; -} // namespace \ No newline at end of file +} // namespace From f4a9e6ee1366490fc8394e183c9114f07e2e61b3 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 4 Dec 2025 20:52:16 -0600 Subject: [PATCH 26/27] fix --- src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index bc0aff87a7..6dbcf4f05f 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -4310,8 +4310,7 @@ struct ConvertPolygeistToLLVMPass for (auto e : toErase) { if (call.getName() == e) { call->erase(); - } else if (callee == "hipDeviceSynchronize") { - call->erase(); + return; } } }); @@ -4320,8 +4319,7 @@ struct ConvertPolygeistToLLVMPass call->erase(); } else if (call.getName() == "hipDeviceSynchronize") { call->erase(); - return; - } + return; } }); } From fdbb5b2a27c5446370ac211b04bbee27f2ef04cf Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 4 Dec 2025 23:55:20 -0600 Subject: [PATCH 27/27] rm ; for for-loop --- src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 6dbcf4f05f..01ff8bdba1 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -4140,7 +4140,7 @@ struct ConvertPolygeistToLLVMPass mod->emitError() << "failed to apply folding"; return signalPassFailure(); } - }; + } LLVMConversionTarget target(getContext()); RewritePatternSet patterns(&getContext());