Skip to content

Commit 9d31c80

Browse files
committed
Fix iterator sigs
1 parent 5fcffb8 commit 9d31c80

File tree

8 files changed

+268
-59
lines changed

8 files changed

+268
-59
lines changed

c/parallel/src/jit_templates/mappings/iterator.h

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
template <typename ValueTp>
2020
struct cccl_iterator_t_mapping
2121
{
22-
bool is_pointer = false;
23-
int size = 1;
24-
int alignment = 1;
25-
void (*advance)(void*, cuda::std::uint64_t) = nullptr;
26-
void (*dereference)(const void*, ValueTp*) = nullptr;
27-
void (*assign)(const void*, ValueTp);
22+
bool is_pointer = false;
23+
int size = 1;
24+
int alignment = 1;
25+
void (*advance)(void*, const void*) = nullptr;
26+
void (*dereference)(const void*, ValueTp*) = nullptr;
27+
void (*assign)(const void*, const void*);
2828

2929
using ValueT = ValueTp;
3030
};
@@ -68,22 +68,19 @@ struct parameter_mapping<cccl_iterator_t>
6868
{
6969
return std::format(
7070
R"output(
71-
extern "C" __device__ void {0}(void *, {1});
72-
extern "C" __device__ void {2}(const void *, {3});
71+
extern "C" __device__ void {0}(void *, const void*);
72+
extern "C" __device__ void {1}(const void *, const void*);
7373
)output",
7474
arg.advance.name,
75-
cccl_type_enum_to_name(cccl_type_enum::CCCL_UINT64),
76-
arg.dereference.name,
77-
cccl_type_enum_to_name(arg.value_type.type));
75+
arg.dereference.name);
7876
}
7977

8078
return std::format(
8179
R"input(
82-
extern "C" __device__ void {0}(void *, {1});
83-
extern "C" __device__ void {2}(const void *, {3}*);
80+
extern "C" __device__ void {0}(void *, const void*);
81+
extern "C" __device__ void {1}(const void *, {2}*);
8482
)input",
8583
arg.advance.name,
86-
cccl_type_enum_to_name(cccl_type_enum::CCCL_UINT64),
8784
arg.dereference.name,
8885
cccl_type_enum_to_name(arg.value_type.type));
8986
}

c/parallel/src/jit_templates/templates/input_iterator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct input_iterator_t
4747

4848
__device__ input_iterator_t& operator+=(difference_type diff)
4949
{
50-
Iterator.advance(&state, diff);
50+
Iterator.advance(&state, &diff);
5151
return *this;
5252
}
5353

