Skip to content

Commit 6b78065

Browse files
committed
Wrap iterator advance/deref with void* wrappers
1 parent 7d2ee97 commit 6b78065

File tree

7 files changed

+217
-60
lines changed

7 files changed

+217
-60
lines changed

c/parallel/src/jit_templates/mappings/iterator.h

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
template <typename ValueTp>
2020
struct cccl_iterator_t_mapping
2121
{
22-
bool is_pointer = false;
23-
int size = 1;
24-
int alignment = 1;
25-
void (*advance)(void*, cuda::std::uint64_t) = nullptr;
26-
void (*dereference)(const void*, ValueTp*) = nullptr;
27-
void (*assign)(const void*, ValueTp);
22+
bool is_pointer = false;
23+
int size = 1;
24+
int alignment = 1;
25+
void (*advance)(void*, const void*) = nullptr;
26+
void (*dereference)(const void*, ValueTp*) = nullptr;
27+
void (*assign)(const void*, const void*);
2828

2929
using ValueT = ValueTp;
3030
};
@@ -68,22 +68,19 @@ struct parameter_mapping<cccl_iterator_t>
6868
{
6969
return std::format(
7070
R"output(
71-
extern "C" __device__ void {0}(void *, {1});
72-
extern "C" __device__ void {2}(const void *, {3});
71+
extern "C" __device__ void {0}(void *, const void*);
72+
extern "C" __device__ void {1}(const void *, const void*);
7373
)output",
7474
arg.advance.name,
75-
cccl_type_enum_to_name(cccl_type_enum::CCCL_UINT64),
76-
arg.dereference.name,
77-
cccl_type_enum_to_name(arg.value_type.type));
75+
arg.dereference.name);
7876
}
7977

8078
return std::format(
8179
R"input(
82-
extern "C" __device__ void {0}(void *, {1});
83-
extern "C" __device__ void {2}(const void *, {3}*);
80+
extern "C" __device__ void {0}(void *, const void*);
81+
extern "C" __device__ void {1}(const void *, {2}*);
8482
)input",
8583
arg.advance.name,
86-
cccl_type_enum_to_name(cccl_type_enum::CCCL_UINT64),
8784
arg.dereference.name,
8885
cccl_type_enum_to_name(arg.value_type.type));
8986
}

c/parallel/src/jit_templates/templates/input_iterator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct input_iterator_t
4747

4848
__device__ input_iterator_t& operator+=(difference_type diff)
4949
{
50-
Iterator.advance(&state, diff);
50+
Iterator.advance(&state, &diff);
5151
return *this;
5252
}
5353

