@@ -277,6 +277,142 @@ def wrapped_{op.__name__}({arg_str}):
277277 return wrapper_func , void_sig
278278
279279
280+ def _create_advance_wrapper (advance_fn , state_ptr_type ):
281+ """Creates a wrapper function for iterator advance that takes void* arguments.
282+
283+ The wrapper takes 2 void* arguments:
284+ - state pointer
285+ - offset pointer (points to uint64 value)
286+ """
287+ void_sig = types .void (types .voidptr , types .voidptr )
288+
289+ wrapper_src = textwrap .dedent (f"""
290+ @intrinsic
291+ def impl(typingctx, state_arg, offset_arg):
292+ def codegen(context, builder, impl_sig, args):
293+ state_type_llvm = context.get_value_type(state_ptr_type.dtype)
294+ offset_type_llvm = context.get_value_type(types.uint64)
295+
296+ state_ptr = builder.bitcast(args[0], state_type_llvm.as_pointer())
297+ offset_ptr = builder.bitcast(args[1], offset_type_llvm.as_pointer())
298+ offset_val = builder.load(offset_ptr)
299+
300+ sig = types.void(state_ptr_type, types.uint64)
301+ context.compile_internal(builder, advance_fn, sig, [state_ptr, offset_val])
302+
303+ return context.get_dummy_value()
304+ return void_sig, codegen
305+
306+ def wrapped_{ advance_fn .__name__ } (state_arg, offset_arg):
307+ return impl(state_arg, offset_arg)
308+ """ )
309+
310+ local_dict = {
311+ "types" : types ,
312+ "state_ptr_type" : state_ptr_type ,
313+ "advance_fn" : advance_fn ,
314+ "intrinsic" : intrinsic ,
315+ "void_sig" : void_sig ,
316+ }
317+ exec (wrapper_src , globals (), local_dict )
318+
319+ wrapper_func = local_dict [f"wrapped_{ advance_fn .__name__ } " ]
320+ wrapper_func .__globals__ .update (local_dict )
321+
322+ return wrapper_func , void_sig
323+
324+
325+ def _create_input_dereference_wrapper (deref_fn , state_ptr_type , value_type ):
326+ """Creates a wrapper function for input iterator dereference that takes void* arguments.
327+
328+ The wrapper takes 2 void* arguments:
329+ - state pointer
330+ - result pointer
331+ """
332+ void_sig = types .void (types .voidptr , types .voidptr )
333+
334+ wrapper_src = textwrap .dedent (f"""
335+ @intrinsic
336+ def impl(typingctx, state_arg, result_arg):
337+ def codegen(context, builder, impl_sig, args):
338+ state_type_llvm = context.get_value_type(state_ptr_type.dtype)
339+ value_type_llvm = context.get_value_type(value_type)
340+
341+ state_ptr = builder.bitcast(args[0], state_type_llvm.as_pointer())
342+ result_ptr = builder.bitcast(args[1], value_type_llvm.as_pointer())
343+
344+ sig = types.void(state_ptr_type, types.CPointer(value_type))
345+ context.compile_internal(builder, deref_fn, sig, [state_ptr, result_ptr])
346+
347+ return context.get_dummy_value()
348+ return void_sig, codegen
349+
350+ def wrapped_{ deref_fn .__name__ } (state_arg, result_arg):
351+ return impl(state_arg, result_arg)
352+ """ )
353+
354+ local_dict = {
355+ "types" : types ,
356+ "state_ptr_type" : state_ptr_type ,
357+ "value_type" : value_type ,
358+ "deref_fn" : deref_fn ,
359+ "intrinsic" : intrinsic ,
360+ "void_sig" : void_sig ,
361+ }
362+ exec (wrapper_src , globals (), local_dict )
363+
364+ wrapper_func = local_dict [f"wrapped_{ deref_fn .__name__ } " ]
365+ wrapper_func .__globals__ .update (local_dict )
366+
367+ return wrapper_func , void_sig
368+
369+
370+ def _create_output_dereference_wrapper (deref_fn , state_ptr_type , value_type ):
371+ """Creates a wrapper function for output iterator dereference that takes void* arguments.
372+
373+ The wrapper takes 2 void* arguments:
374+ - state pointer
375+ - value pointer (points to value)
376+ """
377+ void_sig = types .void (types .voidptr , types .voidptr )
378+
379+ wrapper_src = textwrap .dedent (f"""
380+ @intrinsic
381+ def impl(typingctx, state_arg, value_arg):
382+ def codegen(context, builder, impl_sig, args):
383+ state_type_llvm = context.get_value_type(state_ptr_type.dtype)
384+ value_type_llvm = context.get_value_type(value_type)
385+
386+ state_ptr = builder.bitcast(args[0], state_type_llvm.as_pointer())
387+ value_ptr = builder.bitcast(args[1], value_type_llvm.as_pointer())
388+ value_val = builder.load(value_ptr)
389+
390+ sig = types.void(state_ptr_type, value_type)
391+ context.compile_internal(builder, deref_fn, sig, [state_ptr, value_val])
392+
393+ return context.get_dummy_value()
394+ return void_sig, codegen
395+
396+ def wrapped_{ deref_fn .__name__ } (state_arg, value_arg):
397+ return impl(state_arg, value_arg)
398+ """ )
399+
400+ local_dict = {
401+ "types" : types ,
402+ "state_ptr_type" : state_ptr_type ,
403+ "value_type" : value_type ,
404+ "deref_fn" : deref_fn ,
405+ "intrinsic" : intrinsic ,
406+ "void_sig" : void_sig ,
407+ }
408+ exec (wrapper_src , globals (), local_dict )
409+
410+ wrapper_func = local_dict [f"wrapped_{ deref_fn .__name__ } " ]
411+ wrapper_func .__globals__ .update (local_dict )
412+
413+ return wrapper_func , void_sig
414+
415+
280416def to_cccl_op (op : Callable | OpKind , sig : Signature | None ) -> Op :
281417 """Return an `Op` object corresponding to the given callable or well-known operation.
282418
0 commit comments