@@ -248,7 +248,10 @@ def codegen(context, builder, impl_sig, args):
248248 input_vals = [builder.load(p) for p in input_ptrs]
249249
250250 # Call the original operator
251- result = context.compile_internal(builder, op, sig, input_vals)
251+ # See NVIDIA/numba-cuda#590 for why we need compile_subroutine
252+ # vs compile_internal here:
253+ cres = context.compile_subroutine(builder, op, sig, caching=False)
254+ result = context.call_internal(builder, cres.fndesc, sig, input_vals)
252255
253256 # Store the result
254257 builder.store(result, ret_ptr)
@@ -298,7 +301,8 @@ def codegen(context, builder, impl_sig, args):
298301 offset_val = builder.load(offset_ptr)
299302
300303 sig = types.void(state_ptr_type, types.uint64)
301- context.compile_internal(builder, advance_fn, sig, [state_ptr, offset_val])
304+ cres = context.compile_subroutine(builder, advance_fn, sig, caching=False)
305+ result = context.call_internal(builder, cres.fndesc, sig, [state_ptr, offset_val])
302306
303307 return context.get_dummy_value()
304308 return void_sig, codegen
@@ -342,7 +346,8 @@ def codegen(context, builder, impl_sig, args):
342346 result_ptr = builder.bitcast(args[1], value_type_llvm.as_pointer())
343347
344348 sig = types.void(state_ptr_type, types.CPointer(value_type))
345- context.compile_internal(builder, deref_fn, sig, [state_ptr, result_ptr])
349+ cres = context.compile_subroutine(builder, deref_fn, sig, caching=False)
350+ result = context.call_internal(builder, cres.fndesc, sig, [state_ptr, result_ptr])
346351
347352 return context.get_dummy_value()
348353 return void_sig, codegen
@@ -388,7 +393,8 @@ def codegen(context, builder, impl_sig, args):
388393 value_val = builder.load(value_ptr)
389394
390395 sig = types.void(state_ptr_type, value_type)
391- context.compile_internal(builder, deref_fn, sig, [state_ptr, value_val])
396+ cres = context.compile_subroutine(builder, deref_fn, sig, caching=False)
397+ result = context.call_internal(builder, cres.fndesc, sig, [state_ptr, value_val])
392398
393399 return context.get_dummy_value()
394400 return void_sig, codegen
0 commit comments