diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 975f166f9a..69d8c61b82 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -838,13 +838,15 @@ cc_library( cc_library( name = "XLADerivatives", - srcs = glob([ - "Implementations/*.cpp", - "Passes/*.cpp", - "Dialect/*.cpp", - "Dialect/Distributed/*.cpp", - "Dialect/Tessera/*.cpp", - ]) + [ + srcs = glob( + [ + "Implementations/*.cpp", + "Passes/*.cpp", + "Dialect/*.cpp", + "Dialect/Distributed/*.cpp", + "Dialect/Tessera/*.cpp", + ], + ) + [ "Utils.cpp", ], hdrs = glob([ @@ -903,6 +905,9 @@ cc_library( "@llvm-project//llvm:Passes", "@llvm-project//llvm:Scalar", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AMDGPUDialect", + "@llvm-project//mlir:AMDGPUToROCDL", + "@llvm-project//mlir:AMDGPUUtils", "@llvm-project//mlir:AffineAnalysis", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", @@ -933,6 +938,7 @@ cc_library( "@llvm-project//mlir:GPUPipelines", "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:IndexToLLVM", @@ -944,6 +950,7 @@ cc_library( "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathToLLVM", "@llvm-project//mlir:MathToLibm", + "@llvm-project//mlir:MathToROCDL", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefToLLVM", "@llvm-project//mlir:MemRefTransforms", diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 7c8e8e9237..e064fedc97 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -20,15 +20,19 @@ #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -925,20 +929,33 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { auto ptrty = LLVM::LLVMPointerType::get(op.getContext()); + auto i32 = rewriter.getIntegerType(32); + + std::string memcpyFuncName; + + bool xla = backend.starts_with("xla"); + + if (xla) { + memcpyFuncName = "reactantXLAMemcpy"; + } else if (backend == "cuda") { + memcpyFuncName = "cudaMemcpy"; + } else if (backend == "rocm") { + memcpyFuncName = "hipMemcpy"; + } + SmallVector tys = {ptrty, ptrty, size.getType(), rewriter.getIntegerType(32)}; if (backend.starts_with("xla")) { tys.insert(tys.begin(), ptrty); } - auto i32 = rewriter.getIntegerType(32); - bool xla = backend.starts_with("xla"); - auto cudaMemcpyFn = LLVM::lookupOrCreateFn( - rewriter, moduleOp, xla ? "reactantXLAMemcpy" : "cudaMemcpy", tys, + auto memcpyFn = LLVM::lookupOrCreateFn( + rewriter, moduleOp, memcpyFuncName, tys, xla ? (mlir::Type)LLVM::LLVMVoidType::get(rewriter.getContext()) : (mlir::Type)i32); - if (failed(cudaMemcpyFn)) + if (failed(memcpyFn)) { return failure(); + } SmallVector args = {dst, src, size, LLVM::ConstantOp::create(rewriter, op.getLoc(), @@ -954,7 +971,7 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { args.insert(args.begin(), xdata); } - LLVM::CallOp::create(rewriter, op.getLoc(), cudaMemcpyFn.value(), args); + LLVM::CallOp::create(rewriter, op.getLoc(), memcpyFn.value(), args); rewriter.eraseOp(op); return success(); } @@ -1546,7 +1563,7 @@ struct LowerGPUAlternativesOp auto kernelId = LLVM::createGlobalString( loc, rewriter, std::string("kernelId.") + std::to_string(num++), nullTermLocStr, LLVM::Linkage::Internal, /*opaquePointers*/ true); - auto totalAlternatives = LLVM::ConstantOp::create(rewriter, + auto totalAlternatives = LLVM::ConstantOp::create(rewriter, loc, llvmInt32Type, gao->getNumRegions()); auto alternative = rtPGOGetAlternativeCallBuilder @@ -1555,7 +1572,7 @@ struct LowerGPUAlternativesOp int i = 0; for (auto ®ion : gao->getRegions()) { - auto cmpOp = arith::CmpIOp::create(rewriter, + auto cmpOp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, alternative, arith::ConstantIntOp::create(rewriter, loc, i, 32)); auto ifOp = scf::IfOp::create(rewriter, loc, cmpOp, /* hasElse */ true); @@ -1745,6 +1762,14 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, auto loc = kernelModule.getLoc(); auto ctorloc = rewriter.getUnknownLoc(); + + std::string registerFatBinaryFuncName; + std::string registerFunctionFuncName; + std::string registerVarFuncName; + std::string unregisterFatBinaryFuncName; + std::string registerFatBinaryEndFuncName; + bool requiresRegisterEnd; + rewriter.modifyOpInPlace(kernelModule, [&]() { kernelModule->setAttr("polygeist_stubs", rewriter.getUnitAttr()); }); @@ -1811,6 +1836,23 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, moduleIDPrefix = "__hip_"; fatMagic = HIPFatMagic; } + + if (gpuTarget == "cuda") { + registerFatBinaryFuncName = "__cudaRegisterFatBinary"; + registerFunctionFuncName = "__cudaRegisterFunction"; + registerVarFuncName = "__cudaRegisterVar"; + unregisterFatBinaryFuncName = "__cudaUnregisterFatBinary"; + registerFatBinaryEndFuncName = "__cudaRegisterFatBinaryEnd"; + requiresRegisterEnd = true; + } else { + registerFatBinaryFuncName = "__hipRegisterFatBinary"; + registerFunctionFuncName = "__hipRegisterFunction"; + registerVarFuncName = "__hipRegisterVar"; + unregisterFatBinaryFuncName = "__hipUnregisterFatBinary"; + registerFatBinaryEndFuncName = ""; + requiresRegisterEnd = false; + } + (void)fatbinConstantName; (void)moduleIDSectionName; @@ -1874,16 +1916,18 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, auto bitcastOfWrapper = LLVM::AddrSpaceCastOp::create( ctorBuilder, ctorloc, llvmPointerType, addressOfWrapper); - auto cudaRegisterFatbinFn = - LLVM::lookupOrCreateFn(rewriter, moduleOp, "__cudaRegisterFatBinary", + auto registerFatbinFn = + LLVM::lookupOrCreateFn(rewriter, moduleOp, registerFatBinaryFuncName, llvmPointerType, llvmPointerType); - if (failed(cudaRegisterFatbinFn)) { - llvm::errs() << " cudamalloc already exists with different types\n"; + + if (failed(registerFatbinFn)) { + llvm::errs() + << "register fatbin function already exists with different types\n"; return failure(); } auto module = - LLVM::CallOp::create(rewriter, ctorloc, cudaRegisterFatbinFn.value(), + LLVM::CallOp::create(rewriter, ctorloc, registerFatbinFn.value(), ValueRange(bitcastOfWrapper)); auto moduleGlobalName = @@ -1939,12 +1983,16 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType}; - auto cudaRegisterFn = LLVM::lookupOrCreateFn( - rewriter, moduleOp, "__cudaRegisterFunction", tys, llvmInt32Type); - if (failed(cudaRegisterFn)) { - llvm::errs() << " cudamalloc already exists with different types\n"; + + auto registerFunctionFn = LLVM::lookupOrCreateFn( + rewriter, moduleOp, registerFunctionFuncName, tys, llvmInt32Type); + + if (failed(registerFunctionFn)) { + llvm::errs() + << " register function already exists with different types\n"; return failure(); } + Value args[] = { module.getResult(), bitcast, @@ -1957,7 +2005,8 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, nullPtr, nullPtr}; - LLVM::CallOp::create(rewriter, ctorloc, cudaRegisterFn.value(), args); + LLVM::CallOp::create(rewriter, ctorloc, registerFunctionFn.value(), + args); } else if (LLVM::GlobalOp g = dyn_cast(op)) { int addrSpace = g.getAddrSpace(); if (addrSpace != 1 /* device */ && addrSpace != 4 /* constant */) @@ -1988,6 +2037,7 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, // to pass the GPU DL in here DataLayout DLI(moduleOp); auto size = DLI.getTypeSize(globalTy); + // why 'mgpu' rtRegisterVarCallBuilder.create( ctorloc, ctorBuilder, {module.getResult(), bitcast, symbolName, symbolName, @@ -2005,7 +2055,8 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, 0)}); } } - // TODO this has to happen only for some CUDA versions + // TODO this has to happen only for some CUDA versions, hip does not need + // finialize cuda 11.X if (gpuTarget == "cuda") { auto cudaRegisterFatbinFn = LLVM::lookupOrCreateFn( rewriter, moduleOp, "__cudaRegisterFatBinaryEnd", llvmPointerType, @@ -2037,16 +2088,17 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, auto module = LLVM::LoadOp::create( dtorBuilder, ctorloc, llvmPointerPointerType, aoo->getResult(0)); - auto cudaUnRegisterFatbinFn = LLVM::lookupOrCreateFn( - rewriter, moduleOp, "__cudaUnregisterFatBinary", llvmPointerType, + auto unregisterFatbinFn = LLVM::lookupOrCreateFn( + rewriter, moduleOp, unregisterFatBinaryFuncName, llvmPointerType, llvmVoidType); - if (failed(cudaUnRegisterFatbinFn)) { - llvm::errs() << " cudamalloc already exists with different types\n"; + if (failed(unregisterFatbinFn)) { + llvm::errs() << " unregister fatbin function already exists with " + "different types\n"; return failure(); } - - LLVM::CallOp::create(rewriter, ctorloc, cudaUnRegisterFatbinFn.value(), + LLVM::CallOp::create(rewriter, ctorloc, unregisterFatbinFn.value(), ValueRange(module)); + LLVM::ReturnOp::create(dtorBuilder, ctorloc, ValueRange()); auto dtorSymbol = FlatSymbolRefAttr::get(dtor); { @@ -2171,10 +2223,18 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); Type tys[] = {ptrty, i64, i32, i64, i32, ptrty, i64, ptrty}; - auto launchCall = LLVM::CallOp::create( - rewriter, loc, TypeRange(i32), "cudaLaunchKernel", - args); // FlatSymbolRefAttr::get(rewriter.getStringAttr("cudaLaunchKernel")), - // args); + std::string launchFuncName; + if (gpuTarget == "cuda") { + launchFuncName = "cudaLaunchKernel"; + } else if (gpuTarget == "rocm") { + launchFuncName = "hipLaunchKernel"; + } else { + launchFuncName = "cudaLaunchKernel"; + } + + auto launchCall = + LLVM::CallOp::create(rewriter, loc, TypeRange(i32), launchFuncName, args); + if (launchOp.getAsyncToken()) { // Async launch: make dependent ops use the same stream. rewriter.replaceOp(launchOp, {stream}); @@ -2463,6 +2523,26 @@ class ConvertAllocOpToGpuRuntimeCallPattern }; LLVM::CallOp::create(rewriter, loc, cudaMallocFn.value(), args); allocatedPtr = LLVM::LoadOp::create(rewriter, loc, ptr1ty, ptr); + } else if (backend == "rocm") { + auto one = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(1)); + + auto ptr = LLVM::AllocaOp::create(rewriter, loc, ptrty, ptr1ty, one); + Type tys[] = {ptrty, i64}; + + auto hipMallocFn = + LLVM::lookupOrCreateFn(rewriter, moduleOp, "hipMalloc", tys, i32); + if (failed(hipMallocFn)) { + llvm::errs() << " hipMalloc already exists with different types\n"; + return failure(); + } + + Value args[] = { + ptr, + sizeBytes, + }; + LLVM::CallOp::create(rewriter, loc, hipMallocFn.value(), args); + allocatedPtr = LLVM::LoadOp::create(rewriter, loc, ptr1ty, ptr); } else if (backend.starts_with("cpu")) { Type convertedIndex = typeConverter->convertType(rewriter.getIndexType()); @@ -2602,7 +2682,7 @@ class ConvertOccupancyOp if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); - if (backend != "cuda") + if (backend != "cuda" && backend != "rocm") return rewriter.notifyMatchFailure( op, "Occupancy op lowering only supported for CUDA"); @@ -2617,12 +2697,20 @@ class ConvertOccupancyOp Type tys[] = {ptrty, ptrty, intty, adaptor.getDynamicSMemSize().getType(), adaptor.getFlags().getType()}; - auto cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn = - LLVM::lookupOrCreateFn( - rewriter, moduleOp, - "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", tys, i32); - if (failed(cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn)) { - llvm::errs() << " cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags " + std::string occupancyFuncName; + if (backend == "cuda") { + occupancyFuncName = + "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"; + } else if (backend == "rocm") { + occupancyFuncName = + "hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"; + } + + auto occupancyFn = + LLVM::lookupOrCreateFn(rewriter, moduleOp, occupancyFuncName, tys, i32); + + if (failed(occupancyFn)) { + llvm::errs() << " occupancyMaxActiveBlocksPerMultiprocessorWithFlags " "already exists with different types\n"; return failure(); } @@ -2638,9 +2726,7 @@ class ConvertOccupancyOp auto addr = LLVM::AddressOfOp::create(rewriter, loc, ptrty, funcStubName); Value args[] = {ptr, addr, adaptor.getBlockSize(), adaptor.getDynamicSMemSize(), adaptor.getFlags()}; - LLVM::CallOp::create( - rewriter, loc, - cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn.value(), args); + LLVM::CallOp::create(rewriter, loc, occupancyFn.value(), args); rewriter.replaceOpWithNewOp(op, intty, ptr); return success(); @@ -2663,7 +2749,7 @@ class ConvertGPUKernelAddressOp matchAndRewrite(enzymexla::GPUKernelAddressOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (backend != "cuda") + if (backend != "cuda" && backend != "rocm") return rewriter.notifyMatchFailure( op, "KernelAddress lowering only supported for CUDA"); @@ -2735,6 +2821,20 @@ class ConvertDeallocOpToGpuRuntimeCallPattern ptr, }; LLVM::CallOp::create(rewriter, loc, cudaFreeFn.value(), args); + } else if (backend == "rocm") { + Type tys[] = {ptr1ty}; + auto hipFreeFn = + LLVM::lookupOrCreateFn(rewriter, moduleOp, "hipFree", tys, i32); + + if (failed(hipFreeFn)) { + llvm::errs() << " hipfree already exists with different types\n"; + return failure(); + } + Value args[] = { + ptr, + }; + LLVM::CallOp::create(rewriter, loc, hipFreeFn.value(), args); + } else if (backend.starts_with("cpu")) { FailureOr freeFunc = @@ -2926,8 +3026,9 @@ struct GPUFuncOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; GPUFuncOpLowering(LLVMTypeConverter &converter, unsigned allocaAddrSpace, - StringAttr kernelAttributeName) - : ConvertOpToLLVMPattern(converter), + StringAttr kernelAttributeName, + PatternBenefit benefit = PatternBenefit(1)) + : ConvertOpToLLVMPattern(converter, benefit), allocaAddrSpace(allocaAddrSpace), kernelAttributeName(kernelAttributeName) {} @@ -3577,6 +3678,78 @@ struct OpLowering : public OpConversionPattern { } // namespace gpu } // namespace mlir +struct GPUBarrierToROCDL : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LLVM::FenceOp::create(rewriter, op.getLoc(), LLVM::AtomicOrdering::release, + StringRef("workgroup")); + rewriter.replaceOpWithNewOp(op); + LLVM::FenceOp::create(rewriter, op.getLoc(), LLVM::AtomicOrdering::acquire, + StringRef("workgroup")); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ClusterIdOpToROCDL : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::ClusterIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto indexType = getTypeConverter()->getIndexType(); + Value zero = LLVM::ConstantOp::create(rewriter, loc, indexType, + rewriter.getIndexAttr(0)); + + rewriter.replaceOp(op, zero); + return success(); + } +}; + +struct ClusterDimOpToROCDL : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::ClusterDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto indexType = getTypeConverter()->getIndexType(); + Value one = LLVM::ConstantOp::create(rewriter, loc, indexType, + rewriter.getIndexAttr(1)); + + rewriter.replaceOp(op, one); + return success(); + } +}; + +struct ClusterBlockIdToBlockIdLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gpu::ClusterBlockIdOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getDimension()); + return success(); + } +}; + +struct ClusterDimBlocksToGridDimLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gpu::ClusterDimBlocksOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getDimension()); + return success(); + } +}; + /// Appends the patterns lowering operations from the Func dialect to the LLVM /// dialect using the C-style type conversion, i.e. converting memrefs to /// pointer to arrays of arrays. @@ -3585,6 +3758,7 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, LLVMTypeConverter &typeConverter, std::string gpuTarget, bool func) { if (func) { + PatternBenefit highBenefit(2); patterns.add(typeConverter); patterns.add( typeConverter, @@ -3592,7 +3766,8 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, StringAttr::get(&typeConverter.getContext(), gpuTarget == "cuda" ? NVVM::NVVMDialect::getKernelFuncAttrName() - : ROCDL::ROCDLDialect::getKernelFuncAttrName())); + : ROCDL::ROCDLDialect::getKernelFuncAttrName()), + highBenefit); } else { if (gpuTarget == "cuda") { using namespace mlir::gpu::index_lowering; @@ -3634,6 +3809,42 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, populateLibDeviceConversionPatterns(typeConverter, patterns, benefit); patterns.add(typeConverter, benefit); + } else if (gpuTarget == "rocm") { + using namespace mlir::gpu::index_lowering; + PatternBenefit benefit(1); + PatternBenefit highBenefit(2); + + typeConverter.getContext().loadDialect(); + + mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, + mlir::gpu::amd::Runtime::HIP, + amdgpu::Chipset()); + + patterns.add>(typeConverter, IndexKind::Block, IntrType::Id, + highBenefit); + patterns.add>(typeConverter, IndexKind::Block, IntrType::Dim, + highBenefit); + patterns.add>(typeConverter, IndexKind::Grid, IntrType::Id, + highBenefit); + patterns.add>(typeConverter, IndexKind::Grid, IntrType::Dim, + highBenefit); + + patterns.add(typeConverter, highBenefit); + + patterns.add(typeConverter, highBenefit); + patterns.add(typeConverter, highBenefit); + patterns.add(&typeConverter.getContext(), + highBenefit); + patterns.add( + &typeConverter.getContext(), highBenefit); } } } @@ -3665,7 +3876,28 @@ static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) { auto resumeOp = LLVM::LLVMFuncOp::create( moduleBuilder, fname, LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); - resumeOp.setPrivate(); + + return resumeOp; +} + +static LLVM::LLVMFuncOp addMocROCmFunction(ModuleOp module, Type streamTy) { + const char fname[] = "fake_rocm_dispatch"; + + MLIRContext *ctx = module->getContext(); + auto loc = module->getLoc(); + auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); + + for (auto fn : module.getBody()->getOps()) { + if (fn.getName() == fname) + return fn; + } + + auto voidTy = LLVM::LLVMVoidType::get(ctx); + auto ptrTy = LLVM::LLVMPointerType::get(ctx); + + auto resumeOp = LLVM::LLVMFuncOp::create( + moduleBuilder, fname, + LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); return resumeOp; } @@ -3693,6 +3925,11 @@ struct NoAsyncOpLowering : public OpConversionPattern { struct AsyncOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + std::string backend; + AsyncOpLowering(LLVMTypeConverter &converter, std::string backend) + : ConvertOpToLLVMPattern(converter), + backend(std::move(backend)) {} + LogicalResult matchAndRewrite(async::ExecuteOp execute, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -3894,8 +4131,14 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { } assert(vals.size() == 3); - auto f = addMocCUDAFunction(execute->getParentOfType(), - vals.back().getType()); + // auto f = addMocCUDAFunction(execute->getParentOfType(), + // vals.back().getType()); + + auto f = (backend == "cuda") + ? addMocCUDAFunction(execute->getParentOfType(), + vals.back().getType()) + : addMocROCmFunction(execute->getParentOfType(), + vals.back().getType()); LLVM::CallOp::create(rewriter, execute.getLoc(), f, vals); rewriter.eraseOp(execute); @@ -3996,13 +4239,23 @@ struct ConvertPolygeistToLLVMPass bool hasLaunch = m->walk([](gpu::LaunchFuncOp) { return WalkResult::interrupt(); }).wasInterrupted(); + + std::string launchFuncName; + if (backend == "cuda") { + launchFuncName = "cudaLaunchKernel"; + } else if (backend == "rocm") { + launchFuncName = "hipLaunchKernel"; + } else { + launchFuncName = "cudaLaunchKernel"; + } if (hasLaunch) { OpBuilder rewriter(m); auto i32 = rewriter.getIntegerType(32); auto i64 = rewriter.getIntegerType(64); auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); Type tys[] = {ptrty, i64, i32, i64, i32, ptrty, i64, ptrty}; - LLVM::lookupOrCreateFn(rewriter, m, "cudaLaunchKernel", tys, i32); + + LLVM::lookupOrCreateFn(rewriter, m, launchFuncName, tys, i32); } for (auto mod : gmods) { @@ -4018,6 +4271,11 @@ struct ConvertPolygeistToLLVMPass target.addLegalOp(); target.addLegalDialect(); target.addLegalOp(); + } else if (backend == "rocm") { + target.addIllegalDialect(); + target.addLegalOp(); + target.addLegalDialect(); + target.addLegalOp(); } } @@ -4067,7 +4325,7 @@ struct ConvertPolygeistToLLVMPass if (backend == "cpu") { if (use_async) - patterns.add(converter); + patterns.add(converter, gpuTarget); else patterns.add(patterns.getContext()); } @@ -4189,12 +4447,16 @@ struct ConvertPolygeistToLLVMPass if (auto callee = call.getCallee()) { if (callee == "cudaDeviceSynchronize") { call->erase(); + } else if (callee == "hipDeviceSynchronize") { + call->erase(); } } }); m->walk([](LLVM::LLVMFuncOp call) { if (call.getName() == "cudaDeviceSynchronize") { call->erase(); + } else if (call.getName() == "hipDeviceSynchronize") { + call->erase(); } }); } @@ -4223,7 +4485,6 @@ struct ConvertPolygeistToLLVMPass signalPassFailure(); return; } - { const char *GetDeviceFromHostFuncName = "__reactant$get_device_from_host"; SmallVector toHandle; @@ -4248,4 +4509,4 @@ struct ConvertPolygeistToLLVMPass convertModule(m, /* gpuModule */ false); } }; -} // namespace +} // namespace \ No newline at end of file diff --git a/test/lit_tests/lowering/rocm.mlir b/test/lit_tests/lowering/rocm.mlir new file mode 100644 index 0000000000..48bea8ccb5 --- /dev/null +++ b/test/lit_tests/lowering/rocm.mlir @@ -0,0 +1,45 @@ +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(convert-polygeist-to-llvm{backend=rocm})" | FileCheck %s + +module attributes {gpu.container_module} { + llvm.func @test_rocm_launch(%arg0: !llvm.ptr) { + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %c1_i64 = arith.constant 1 : i64 + %stream = llvm.inttoptr %c1_i64 : i64 to !llvm.ptr + %token = "enzymexla.stream2token"(%stream) : (!llvm.ptr) -> !gpu.async.token + gpu.launch_func [%token] @test_module::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c32, %c1, %c1) args(%arg0 : !llvm.ptr) + llvm.return + } + + func.func @test_rocm_alloc() { + %alloc = gpu.alloc() : memref<256xf32, 1> + gpu.dealloc %alloc : memref<256xf32, 1> + return + } + + func.func @test_rocm_memcpy(%src: memref<256xf32>, %dst: memref<256xf32, 1>) { + %c1024 = arith.constant 1024 : index + "enzymexla.memcpy"(%dst, %src, %c1024) : (memref<256xf32, 1>, memref<256xf32>, index) -> () + return + } + + gpu.module @test_module { + gpu.func @test_kernel(%arg0: !llvm.ptr) kernel { + gpu.return + } + } +} + +// CHECK-LABEL: llvm.func @test_rocm_launch +// CHECK: llvm.call @hipLaunchKernel +// CHECK-NOT: cudaLaunchKernel + +// CHECK-LABEL: llvm.func @test_rocm_alloc +// CHECK: llvm.call @hipMalloc +// CHECK: llvm.call @hipFree +// CHECK-NOT: cudaMalloc +// CHECK-NOT: cudaFree + +// CHECK-LABEL: llvm.func @test_rocm_memcpy +// CHECK: llvm.call @hipMemcpy +// CHECK-NOT: cudaMemcpy