c/parallel/src/jit_templates/templates/output_iterator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct output_iterator_proxy_t
3535
{
3636
__device__ output_iterator_proxy_t& operator=(AssignT x)
3737
{
38-
AssignF(&state, cuda::std::move(x));
38+
AssignF(&state, &x);
3939
return *this;
4040
}
4141

@@ -59,7 +59,7 @@ struct output_iterator_t
5959

6060
__device__ output_iterator_t& operator+=(difference_type diff)
6161
{
62-
Iterator.advance(&state, diff);
62+
Iterator.advance(&state, &diff);
6363
return *this;
6464
}
6565

c/parallel/src/kernels/iterators.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ std::string make_kernel_input_iterator(
4545
{
4646
const std::string iter_def = std::format(R"XXX(
4747
extern "C" __device__ void DEREF(const void *self_ptr, VALUE_T* result);
48-
extern "C" __device__ void ADVANCE(void *self_ptr, DIFF_T offset);
48+
extern "C" __device__ void ADVANCE(void *self_ptr, const void* offset);
4949
struct __align__(OP_ALIGNMENT) {0} {{
5050
using iterator_category = cuda::std::random_access_iterator_tag;
5151
using value_type = VALUE_T;
@@ -58,7 +58,7 @@ struct __align__(OP_ALIGNMENT) {0} {{
5858
return result;
5959
}}
6060
__device__ inline {0}& operator+=(difference_type diff) {{
61-
ADVANCE(data, diff);
61+
ADVANCE(data, &diff);
6262
return *this;
6363
}}
6464
__device__ inline value_type operator[](difference_type diff) const {{
@@ -99,14 +99,14 @@ std::string make_kernel_output_iterator(
9999
std::string_view advance)
100100
{
101101
const std::string iter_def = std::format(R"XXX(
102-
extern "C" __device__ void DEREF(const void *self_ptr, VALUE_T x);
103-
extern "C" __device__ void ADVANCE(void *self_ptr, DIFF_T offset);
102+
extern "C" __device__ void DEREF(const void *self_ptr, const void* x);
103+
extern "C" __device__ void ADVANCE(void *self_ptr, const void* offset);
104104
struct __align__(OP_ALIGNMENT) {0}_state_t {{
105105
char data[OP_SIZE];
106106
}};
107107
struct {0}_proxy_t {{
108108
__device__ {0}_proxy_t operator=(VALUE_T x) {{
109-
DEREF(&state, x);
109+
DEREF(&state, &x);
110110
return *this;
111111
}}
112112
{0}_state_t state;
@@ -119,7 +119,7 @@ struct {0} {{
119119
using reference = {0}_proxy_t;
120120
__device__ {0}_proxy_t operator*() const {{ return {{state}}; }}
121121
__device__ {0}& operator+=(difference_type diff) {{
122-
ADVANCE(&state, diff);
122+
ADVANCE(&state, &diff);
123123
return *this;
124124
}}
125125
__device__ {0}_proxy_t operator[](difference_type diff) const {{

c/parallel/test/test_util.h

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -931,8 +931,10 @@ inline std::tuple<std::string, std::string, std::string> make_random_access_iter
931931
{
932932
std::string state_def_src = std::format("struct {0} {{ {1}* data; }};\n", iterator_state_name, value_type);
933933
std::string advance_fn_def_src = std::format(
934-
"extern \"C\" __device__ void {0}({1}* state, unsigned long long offset) {{\n"
935-
" state->data += offset;\n"
934+
"extern \"C\" __device__ void {0}(void* state, const void* offset) {{\n"
935+
" auto* typed_state = static_cast<{1}*>(state);\n"
936+
" auto offset_val = *static_cast<const unsigned long long*>(offset);\n"
937+
" typed_state->data += offset_val;\n"
936938
"}}",
937939
advance_fn_name,
938940
iterator_state_name);
@@ -941,19 +943,22 @@ inline std::tuple<std::string, std::string, std::string> make_random_access_iter
941943
if (kind == iterator_kind::INPUT)
942944
{
943945
dereference_fn_def_src = std::format(
944-
"extern \"C\" __device__ void {0}({1}* state, {2}* result) {{\n"
945-
" *result = (*state->data){3};\n"
946+
"extern \"C\" __device__ void {0}(const void* state, {1}* result) {{\n"
947+
" auto* typed_state = static_cast<const {2}*>(state);\n"
948+
" *result = (*typed_state->data){3};\n"
946949
"}}",
947950
dereference_fn_name,
948-
iterator_state_name,
949951
value_type,
952+
iterator_state_name,
950953
transform);
951954
}
952955
else
953956
{
954957
dereference_fn_def_src = std::format(
955-
"extern \"C\" __device__ void {0}({1}* state, {2} x) {{\n"
956-
" *state->data = x{3};\n"
958+
"extern \"C\" __device__ void {0}(const void* state, const void* x) {{\n"
959+
" auto* typed_state = static_cast<const {1}*>(state);\n"
960+
" auto x_val = *static_cast<const {2}*>(x);\n"
961+
" *typed_state->data = x_val{3};\n"
957962
"}}",
958963
dereference_fn_name,
959964
iterator_state_name,
@@ -1033,17 +1038,16 @@ inline std::tuple<std::string, std::string, std::string> make_constant_iterator_
10331038
std::string_view dereference_fn_name)
10341039
{
10351040
std::string iterator_state_src = std::format("struct {0} {{ {1} value; }};\n", iterator_state_name, value_type);
1036-
std::string advance_fn_src = std::format(
1037-
"extern \"C\" __device__ void {0}({1}* state, unsigned long long offset) {{ }}",
1038-
advance_fn_name,
1039-
iterator_state_name);
1041+
std::string advance_fn_src =
1042+
std::format("extern \"C\" __device__ void {0}(void* state, const void* offset) {{ }}", advance_fn_name);
10401043
std::string dereference_fn_src = std::format(
1041-
"extern \"C\" __device__ void {0}({1}* state, {2}* result) {{ \n"
1042-
" *result = state->value;\n"
1044+
"extern \"C\" __device__ void {0}(const void* state, {1}* result) {{ \n"
1045+
" auto* typed_state = static_cast<const {2}*>(state);\n"
1046+
" *result = typed_state->value;\n"
10431047
"}}",
10441048
dereference_fn_name,
1045-
iterator_state_name,
1046-
value_type);
1049+
value_type,
1050+
iterator_state_name);
10471051

10481052
return std::make_tuple(iterator_state_src, advance_fn_src, dereference_fn_src);
10491053
}
@@ -1124,24 +1128,27 @@ struct {0} {{
11241128
const std::string it_state_def_src = std::format(it_state_src_tmpl, state_name, index_ty_name);
11251129

11261130
static constexpr std::string_view it_def_src_tmpl = R"XXX(
1127-
extern "C" __device__ void {0}({1}* state, {2} offset)
1131+
extern "C" __device__ void {0}(void* state, const void* offset)
11281132
{{
1129-
state->linear_id += offset;
1133+
auto* typed_state = static_cast<{1}*>(state);
1134+
auto offset_val = *static_cast<const {2}*>(offset);
1135+
typed_state->linear_id += offset_val;
11301136
}}
11311137
)XXX";
11321138

11331139
const std::string it_advance_fn_def_src =
11341140
std::format(it_def_src_tmpl, /*0*/ advance_fn_name, state_name, index_ty_name);
11351141

11361142
static constexpr std::string_view it_deref_src_tmpl = R"XXX(
1137-
extern "C" __device__ void {0}({1}* state, {2}* result)
1143+
extern "C" __device__ void {0}(const void* state, {1}* result)
11381144
{{
1139-
*result = (state->linear_id) * (state->segment_size);
1145+
auto* typed_state = static_cast<const {2}*>(state);
1146+
*result = (typed_state->linear_id) * (typed_state->segment_size);
11401147
}}
11411148
)XXX";
11421149

11431150
const std::string it_deref_fn_def_src =
1144-
std::format(it_deref_src_tmpl, dereference_fn_name, state_name, index_ty_name);
1151+
std::format(it_deref_src_tmpl, dereference_fn_name, index_ty_name, state_name);
11451152

11461153
return std::make_tuple(it_state_def_src, it_advance_fn_def_src, it_deref_fn_def_src);
11471154
}
@@ -1229,8 +1236,9 @@ struct {0} {{
12291236

12301237
static constexpr std::string_view transform_it_advance_fn_src_tmpl = R"XXX(
12311238
{3}
1232-
extern "C" __device__ void {0}({1} *transform_it_state, unsigned long long offset) {{
1233-
{2}(&(transform_it_state->base_it_state), offset);
1239+
extern "C" __device__ void {0}(void* transform_it_state, const void* offset) {{
1240+
auto* typed_state = static_cast<{1}*>(transform_it_state);
1241+
{2}(&(typed_state->base_it_state), offset);
12341242
}}
12351243
)XXX";
12361244

@@ -1244,11 +1252,12 @@ extern "C" __device__ void {0}({1} *transform_it_state, unsigned long long offse
12441252
static constexpr std::string_view transform_it_dereference_fn_src_tmpl = R"XXX(
12451253
{5}
12461254
{6}
1247-
extern "C" __device__ void {0}({1} *transform_it_state, {2}* result) {{
1255+
extern "C" __device__ void {0}(const void* transform_it_state, {2}* result) {{
1256+
auto* typed_state = static_cast<const {1}*>(transform_it_state);
12481257
{7} base_result;
1249-
{4}(&(transform_it_state->base_it_state), &base_result);
1258+
{4}(&(typed_state->base_it_state), &base_result);
12501259
*result = {3}(
1251-
&(transform_it_state->functor_state),
1260+
&(typed_state->functor_state),
12521261
base_result
12531262
);
12541263
}}

0 commit comments

Comments
 (0)