Skip to content

Commit bb21263

Browse files
committed
cudaFree fix with addr space
1 parent 2e1f98a commit bb21263

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2717,20 +2717,23 @@ class ConvertDeallocOpToGpuRuntimeCallPattern
27172717
auto i32 = rewriter.getIntegerType(32);
27182718
auto moduleOp = deallocOp->getParentOfType<ModuleOp>();
27192719

2720-
auto ptr1ty = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
2720+
auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext());
27212721

27222722
if (backend == "cuda") {
27232723
auto one = LLVM::ConstantOp::create(rewriter, loc, i64,
27242724
rewriter.getI64IntegerAttr(1));
27252725

2726-
Type tys[] = {ptr1ty};
2726+
Type tys[] = {ptrty};
27272727
auto cudaFreeFn =
27282728
LLVM::lookupOrCreateFn(rewriter, moduleOp, "cudaFree", tys, i32);
27292729
if (failed(cudaFreeFn)) {
27302730
llvm::errs() << " cudafree already exists with different types\n";
27312731
return failure();
27322732
}
27332733

2734+
if (cast<LLVM::LLVMPointerType>(ptr.getType()).getAddressSpace() != 0)
2735+
ptr = LLVM::AddrSpaceCastOp::create(rewriter, loc, ptrty, ptr);
2736+
27342737
Value args[] = {
27352738
ptr,
27362739
};
@@ -2750,8 +2753,6 @@ class ConvertDeallocOpToGpuRuntimeCallPattern
27502753
};
27512754
LLVM::CallOp::create(rewriter, loc, freeFunc.value(), args);
27522755
} else if (backend.starts_with("xla")) {
2753-
auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext());
2754-
27552756
// handle, ptr
27562757
Type tys[] = {ptrty, ptrty};
27572758

0 commit comments

Comments
 (0)