@@ -931,8 +931,10 @@ inline std::tuple<std::string, std::string, std::string> make_random_access_iter
931931{
932932 std::string state_def_src = std::format (" struct {0} {{ {1}* data; }};\n " , iterator_state_name, value_type);
933933 std::string advance_fn_def_src = std::format (
934- " extern \" C\" __device__ void {0}({1}* state, unsigned long long offset) {{\n "
935- " state->data += offset;\n "
934+ " extern \" C\" __device__ void {0}(void* state, const void* offset) {{\n "
935+ " auto* typed_state = static_cast<{1}*>(state);\n "
936+ " auto offset_val = *static_cast<const unsigned long long*>(offset);\n "
937+ " typed_state->data += offset_val;\n "
936938 " }}" ,
937939 advance_fn_name,
938940 iterator_state_name);
@@ -941,19 +943,22 @@ inline std::tuple<std::string, std::string, std::string> make_random_access_iter
941943 if (kind == iterator_kind::INPUT)
942944 {
943945 dereference_fn_def_src = std::format (
944- " extern \" C\" __device__ void {0}({1}* state, {2}* result) {{\n "
945- " *result = (*state->data){3};\n "
946+ " extern \" C\" __device__ void {0}(const void* state, {1}* result) {{\n "
947+ " auto* typed_state = static_cast<const {2}*>(state);\n "
948+ " *result = (*typed_state->data){3};\n "
946949 " }}" ,
947950 dereference_fn_name,
948- iterator_state_name,
949951 value_type,
952+ iterator_state_name,
950953 transform);
951954 }
952955 else
953956 {
954957 dereference_fn_def_src = std::format (
955- " extern \" C\" __device__ void {0}({1}* state, {2} x) {{\n "
956- " *state->data = x{3};\n "
958+ " extern \" C\" __device__ void {0}(const void* state, const void* x) {{\n "
959+ " auto* typed_state = static_cast<const {1}*>(state);\n "
960+ " auto x_val = *static_cast<const {2}*>(x);\n "
961+ " *typed_state->data = x_val{3};\n "
957962 " }}" ,
958963 dereference_fn_name,
959964 iterator_state_name,
@@ -1033,17 +1038,16 @@ inline std::tuple<std::string, std::string, std::string> make_constant_iterator_
10331038 std::string_view dereference_fn_name)
10341039{
10351040 std::string iterator_state_src = std::format (" struct {0} {{ {1} value; }};\n " , iterator_state_name, value_type);
1036- std::string advance_fn_src = std::format (
1037- " extern \" C\" __device__ void {0}({1}* state, unsigned long long offset) {{ }}" ,
1038- advance_fn_name,
1039- iterator_state_name);
1041+ std::string advance_fn_src =
1042+ std::format (" extern \" C\" __device__ void {0}(void* state, const void* offset) {{ }}" , advance_fn_name);
10401043 std::string dereference_fn_src = std::format (
1041- " extern \" C\" __device__ void {0}({1}* state, {2}* result) {{ \n "
1042- " *result = state->value;\n "
1044+ " extern \" C\" __device__ void {0}(const void* state, {1}* result) {{ \n "
1045+ " auto* typed_state = static_cast<const {2}*>(state);\n "
1046+ " *result = typed_state->value;\n "
10431047 " }}" ,
10441048 dereference_fn_name,
1045- iterator_state_name ,
1046- value_type );
1049+ value_type ,
1050+ iterator_state_name );
10471051
10481052 return std::make_tuple (iterator_state_src, advance_fn_src, dereference_fn_src);
10491053}
@@ -1124,24 +1128,27 @@ struct {0} {{
11241128 const std::string it_state_def_src = std::format (it_state_src_tmpl, state_name, index_ty_name);
11251129
11261130 static constexpr std::string_view it_def_src_tmpl = R"XXX(
1127- extern "C" __device__ void {0}({1} * state, {2} offset)
1131+ extern "C" __device__ void {0}(void * state, const void* offset)
11281132{{
1129- state->linear_id += offset;
1133+ auto* typed_state = static_cast<{1}*>(state);
1134+ auto offset_val = *static_cast<const {2}*>(offset);
1135+ typed_state->linear_id += offset_val;
11301136}}
11311137)XXX" ;
11321138
11331139 const std::string it_advance_fn_def_src =
11341140 std::format (it_def_src_tmpl, /* 0*/ advance_fn_name, state_name, index_ty_name);
11351141
11361142 static constexpr std::string_view it_deref_src_tmpl = R"XXX(
1137- extern "C" __device__ void {0}({1} * state, {2 }* result)
1143+ extern "C" __device__ void {0}(const void * state, {1 }* result)
11381144{{
1139- *result = (state->linear_id) * (state->segment_size);
1145+ auto* typed_state = static_cast<const {2}*>(state);
1146+ *result = (typed_state->linear_id) * (typed_state->segment_size);
11401147}}
11411148)XXX" ;
11421149
11431150 const std::string it_deref_fn_def_src =
1144- std::format (it_deref_src_tmpl, dereference_fn_name, state_name, index_ty_name );
1151+ std::format (it_deref_src_tmpl, dereference_fn_name, index_ty_name, state_name );
11451152
11461153 return std::make_tuple (it_state_def_src, it_advance_fn_def_src, it_deref_fn_def_src);
11471154}
@@ -1229,8 +1236,9 @@ struct {0} {{
12291236
12301237 static constexpr std::string_view transform_it_advance_fn_src_tmpl = R"XXX(
12311238{3}
1232- extern "C" __device__ void {0}({1} *transform_it_state, unsigned long long offset) {{
1233- {2}(&(transform_it_state->base_it_state), offset);
1239+ extern "C" __device__ void {0}(void* transform_it_state, const void* offset) {{
1240+ auto* typed_state = static_cast<{1}*>(transform_it_state);
1241+ {2}(&(typed_state->base_it_state), offset);
12341242}}
12351243)XXX" ;
12361244
@@ -1244,11 +1252,12 @@ extern "C" __device__ void {0}({1} *transform_it_state, unsigned long long offse
12441252 static constexpr std::string_view transform_it_dereference_fn_src_tmpl = R"XXX(
12451253{5}
12461254{6}
1247- extern "C" __device__ void {0}({1} *transform_it_state, {2}* result) {{
1255+ extern "C" __device__ void {0}(const void* transform_it_state, {2}* result) {{
1256+ auto* typed_state = static_cast<const {1}*>(transform_it_state);
12481257 {7} base_result;
1249- {4}(&(transform_it_state ->base_it_state), &base_result);
1258+ {4}(&(typed_state ->base_it_state), &base_result);
12501259 *result = {3}(
1251- &(transform_it_state ->functor_state),
1260+ &(typed_state ->functor_state),
12521261 base_result
12531262 );
12541263}}
0 commit comments