c/parallel/src/jit_templates/templates/output_iterator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct output_iterator_proxy_t
3535
{
3636
__device__ output_iterator_proxy_t& operator=(AssignT x)
3737
{
38-
AssignF(&state, cuda::std::move(x));
38+
AssignF(&state, &x);
3939
return *this;
4040
}
4141

@@ -59,7 +59,7 @@ struct output_iterator_t
5959

6060
__device__ output_iterator_t& operator+=(difference_type diff)
6161
{
62-
Iterator.advance(&state, diff);
62+
Iterator.advance(&state, &diff);
6363
return *this;
6464
}
6565

c/parallel/src/kernels/iterators.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ std::string make_kernel_input_iterator(
4545
{
4646
const std::string iter_def = std::format(R"XXX(
4747
extern "C" __device__ void DEREF(const void *self_ptr, VALUE_T* result);
48-
extern "C" __device__ void ADVANCE(void *self_ptr, DIFF_T offset);
48+
extern "C" __device__ void ADVANCE(void *self_ptr, const void* offset);
4949
struct __align__(OP_ALIGNMENT) {0} {{
5050
using iterator_category = cuda::std::random_access_iterator_tag;
5151
using value_type = VALUE_T;
@@ -58,7 +58,7 @@ struct __align__(OP_ALIGNMENT) {0} {{
5858
return result;
5959
}}
6060
__device__ inline {0}& operator+=(difference_type diff) {{
61-
ADVANCE(data, diff);
61+
ADVANCE(data, &diff);
6262
return *this;
6363
}}
6464
__device__ inline value_type operator[](difference_type diff) const {{
@@ -99,14 +99,14 @@ std::string make_kernel_output_iterator(
9999
std::string_view advance)
100100
{
101101
const std::string iter_def = std::format(R"XXX(
102-
extern "C" __device__ void DEREF(const void *self_ptr, VALUE_T x);
103-
extern "C" __device__ void ADVANCE(void *self_ptr, DIFF_T offset);
102+
extern "C" __device__ void DEREF(const void *self_ptr, const void* x);
103+
extern "C" __device__ void ADVANCE(void *self_ptr, const void* offset);
104104
struct __align__(OP_ALIGNMENT) {0}_state_t {{
105105
char data[OP_SIZE];
106106
}};
107107
struct {0}_proxy_t {{
108108
__device__ {0}_proxy_t operator=(VALUE_T x) {{
109-
DEREF(&state, x);
109+
DEREF(&state, &x);
110110
return *this;
111111
}}
112112
{0}_state_t state;
@@ -119,7 +119,7 @@ struct {0} {{
119119
using reference = {0}_proxy_t;
120120
__device__ {0}_proxy_t operator*() const {{ return {{state}}; }}
121121
__device__ {0}& operator+=(difference_type diff) {{
122-
ADVANCE(&state, diff);
122+
ADVANCE(&state, &diff);
123123
return *this;
124124
}}
125125
__device__ {0}_proxy_t operator[](difference_type diff) const {{

c/parallel/test/test_util.h

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -931,8 +931,10 @@ inline std::tuple<std::string, std::string, std::string> make_random_access_iter
931931
{
932932
std::string state_def_src = std::format("struct {0} {{ {1}* data; }};\n", iterator_state_name, value_type);
933933
std::string advance_fn_def_src = std::format(
934-
"extern \"C\" __device__ void {0}({1}* state, unsigned long long offset) {{\n"
935-
" state->data += offset;\n"
934+
"extern \"C\" __device__ void {0}(void* state, const void* offset) {{\n"
935+
" auto* typed_state = static_cast<{1}*>(state);\n"
936+
" auto offset_val = *static_cast<const unsigned long long*>(offset);\n"
937+
" typed_state->data += offset_val;\n"
936938
"}}",
937939
advance_fn_name,
938940
iterator_state_name);
@@ -941,19 +943,22 @@ inline std::tuple<std::string, std::string, std::string> make_random_access_iter
941943
if (kind == iterator_kind::INPUT)
942944
{
943945
dereference_fn_def_src = std::format(
944-
"extern \"C\" __device__ void {0}({1}* state, {2}* result) {{\n"
945-
" *result = (*state->data){3};\n"
946+
"extern \"C\" __device__ void {0}(const void* state, {1}* result) {{\n"
947+
" auto* typed_state = static_cast<const {2}*>(state);\n"
948+
" *result = (*typed_state->data){3};\n"
946949
"}}",
947950
dereference_fn_name,
948-
iterator_state_name,
949951
value_type,
952+
iterator_state_name,
950953
transform);
951954
}
952955
else
953956
{
954957
dereference_fn_def_src = std::format(
955-
"extern \"C\" __device__ void {0}({1}* state, {2} x) {{\n"
956-
" *state->data = x{3};\n"
958+
"extern \"C\" __device__ void {0}(const void* state, const void* x) {{\n"
959+
" auto* typed_state = static_cast<const {1}*>(state);\n"
960+
" auto x_val = *static_cast<const {2}*>(x);\n"
961+
" *typed_state->data = x_val{3};\n"
957962
"}}",
958963
dereference_fn_name,
959964
iterator_state_name,
@@ -1033,17 +1038,16 @@ inline std::tuple<std::string, std::string, std::string> make_constant_iterator_
10331038
std::string_view dereference_fn_name)
10341039
{
10351040
std::string iterator_state_src = std::format("struct {0} {{ {1} value; }};\n", iterator_state_name, value_type);
1036-
std::string advance_fn_src = std::format(
1037-
"extern \"C\" __device__ void {0}({1}* state, unsigned long long offset) {{ }}",
1038-
advance_fn_name,
1039-
iterator_state_name);
1041+
std::string advance_fn_src =
1042+
std::format("extern \"C\" __device__ void {0}(void* state, const void* offset) {{ }}", advance_fn_name);
10401043
std::string dereference_fn_src = std::format(
1041-
"extern \"C\" __device__ void {0}({1}* state, {2}* result) {{ \n"
1042-
" *result = state->value;\n"
1044+
"extern \"C\" __device__ void {0}(const void* state, {1}* result) {{ \n"
1045+
" auto* typed_state = static_cast<const {2}*>(state);\n"
1046+
" *result = typed_state->value;\n"
10431047
"}}",
10441048
dereference_fn_name,
1045-
iterator_state_name,
1046-
value_type);
1049+
value_type,
1050+
iterator_state_name);
10471051

10481052
return std::make_tuple(iterator_state_src, advance_fn_src, dereference_fn_src);
10491053
}
@@ -1124,24 +1128,27 @@ struct {0} {{
11241128
const std::string it_state_def_src = std::format(it_state_src_tmpl, state_name, index_ty_name);
11251129

11261130
static constexpr std::string_view it_def_src_tmpl = R"XXX(
1127-
extern "C" __device__ void {0}({1}* state, {2} offset)
1131+
extern "C" __device__ void {0}(void* state, const void* offset)
11281132
{{
1129-
state->linear_id += offset;
1133+
auto* typed_state = static_cast<{1}*>(state);
1134+
auto offset_val = *static_cast<const {2}*>(offset);
1135+
typed_state->linear_id += offset_val;
11301136
}}
11311137
)XXX";
11321138

11331139
const std::string it_advance_fn_def_src =
11341140
std::format(it_def_src_tmpl, /*0*/ advance_fn_name, state_name, index_ty_name);
11351141

11361142
static constexpr std::string_view it_deref_src_tmpl = R"XXX(
1137-
extern "C" __device__ void {0}({1}* state, {2}* result)
1143+
extern "C" __device__ void {0}(const void* state, {1}* result)
11381144
{{
1139-
*result = (state->linear_id) * (state->segment_size);
1145+
auto* typed_state = static_cast<const {2}*>(state);
1146+
*result = (typed_state->linear_id) * (typed_state->segment_size);
11401147
}}
11411148
)XXX";
11421149

11431150
const std::string it_deref_fn_def_src =
1144-
std::format(it_deref_src_tmpl, dereference_fn_name, state_name, index_ty_name);
1151+
std::format(it_deref_src_tmpl, dereference_fn_name, index_ty_name, state_name);
11451152

11461153
return std::make_tuple(it_state_def_src, it_advance_fn_def_src, it_deref_fn_def_src);
11471154
}
@@ -1229,8 +1236,9 @@ struct {0} {{
12291236

12301237
static constexpr std::string_view transform_it_advance_fn_src_tmpl = R"XXX(
12311238
{3}
1232-
extern "C" __device__ void {0}({1} *transform_it_state, unsigned long long offset) {{
1233-
{2}(&(transform_it_state->base_it_state), offset);
1239+
extern "C" __device__ void {0}(void* transform_it_state, const void* offset) {{
1240+
auto* typed_state = static_cast<{1}*>(transform_it_state);
1241+
{2}(&(typed_state->base_it_state), offset);
12341242
}}
12351243
)XXX";
12361244

@@ -1244,11 +1252,12 @@ extern "C" __device__ void {0}({1} *transform_it_state, unsigned long long offse
12441252
static constexpr std::string_view transform_it_dereference_fn_src_tmpl = R"XXX(
12451253
{5}
12461254
{6}
1247-
extern "C" __device__ void {0}({1} *transform_it_state, {2}* result) {{
1255+
extern "C" __device__ void {0}(const void* transform_it_state, {2}* result) {{
1256+
auto* typed_state = static_cast<const {1}*>(transform_it_state);
12481257
{7} base_result;
1249-
{4}(&(transform_it_state->base_it_state), &base_result);
1258+
{4}(&(typed_state->base_it_state), &base_result);
12501259
*result = {3}(
1251-
&(transform_it_state->functor_state),
1260+
&(typed_state->functor_state),
12521261
base_result
12531262
);
12541263
}}

python/cuda_cccl/cuda/compute/_cccl_interop.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,142 @@ def wrapped_{op.__name__}({arg_str}):
277277
return wrapper_func, void_sig
278278

279279

280+
def _create_advance_wrapper(advance_fn, state_ptr_type):
281+
"""Creates a wrapper function for iterator advance that takes void* arguments.
282+
283+
The wrapper takes 2 void* arguments:
284+
- state pointer
285+
- offset pointer (points to uint64 value)
286+
"""
287+
void_sig = types.void(types.voidptr, types.voidptr)
288+
289+
wrapper_src = textwrap.dedent(f"""
290+
@intrinsic
291+
def impl(typingctx, state_arg, offset_arg):
292+
def codegen(context, builder, impl_sig, args):
293+
state_type_llvm = context.get_value_type(state_ptr_type.dtype)
294+
offset_type_llvm = context.get_value_type(types.uint64)
295+
296+
state_ptr = builder.bitcast(args[0], state_type_llvm.as_pointer())
297+
offset_ptr = builder.bitcast(args[1], offset_type_llvm.as_pointer())
298+
offset_val = builder.load(offset_ptr)
299+
300+
sig = types.void(state_ptr_type, types.uint64)
301+
context.compile_internal(builder, advance_fn, sig, [state_ptr, offset_val])
302+
303+
return context.get_dummy_value()
304+
return void_sig, codegen
305+
306+
def wrapped_{advance_fn.__name__}(state_arg, offset_arg):
307+
return impl(state_arg, offset_arg)
308+
""")
309+
310+
local_dict = {
311+
"types": types,
312+
"state_ptr_type": state_ptr_type,
313+
"advance_fn": advance_fn,
314+
"intrinsic": intrinsic,
315+
"void_sig": void_sig,
316+
}
317+
exec(wrapper_src, globals(), local_dict)
318+
319+
wrapper_func = local_dict[f"wrapped_{advance_fn.__name__}"]
320+
wrapper_func.__globals__.update(local_dict)
321+
322+
return wrapper_func, void_sig
323+
324+
325+
def _create_input_dereference_wrapper(deref_fn, state_ptr_type, value_type):
326+
"""Creates a wrapper function for input iterator dereference that takes void* arguments.
327+
328+
The wrapper takes 2 void* arguments:
329+
- state pointer
330+
- result pointer
331+
"""
332+
void_sig = types.void(types.voidptr, types.voidptr)
333+
334+
wrapper_src = textwrap.dedent(f"""
335+
@intrinsic
336+
def impl(typingctx, state_arg, result_arg):
337+
def codegen(context, builder, impl_sig, args):
338+
state_type_llvm = context.get_value_type(state_ptr_type.dtype)
339+
value_type_llvm = context.get_value_type(value_type)
340+
341+
state_ptr = builder.bitcast(args[0], state_type_llvm.as_pointer())
342+
result_ptr = builder.bitcast(args[1], value_type_llvm.as_pointer())
343+
344+
sig = types.void(state_ptr_type, types.CPointer(value_type))
345+
context.compile_internal(builder, deref_fn, sig, [state_ptr, result_ptr])
346+
347+
return context.get_dummy_value()
348+
return void_sig, codegen
349+
350+
def wrapped_{deref_fn.__name__}(state_arg, result_arg):
351+
return impl(state_arg, result_arg)
352+
""")
353+
354+
local_dict = {
355+
"types": types,
356+
"state_ptr_type": state_ptr_type,
357+
"value_type": value_type,
358+
"deref_fn": deref_fn,
359+
"intrinsic": intrinsic,
360+
"void_sig": void_sig,
361+
}
362+
exec(wrapper_src, globals(), local_dict)
363+
364+
wrapper_func = local_dict[f"wrapped_{deref_fn.__name__}"]
365+
wrapper_func.__globals__.update(local_dict)
366+
367+
return wrapper_func, void_sig
368+
369+
370+
def _create_output_dereference_wrapper(deref_fn, state_ptr_type, value_type):
371+
"""Creates a wrapper function for output iterator dereference that takes void* arguments.
372+
373+
The wrapper takes 2 void* arguments:
374+
- state pointer
375+
- value pointer (points to value)
376+
"""
377+
void_sig = types.void(types.voidptr, types.voidptr)
378+
379+
wrapper_src = textwrap.dedent(f"""
380+
@intrinsic
381+
def impl(typingctx, state_arg, value_arg):
382+
def codegen(context, builder, impl_sig, args):
383+
state_type_llvm = context.get_value_type(state_ptr_type.dtype)
384+
value_type_llvm = context.get_value_type(value_type)
385+
386+
state_ptr = builder.bitcast(args[0], state_type_llvm.as_pointer())
387+
value_ptr = builder.bitcast(args[1], value_type_llvm.as_pointer())
388+
value_val = builder.load(value_ptr)
389+
390+
sig = types.void(state_ptr_type, value_type)
391+
context.compile_internal(builder, deref_fn, sig, [state_ptr, value_val])
392+
393+
return context.get_dummy_value()
394+
return void_sig, codegen
395+
396+
def wrapped_{deref_fn.__name__}(state_arg, value_arg):
397+
return impl(state_arg, value_arg)
398+
""")
399+
400+
local_dict = {
401+
"types": types,
402+
"state_ptr_type": state_ptr_type,
403+
"value_type": value_type,
404+
"deref_fn": deref_fn,
405+
"intrinsic": intrinsic,
406+
"void_sig": void_sig,
407+
}
408+
exec(wrapper_src, globals(), local_dict)
409+
410+
wrapper_func = local_dict[f"wrapped_{deref_fn.__name__}"]
411+
wrapper_func.__globals__.update(local_dict)
412+
413+
return wrapper_func, void_sig
414+
415+
280416
def to_cccl_op(op: Callable | OpKind, sig: Signature | None) -> Op:
281417
"""Return an `Op` object corresponding to the given callable or well-known operation.
282418

0 commit comments

Comments
 (0)