Skip to content

Commit 3cef9d0

Browse files
committed
Replace calls to compile_internal with compile_subroutine
1 parent 9d31c80 commit 3cef9d0

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

python/cuda_cccl/cuda/compute/_cccl_interop.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,10 @@ def codegen(context, builder, impl_sig, args):
248248
input_vals = [builder.load(p) for p in input_ptrs]
249249
250250
# Call the original operator
251-
result = context.compile_internal(builder, op, sig, input_vals)
251+
# See NVIDIA/numba-cuda#590 for why we need compile_subroutine
252+
# vs compile_internal here:
253+
cres = context.compile_subroutine(builder, op, sig, caching=False)
254+
result = context.call_internal(builder, cres.fndesc, sig, input_vals)
252255
253256
# Store the result
254257
builder.store(result, ret_ptr)
@@ -298,7 +301,8 @@ def codegen(context, builder, impl_sig, args):
298301
offset_val = builder.load(offset_ptr)
299302
300303
sig = types.void(state_ptr_type, types.uint64)
301-
context.compile_internal(builder, advance_fn, sig, [state_ptr, offset_val])
304+
cres = context.compile_subroutine(builder, advance_fn, sig, caching=False)
305+
result = context.call_internal(builder, cres.fndesc, sig, [state_ptr, offset_val])
302306
303307
return context.get_dummy_value()
304308
return void_sig, codegen
@@ -342,7 +346,8 @@ def codegen(context, builder, impl_sig, args):
342346
result_ptr = builder.bitcast(args[1], value_type_llvm.as_pointer())
343347
344348
sig = types.void(state_ptr_type, types.CPointer(value_type))
345-
context.compile_internal(builder, deref_fn, sig, [state_ptr, result_ptr])
349+
cres = context.compile_subroutine(builder, deref_fn, sig, caching=False)
350+
result = context.call_internal(builder, cres.fndesc, sig, [state_ptr, result_ptr])
346351
347352
return context.get_dummy_value()
348353
return void_sig, codegen
@@ -388,7 +393,8 @@ def codegen(context, builder, impl_sig, args):
388393
value_val = builder.load(value_ptr)
389394
390395
sig = types.void(state_ptr_type, value_type)
391-
context.compile_internal(builder, deref_fn, sig, [state_ptr, value_val])
396+
cres = context.compile_subroutine(builder, deref_fn, sig, caching=False)
397+
result = context.call_internal(builder, cres.fndesc, sig, [state_ptr, value_val])
392398
393399
return context.get_dummy_value()
394400
return void_sig, codegen

0 commit comments

Comments
 (0)