diff --git a/c/parallel/src/jit_templates/mappings/iterator.h b/c/parallel/src/jit_templates/mappings/iterator.h index b6666c37228..fbafc3889a5 100644 --- a/c/parallel/src/jit_templates/mappings/iterator.h +++ b/c/parallel/src/jit_templates/mappings/iterator.h @@ -19,12 +19,12 @@ template struct cccl_iterator_t_mapping { - bool is_pointer = false; - int size = 1; - int alignment = 1; - void (*advance)(void*, cuda::std::uint64_t) = nullptr; - void (*dereference)(const void*, ValueTp*) = nullptr; - void (*assign)(const void*, ValueTp); + bool is_pointer = false; + int size = 1; + int alignment = 1; + void (*advance)(void*, const void*) = nullptr; + void (*dereference)(const void*, ValueTp*) = nullptr; + void (*assign)(const void*, const void*); using ValueT = ValueTp; }; @@ -68,22 +68,19 @@ struct parameter_mapping { return std::format( R"output( -extern "C" __device__ void {0}(void *, {1}); -extern "C" __device__ void {2}(const void *, {3}); +extern "C" __device__ void {0}(void *, const void*); +extern "C" __device__ void {1}(const void *, const void*); )output", arg.advance.name, - cccl_type_enum_to_name(cccl_type_enum::CCCL_UINT64), - arg.dereference.name, - cccl_type_enum_to_name(arg.value_type.type)); + arg.dereference.name); } return std::format( R"input( -extern "C" __device__ void {0}(void *, {1}); -extern "C" __device__ void {2}(const void *, {3}*); +extern "C" __device__ void {0}(void *, const void*); +extern "C" __device__ void {1}(const void *, {2}*); )input", arg.advance.name, - cccl_type_enum_to_name(cccl_type_enum::CCCL_UINT64), arg.dereference.name, cccl_type_enum_to_name(arg.value_type.type)); } diff --git a/c/parallel/src/jit_templates/templates/input_iterator.h b/c/parallel/src/jit_templates/templates/input_iterator.h index f537f6aa497..812bb6c62c7 100644 --- a/c/parallel/src/jit_templates/templates/input_iterator.h +++ b/c/parallel/src/jit_templates/templates/input_iterator.h @@ -47,7 +47,7 @@ struct input_iterator_t __device__ input_iterator_t& operator+=(difference_type diff) { - Iterator.advance(&state, diff); + Iterator.advance(&state, &diff); return *this; } diff --git a/c/parallel/src/jit_templates/templates/output_iterator.h b/c/parallel/src/jit_templates/templates/output_iterator.h index 26bb2096757..a1e1e9144a0 100644 --- a/c/parallel/src/jit_templates/templates/output_iterator.h +++ b/c/parallel/src/jit_templates/templates/output_iterator.h @@ -35,7 +35,7 @@ struct output_iterator_proxy_t { __device__ output_iterator_proxy_t& operator=(AssignT x) { - AssignF(&state, cuda::std::move(x)); + AssignF(&state, &x); return *this; } @@ -59,7 +59,7 @@ struct output_iterator_t __device__ output_iterator_t& operator+=(difference_type diff) { - Iterator.advance(&state, diff); + Iterator.advance(&state, &diff); return *this; } diff --git a/c/parallel/src/kernels/iterators.cpp b/c/parallel/src/kernels/iterators.cpp index fb5f7fa48f5..e23b219486d 100644 --- a/c/parallel/src/kernels/iterators.cpp +++ b/c/parallel/src/kernels/iterators.cpp @@ -45,7 +45,7 @@ std::string make_kernel_input_iterator( { const std::string iter_def = std::format(R"XXX( extern "C" __device__ void DEREF(const void *self_ptr, VALUE_T* result); -extern "C" __device__ void ADVANCE(void *self_ptr, DIFF_T offset); +extern "C" __device__ void ADVANCE(void *self_ptr, const void* offset); struct __align__(OP_ALIGNMENT) {0} {{ using iterator_category = cuda::std::random_access_iterator_tag; using value_type = VALUE_T; @@ -58,7 +58,7 @@ struct __align__(OP_ALIGNMENT) {0} {{ return result; }} __device__ inline {0}& operator+=(difference_type diff) {{ - ADVANCE(data, diff); + ADVANCE(data, &diff); return *this; }} __device__ inline value_type operator[](difference_type diff) const {{ @@ -99,14 +99,14 @@ std::string make_kernel_output_iterator( std::string_view advance) { const std::string iter_def = std::format(R"XXX( -extern "C" __device__ void DEREF(const void *self_ptr, VALUE_T x); -extern "C" __device__ void ADVANCE(void *self_ptr, DIFF_T offset); +extern "C" __device__ void DEREF(const void *self_ptr, const void* x); +extern "C" __device__ void ADVANCE(void *self_ptr, const void* offset); struct __align__(OP_ALIGNMENT) {0}_state_t {{ char data[OP_SIZE]; }}; struct {0}_proxy_t {{ __device__ {0}_proxy_t operator=(VALUE_T x) {{ - DEREF(&state, x); + DEREF(&state, &x); return *this; }} {0}_state_t state; @@ -119,7 +119,7 @@ struct {0} {{ using reference = {0}_proxy_t; __device__ {0}_proxy_t operator*() const {{ return {{state}}; }} __device__ {0}& operator+=(difference_type diff) {{ - ADVANCE(&state, diff); + ADVANCE(&state, &diff); return *this; }} __device__ {0}_proxy_t operator[](difference_type diff) const {{ diff --git a/c/parallel/test/test_merge_sort.cpp b/c/parallel/test/test_merge_sort.cpp index fc87a6ee98e..dbfdebc6aac 100644 --- a/c/parallel/test/test_merge_sort.cpp +++ b/c/parallel/test/test_merge_sort.cpp @@ -244,16 +244,15 @@ struct item_pair struct DeviceMergeSort_SortPairsCopy_CustomType_Fixture_Tag; C2H_TEST("DeviceMergeSort:SortPairsCopy works with custom types", "[merge_sort]") { - const size_t num_items = GENERATE_COPY(take(2, random(1, 100000)), values({5, 10000, 100000})); - operation_t op = make_operation( - "op", - "struct key_pair { short a; size_t b; };\n" - "extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) {\n" - " key_pair* lhs = static_cast(lhs_ptr);\n" - " key_pair* rhs = static_cast(rhs_ptr);\n" - " bool* out = static_cast(out_ptr);\n" - " *out = lhs->a == rhs->a ? lhs->b < rhs->b : lhs->a < rhs->a;\n" - "}"); + const size_t num_items = GENERATE_COPY(take(2, random(1, 100000)), values({5, 10000, 100000})); + operation_t op = make_operation("op", + R"(struct key_pair { short a; size_t b; }; +extern "C" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) { + key_pair* lhs = static_cast(lhs_ptr); + key_pair* rhs = static_cast(rhs_ptr); + bool* out = static_cast(out_ptr); + *out = lhs->a == rhs->a ? lhs->b < rhs->b : lhs->a < rhs->a; +})"); const std::vector a = generate(num_items); const std::vector b = generate(num_items); std::vector input_keys(num_items); @@ -301,16 +300,15 @@ C2H_TEST("DeviceMergeSort:SortPairsCopy works with custom types", "[merge_sort]" struct DeviceMergeSort_SortPairsCopy_CustomType_WellKnown_Fixture_Tag; C2H_TEST("DeviceMergeSort:SortPairsCopy works with custom types with well-known predicates", "[merge_sort][well_known]") { - const size_t num_items = GENERATE_COPY(take(2, random(1, 100000)), values({5, 10000, 100000})); - operation_t op_state = make_operation( - "op", - "struct key_pair { short a; size_t b; };\n" - "extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) {\n" - " key_pair* lhs = static_cast(lhs_ptr);\n" - " key_pair* rhs = static_cast(rhs_ptr);\n" - " bool* out = static_cast(out_ptr);\n" - " *out = lhs->a == rhs->a ? lhs->b < rhs->b : lhs->a < rhs->a;\n" - "}"); + const size_t num_items = GENERATE_COPY(take(2, random(1, 100000)), values({5, 10000, 100000})); + operation_t op_state = make_operation("op", + R"(struct key_pair { short a; size_t b; }; +extern "C" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) { + key_pair* lhs = static_cast(lhs_ptr); + key_pair* rhs = static_cast(rhs_ptr); + bool* out = static_cast(out_ptr); + *out = lhs->a == rhs->a ? lhs->b < rhs->b : lhs->a < rhs->a; +})"); cccl_op_t op = op_state; op.type = cccl_op_kind_t::CCCL_LESS; const std::vector a = generate(num_items); @@ -432,13 +430,17 @@ C2H_TEST("DeviceMergeSort::SortKeys works with output iterators", "[merge_sort]" make_iterator( {"random_access_iterator_state_t", "struct random_access_iterator_state_t { int* d_input; };\n"}, {"advance", - "extern \"C\" __device__ void advance(random_access_iterator_state_t* state, unsigned long long offset) {\n" - " state->d_input += offset;\n" - "}"}, + R"(extern "C" __device__ void advance(void* state, const void* offset) { + auto* typed_state = static_cast(state); + auto offset_val = *static_cast(offset); + typed_state->d_input += offset_val; +})"}, {"dereference", - "extern \"C\" __device__ void dereference(random_access_iterator_state_t* state, int x) {\n" - " *state->d_input = x;\n" - "}"}); + R"(extern "C" __device__ void dereference(void* state, const void* x) { + auto* typed_state = static_cast(state); + auto x_val = *static_cast(x); + *typed_state->d_input = x_val; +})"}); std::vector input_keys = make_shuffled_key_ranks_vector(num_items); std::vector expected_keys = input_keys; @@ -471,14 +473,17 @@ C2H_TEST("DeviceMergeSort::SortPairs works with output iterators for items", "[m make_iterator( "struct item_random_access_iterator_state_t { int* d_input; };\n", {"advance", - "extern \"C\" __device__ void advance(item_random_access_iterator_state_t* state, unsigned long long offset) " - "{\n" - " state->d_input += offset;\n" - "}"}, + R"(extern "C" __device__ void advance(void* state, const void* offset) { + auto* typed_state = static_cast(state); + auto offset_val = *static_cast(offset); + typed_state->d_input += offset_val; +})"}, {"dereference", - "extern \"C\" __device__ void dereference(item_random_access_iterator_state_t* state, int x) {\n" - " *state->d_input = x;\n" - "}"}); + R"(extern "C" __device__ void dereference(void* state, const void* x) { + auto* typed_state = static_cast(state); + auto x_val = *static_cast(x); + *typed_state->d_input = x_val; +})"}); pointer_t input_keys_it(input_keys); pointer_t input_items_it(input_items); @@ -650,12 +655,12 @@ C2H_TEST("MergeSort works with C++ source operations using custom headers", "[me /* C2H_TEST("DeviceMergeSort:SortPairsCopy fails to build for large types due to no vsmem", "[merge_sort]") { const size_t num_items = 1; - operation_t op = make_operation( + operation_t op = make_operation( "op", - "struct large_key_pair { int a; char c[100]; };\n" - "extern \"C\" __device__ bool op(large_key_pair lhs, large_key_pair rhs) {\n" - " return lhs.a < rhs.a;\n" - "}"); + R"(struct large_key_pair { int a; char c[100]; }; +extern "C" __device__ bool op(large_key_pair lhs, large_key_pair rhs) { + return lhs.a < rhs.a; +})"); const std::vector a = generate(num_items); std::vector input_keys(num_items); for (std::size_t i = 0; i < num_items; ++i) diff --git a/c/parallel/test/test_reduce.cpp b/c/parallel/test/test_reduce.cpp index adbb00d7395..cecbb825342 100644 --- a/c/parallel/test/test_reduce.cpp +++ b/c/parallel/test/test_reduce.cpp @@ -180,15 +180,14 @@ C2H_TEST("Reduce works with custom types", "[reduce]") { const std::size_t num_items = GENERATE(0, 42, take(4, random(1 << 12, 1 << 24))); - operation_t op = make_operation( - "op", - "struct pair { short a; size_t b; };\n" - "extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, void* out_ptr) {\n" - " pair* lhs = static_cast(lhs_ptr);\n" - " pair* rhs = static_cast(rhs_ptr);\n" - " pair* out = static_cast(out_ptr);\n" - " *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b };\n" - "}"); + operation_t op = make_operation("op", + R"(struct pair { short a; size_t b; }; +extern "C" __device__ void op(void* lhs_ptr, void* rhs_ptr, void* out_ptr) { + pair* lhs = static_cast(lhs_ptr); + pair* rhs = static_cast(rhs_ptr); + pair* out = static_cast(out_ptr); + *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b }; +})"); const std::vector a = generate(num_items); const std::vector b = generate(num_items); std::vector input(num_items); @@ -218,15 +217,14 @@ C2H_TEST("Reduce works with custom types with well-known operations", "[reduce][ { const std::size_t num_items = GENERATE(0, 42, take(4, random(1 << 12, 1 << 24))); - operation_t op_state = make_operation( - "op", - "struct pair { short a; size_t b; };\n" - "extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, void* out_ptr) {\n" - " pair* lhs = static_cast(lhs_ptr);\n" - " pair* rhs = static_cast(rhs_ptr);\n" - " pair* out = static_cast(out_ptr);\n" - " *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b };\n" - "}"); + operation_t op_state = make_operation("op", + R"(struct pair { short a; size_t b; }; +extern "C" __device__ void op(void* lhs_ptr, void* rhs_ptr, void* out_ptr) { + pair* lhs = static_cast(lhs_ptr); + pair* rhs = static_cast(rhs_ptr); + pair* out = static_cast(out_ptr); + *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b }; +})"); cccl_op_t op = op_state; op.type = cccl_op_kind_t::CCCL_PLUS; const std::vector a = generate(num_items); @@ -371,14 +369,14 @@ C2H_TEST("Reduce works with stateful operators", "[reduce]") pointer_t counter(1); stateful_operation_t op = make_operation( "op", - "struct invocation_counter_state_t { int* d_counter; };\n" - "extern \"C\" __device__ void op(void* state_ptr, void* a_ptr, void* b_ptr, void* out_ptr) {\n" - " invocation_counter_state_t* state = static_cast(state_ptr);\n" - " atomicAdd(state->d_counter, 1);\n" - " int a = *static_cast(a_ptr);\n" - " int b = *static_cast(b_ptr);\n" - " *static_cast(out_ptr) = a + b;\n" - "}", + R"(struct invocation_counter_state_t { int* d_counter; }; +extern "C" __device__ void op(void* state_ptr, void* a_ptr, void* b_ptr, void* out_ptr) { + invocation_counter_state_t* state = static_cast(state_ptr); + atomicAdd(state->d_counter, 1); + int a = *static_cast(a_ptr); + int b = *static_cast(b_ptr); + *static_cast(out_ptr) = a + b; +})", invocation_counter_state_t{counter.ptr}); const std::vector input = generate(num_items); diff --git a/c/parallel/test/test_scan.cpp b/c/parallel/test/test_scan.cpp index 3ba7c54343a..56a1f0c9cea 100644 --- a/c/parallel/test/test_scan.cpp +++ b/c/parallel/test/test_scan.cpp @@ -398,15 +398,14 @@ C2H_TEST("Scan works with custom types", "[scan]") { const std::size_t num_items = GENERATE(0, 42, take(4, random(1 << 12, 1 << 24))); - operation_t op = make_operation( - "op", - "struct pair { short a; size_t b; };\n" - "extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, void* out_ptr) {\n" - " pair* lhs = static_cast(lhs_ptr);\n" - " pair* rhs = static_cast(rhs_ptr);\n" - " pair* out = static_cast(out_ptr);\n" - " *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b };\n" - "}"); + operation_t op = make_operation("op", + R"(struct pair { short a; size_t b; }; +extern "C" __device__ void op(void* lhs_ptr, void* rhs_ptr, void* out_ptr) { + pair* lhs = static_cast(lhs_ptr); + pair* rhs = static_cast(rhs_ptr); + pair* out = static_cast(out_ptr); + *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b }; +})"); const std::vector a = generate(num_items); const std::vector b = generate(num_items); std::vector input(num_items); @@ -439,15 +438,14 @@ C2H_TEST("Scan works with custom types with well-known operations", "[scan][well { const std::size_t num_items = GENERATE(0, 42, take(4, random(1 << 12, 1 << 24))); - operation_t op_state = make_operation( - "op", - "struct pair { short a; size_t b; };\n" - "extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, void* out_ptr) {\n" - " pair* lhs = static_cast(lhs_ptr);\n" - " pair* rhs = static_cast(rhs_ptr);\n" - " pair* out = static_cast(out_ptr);\n" - " *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b };\n" - "}"); + operation_t op_state = make_operation("op", + R"(struct pair { short a; size_t b; }; +extern "C" __device__ void op(void* lhs_ptr, void* rhs_ptr, void* out_ptr) { + pair* lhs = static_cast(lhs_ptr); + pair* rhs = static_cast(rhs_ptr); + pair* out = static_cast(out_ptr); + *out = pair{ lhs->a + rhs->a, lhs->b + rhs->b }; +})"); cccl_op_t op = op_state; op.type = cccl_op_kind_t::CCCL_PLUS; const std::vector a = generate(num_items); diff --git a/c/parallel/test/test_segmented_reduce.cpp b/c/parallel/test/test_segmented_reduce.cpp index b2b376e6623..c8d9a1b522b 100644 --- a/c/parallel/test/test_segmented_reduce.cpp +++ b/c/parallel/test/test_segmented_reduce.cpp @@ -443,9 +443,11 @@ struct {0} {{ /* 2 */ index_type_name); static constexpr std::string_view it_advance_fn_def_src_tmpl = R"XXX( -extern "C" __device__ void {0}({1}* state, {2} offset) +extern "C" __device__ void {0}(void* state, const void* offset) {{ - state->linear_id += offset; + auto* typed_state = static_cast<{1}*>(state); + auto offset_val = *static_cast(offset); + typed_state->linear_id += offset_val; }} )XXX"; @@ -453,10 +455,11 @@ extern "C" __device__ void {0}({1}* state, {2} offset) std::format(it_advance_fn_def_src_tmpl, /*0*/ advance_fn_name, state_name, index_type_name); static constexpr std::string_view it_dereference_fn_src_tmpl = R"XXX( -extern "C" __device__ void {0}({2} *state, {1}* result) {{ - unsigned long long col_id = (state->linear_id) / (state->n_rows); - unsigned long long row_id = (state->linear_id) - col_id * (state->n_rows); - *result = *(state->ptr + row_id * (state->n_cols) + col_id); +extern "C" __device__ void {0}(const void* state, {1}* result) {{ + auto* typed_state = static_cast(state); + unsigned long long col_id = (typed_state->linear_id) / (typed_state->n_rows); + unsigned long long row_id = (typed_state->linear_id) - col_id * (typed_state->n_rows); + *result = *(typed_state->ptr + row_id * (typed_state->n_cols) + col_id); }} )XXX"; diff --git a/c/parallel/test/test_three_way_partition.cpp b/c/parallel/test/test_three_way_partition.cpp index 2bf9c65c2c8..63b175eb70f 100644 --- a/c/parallel/test/test_three_way_partition.cpp +++ b/c/parallel/test/test_three_way_partition.cpp @@ -329,21 +329,21 @@ C2H_TEST("ThreeWayPartition works with stateful operations", "[three_way_partiti selector_state_t op_state = {21}; stateful_operation_t less_op = make_operation( "less_op", - "struct selector_state_t { int comparison_value; };\n" - "extern \"C\" __device__ void less_op(void* state_ptr, void* x_ptr, void* out_ptr) {\n" - " selector_state_t* state = static_cast(state_ptr);\n" - " *static_cast(x_ptr) < state->comparison_value;\n" - " *static_cast(out_ptr) = *static_cast(x_ptr) < state->comparison_value;\n" - "}", + R"(struct selector_state_t { int comparison_value; }; +extern "C" __device__ void less_op(void* state_ptr, void* x_ptr, void* out_ptr) { + selector_state_t* state = static_cast(state_ptr); + *static_cast(x_ptr) < state->comparison_value; + *static_cast(out_ptr) = *static_cast(x_ptr) < state->comparison_value; +})", op_state); stateful_operation_t greater_or_equal_op = make_operation( "greater_or_equal_op", - "struct selector_state_t { int comparison_value; };\n" - "extern \"C\" __device__ void greater_or_equal_op(void* state_ptr, void* x_ptr, void* out_ptr) {\n" - " selector_state_t* state = static_cast(state_ptr);\n" - " *static_cast(x_ptr) >= state->comparison_value;\n" - " *static_cast(out_ptr) = *static_cast(x_ptr) >= state->comparison_value;\n" - "}", + R"(struct selector_state_t { int comparison_value; }; +extern "C" __device__ void greater_or_equal_op(void* state_ptr, void* x_ptr, void* out_ptr) { + selector_state_t* state = static_cast(state_ptr); + *static_cast(x_ptr) >= state->comparison_value; + *static_cast(out_ptr) = *static_cast(x_ptr) >= state->comparison_value; +})", op_state); const std::size_t num_items = GENERATE(0, 42, take(4, random(1 << 12, 1 << 20))); @@ -409,21 +409,21 @@ C2H_TEST("ThreeWayPartition works with custom types", "[three_way_partition]") operation_t less_op = make_operation( "less_op", - std::format("struct pair_type {{ int a; size_t b; }};" - "extern \"C\" __device__ void less_op(void* x_ptr, void* out_ptr) {{ " - " pair_type* x = static_cast(x_ptr); " - " bool* out = static_cast(out_ptr); " - " *out = x->a < {0}; " - "}}", + std::format(R"(struct pair_type {{ int a; size_t b; }}; +extern "C" __device__ void less_op(void* x_ptr, void* out_ptr) {{ + pair_type* x = static_cast(x_ptr); + bool* out = static_cast(out_ptr); + *out = x->a < {0}; +}})", comparison_value)); operation_t greater_or_equal_op = make_operation( "greater_or_equal_op", - std::format("struct pair_type {{ int a; size_t b; }};" - "extern \"C\" __device__ void greater_or_equal_op(void* x_ptr, void* out_ptr) {{ " - " pair_type* x = static_cast(x_ptr); " - " bool* out = static_cast(out_ptr); " - " *out = x->a >= {0}; " - "}}", + std::format(R"(struct pair_type {{ int a; size_t b; }}; +extern "C" __device__ void greater_or_equal_op(void* x_ptr, void* out_ptr) {{ + pair_type* x = static_cast(x_ptr); + bool* out = static_cast(out_ptr); + *out = x->a >= {0}; +}})", comparison_value)); const std::size_t num_items = GENERATE(0, 42, take(4, random(1 << 12, 1 << 20))); diff --git a/c/parallel/test/test_transform.cpp b/c/parallel/test/test_transform.cpp index 5816a355b44..397e753e258 100644 --- a/c/parallel/test/test_transform.cpp +++ b/c/parallel/test/test_transform.cpp @@ -281,14 +281,13 @@ C2H_TEST("Transform works with output of different type", "[transform]") { const std::size_t num_items = GENERATE(0, 42, take(4, random(1 << 12, 1 << 24))); - operation_t op = make_operation( - "op", - "struct pair { short a; size_t b; };\n" - "extern \"C\" __device__ void op(void* x_ptr, void* out_ptr) {\n" - " int* x = static_cast(x_ptr);\n" - " pair* out = static_cast(out_ptr);\n" - " *out = pair{ short(*x), size_t(*x) };\n" - "}"); + operation_t op = make_operation("op", + R"(struct pair { short a; size_t b; }; +extern "C" __device__ void op(void* x_ptr, void* out_ptr) { + int* x = static_cast(x_ptr); + pair* out = static_cast(out_ptr); + *out = pair{ short(*x), size_t(*x) }; +})"); const std::vector input = generate(num_items); std::vector expected(num_items); std::vector output(num_items); @@ -315,14 +314,13 @@ C2H_TEST("Transform works with custom types", "[transform]") { const std::size_t num_items = GENERATE(0, 42, take(4, random(1 << 12, 1 << 24))); - operation_t op = make_operation( - "op", - "struct pair { short a; size_t b; };\n" - "extern \"C\" __device__ void op(void* x_ptr, void* out_ptr) {\n" - " pair* x = static_cast(x_ptr);\n" - " pair* out = static_cast(out_ptr);\n" - " *out = pair{ x->a * 2, x->b * 2 };\n" - "}"); + operation_t op = make_operation("op", + R"(struct pair { short a; size_t b; }; +extern "C" __device__ void op(void* x_ptr, void* out_ptr) { + pair* x = static_cast(x_ptr); + pair* out = static_cast(out_ptr); + *out = pair{ x->a * 2, x->b * 2 }; +})"); const std::vector a = generate(num_items); const std::vector b = generate(num_items); std::vector input(num_items); @@ -354,15 +352,14 @@ C2H_TEST("Transform works with custom types with well-known operators", "[transf { const std::size_t num_items = GENERATE(0, 42, take(4, random(1 << 12, 1 << 24))); - operation_t op_state = make_operation( - "op", - "struct pair { short a; size_t b; };\n" - "extern \"C\" __device__ void op(void* x_ptr, void* out_ptr) {\n" - " pair* x = static_cast(x_ptr);\n" - " pair* out = static_cast(out_ptr);\n" - " *out = pair{ x->a * 2, x->b * 2 };\n" - "}"); - cccl_op_t op = op_state; + operation_t op_state = make_operation("op", + R"(struct pair { short a; size_t b; }; +extern "C" __device__ void op(void* x_ptr, void* out_ptr) { + pair* x = static_cast(x_ptr); + pair* out = static_cast(out_ptr); + *out = pair{ x->a * 2, x->b * 2 }; +})"); + cccl_op_t op = op_state; // HACK: this doesn't actually match the operation above, but that's fine, as we are supposed to not take the // well-known path anyway op.type = cccl_op_kind_t::CCCL_NEGATE; @@ -458,14 +455,13 @@ C2H_TEST("Transform with binary operator", "[transform]") pointer_t input2_ptr(input2); pointer_t output_ptr(output); - operation_t op = make_operation( - "op", - "extern \"C\" __device__ void op(void* x_ptr, void* y_ptr, void* out_ptr ) {\n" - " int* x = static_cast(x_ptr);\n" - " int* y = static_cast(y_ptr);\n" - " int* out = static_cast(out_ptr);\n" - " *out = (*x > *y) ? *x : *y;\n" - "}"); + operation_t op = make_operation("op", + R"(extern "C" __device__ void op(void* x_ptr, void* y_ptr, void* out_ptr ) { + int* x = static_cast(x_ptr); + int* y = static_cast(y_ptr); + int* out = static_cast(out_ptr); + *out = (*x > *y) ? *x : *y; +})"); auto& build_cache = get_cache(); const auto& test_key = make_key(); @@ -496,14 +492,13 @@ C2H_TEST("Binary transform with one iterator", "[transform]") pointer_t input1_ptr(input1); pointer_t output_ptr(output); - operation_t op = make_operation( - "op", - "extern \"C\" __device__ void op(void* x_ptr, void* y_ptr, void* out_ptr) {\n" - " int* x = static_cast(x_ptr);\n" - " int* y = static_cast(y_ptr);\n" - " int* out = static_cast(out_ptr);\n" - " *out = (*x > *y) ? *x : *y;\n" - "}"); + operation_t op = make_operation("op", + R"(extern "C" __device__ void op(void* x_ptr, void* y_ptr, void* out_ptr) { + int* x = static_cast(x_ptr); + int* y = static_cast(y_ptr); + int* out = static_cast(out_ptr); + *out = (*x > *y) ? *x : *y; +})"); auto& build_cache = get_cache(); const auto& test_key = make_key(); diff --git a/c/parallel/test/test_unique_by_key.cpp b/c/parallel/test/test_unique_by_key.cpp index 52913d1bde0..a59b97fe091 100644 --- a/c/parallel/test/test_unique_by_key.cpp +++ b/c/parallel/test/test_unique_by_key.cpp @@ -461,15 +461,14 @@ C2H_TEST("DeviceSelect::UniqueByKey works with custom types", "[device][select_u { const int num_items = GENERATE_COPY(take(2, random(1, 1000000))); - operation_t op = make_operation( - "op", - "struct key_pair { short a; size_t b; };\n" - "extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) {\n" - " key_pair* lhs = static_cast(lhs_ptr);\n" - " key_pair* rhs = static_cast(rhs_ptr);\n" - " bool* out = static_cast(out_ptr);\n" - " *out = (lhs->a == rhs->a && lhs->b == rhs->b);\n" - "}"); + operation_t op = make_operation("op", + R"(struct key_pair { short a; size_t b; }; +extern "C" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) { + key_pair* lhs = static_cast(lhs_ptr); + key_pair* rhs = static_cast(rhs_ptr); + bool* out = static_cast(out_ptr); + *out = (lhs->a == rhs->a && lhs->b == rhs->b); +})"); const std::vector a = generate(num_items); const std::vector b = generate(num_items); std::vector input_keys(num_items); @@ -533,15 +532,14 @@ C2H_TEST("DeviceSelect::UniqueByKey works with custom types with well-known oper { const int num_items = GENERATE_COPY(take(2, random(1, 1000000))); - operation_t op_state = make_operation( - "op", - "struct key_pair { short a; size_t b; };\n" - "extern \"C\" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) {\n" - " key_pair* lhs = static_cast(lhs_ptr);\n" - " key_pair* rhs = static_cast(rhs_ptr);\n" - " bool* out = static_cast(out_ptr);\n" - " *out = (lhs->a == rhs->a && lhs->b == rhs->b);\n" - "}"); + operation_t op_state = make_operation("op", + R"(struct key_pair { short a; size_t b; }; +extern "C" __device__ void op(void* lhs_ptr, void* rhs_ptr, bool* out_ptr) { + key_pair* lhs = static_cast(lhs_ptr); + key_pair* rhs = static_cast(rhs_ptr); + bool* out = static_cast(out_ptr); + *out = (lhs->a == rhs->a && lhs->b == rhs->b); +})"); cccl_op_t op = op_state; op.type = cccl_op_kind_t::CCCL_EQUAL_TO; const std::vector a = generate(num_items); @@ -693,12 +691,11 @@ C2H_TEST("DeviceSelect::UniqueByKey fails to build for large types due to no vsm { const int num_items = 1; - operation_t op = make_operation( - "op", - "struct large_key_pair { int a; char c[500]; };\n" - "extern \"C\" __device__ bool op(large_key_pair lhs, large_key_pair rhs) {\n" - " return lhs.a == rhs.a;\n" - "}"); + operation_t op = make_operation("op", + R"(struct large_key_pair { int a; char c[500]; }; +extern "C" __device__ bool op(large_key_pair lhs, large_key_pair rhs) { + return lhs.a == rhs.a; +})"); const std::vector a = generate(num_items); std::vector input_keys(num_items); for (int i = 0; i < num_items; ++i) diff --git a/c/parallel/test/test_util.h b/c/parallel/test/test_util.h index 1ba57a5a4de..995fa102aac 100644 --- a/c/parallel/test/test_util.h +++ b/c/parallel/test/test_util.h @@ -931,8 +931,10 @@ inline std::tuple make_random_access_iter { std::string state_def_src = std::format("struct {0} {{ {1}* data; }};\n", iterator_state_name, value_type); std::string advance_fn_def_src = std::format( - "extern \"C\" __device__ void {0}({1}* state, unsigned long long offset) {{\n" - " state->data += offset;\n" + "extern \"C\" __device__ void {0}(void* state, const void* offset) {{\n" + " auto* typed_state = static_cast<{1}*>(state);\n" + " auto offset_val = *static_cast(offset);\n" + " typed_state->data += offset_val;\n" "}}", advance_fn_name, iterator_state_name); @@ -941,19 +943,22 @@ inline std::tuple make_random_access_iter if (kind == iterator_kind::INPUT) { dereference_fn_def_src = std::format( - "extern \"C\" __device__ void {0}({1}* state, {2}* result) {{\n" - " *result = (*state->data){3};\n" + "extern \"C\" __device__ void {0}(const void* state, {1}* result) {{\n" + " auto* typed_state = static_cast(state);\n" + " *result = (*typed_state->data){3};\n" "}}", dereference_fn_name, - iterator_state_name, value_type, + iterator_state_name, transform); } else { dereference_fn_def_src = std::format( - "extern \"C\" __device__ void {0}({1}* state, {2} x) {{\n" - " *state->data = x{3};\n" + "extern \"C\" __device__ void {0}(void* state, const void* x) {{\n" + " auto* typed_state = static_cast<{1}*>(state);\n" + " auto x_val = *static_cast(x);\n" + " *typed_state->data = x_val{3};\n" "}}", dereference_fn_name, iterator_state_name, @@ -991,15 +996,18 @@ inline std::tuple make_counting_iterator_ { std::string iterator_state_def_src = std::format("struct {0} {{ {1} value; }};\n", iterator_state_name, value_type); std::string advance_fn_def_src = std::format( - "extern \"C\" __device__ void {0}({1}* state, unsigned long long offset) {{\n" - " state->value += offset;\n" + "extern \"C\" __device__ void {0}(void* state, const void* offset) {{\n" + " auto* typed_state = static_cast<{1}*>(state);\n" + " auto offset_val = *static_cast(offset);\n" + " typed_state->value += offset_val;\n" "}}", advance_fn_name, iterator_state_name); std::string dereference_fn_def_src = std::format( - "extern \"C\" __device__ void {0}({1}* state, {2}* result) {{ \n" - " *result = state->value;\n" + "extern \"C\" __device__ void {0}(const void* state, {2}* result) {{ \n" + " auto* typed_state = static_cast(state);\n" + " *result = typed_state->value;\n" "}}", dereference_fn_name, iterator_state_name, @@ -1033,17 +1041,16 @@ inline std::tuple make_constant_iterator_ std::string_view dereference_fn_name) { std::string iterator_state_src = std::format("struct {0} {{ {1} value; }};\n", iterator_state_name, value_type); - std::string advance_fn_src = std::format( - "extern \"C\" __device__ void {0}({1}* state, unsigned long long offset) {{ }}", - advance_fn_name, - iterator_state_name); + std::string advance_fn_src = + std::format("extern \"C\" __device__ void {0}(void* state, const void* offset) {{ }}", advance_fn_name); std::string dereference_fn_src = std::format( - "extern \"C\" __device__ void {0}({1}* state, {2}* result) {{ \n" - " *result = state->value;\n" + "extern \"C\" __device__ void {0}(const void* state, {1}* result) {{ \n" + " auto* typed_state = static_cast(state);\n" + " *result = typed_state->value;\n" "}}", dereference_fn_name, - iterator_state_name, - value_type); + value_type, + iterator_state_name); return std::make_tuple(iterator_state_src, advance_fn_src, dereference_fn_src); } @@ -1076,8 +1083,10 @@ inline std::tuple make_reverse_iterator_s { std::string iterator_state_src = std::format("struct {0} {{ {1}* data; }};\n", iterator_state_name, value_type); std::string advance_fn_src = std::format( - "extern \"C\" __device__ void {0}({1}* state, unsigned long long offset) {{\n" - " state->data -= offset;\n" + "extern \"C\" __device__ void {0}(void* state, const void* offset) {{\n" + " auto* typed_state = static_cast<{1}*>(state);\n" + " auto offset_val = *static_cast(offset);\n" + " typed_state->data -= offset_val;\n" "}}", advance_fn_name, iterator_state_name); @@ -1085,8 +1094,9 @@ inline std::tuple make_reverse_iterator_s if (kind == iterator_kind::INPUT) { dereference_fn_src = std::format( - "extern \"C\" __device__ void {0}({1}* state, {2}* result) {{\n" - " *result = (*state->data){3};\n" + "extern \"C\" __device__ void {0}(const void* state, {2}* result) {{\n" + " auto* typed_state = static_cast(state);\n" + " *result = (*typed_state->data){3};\n" "}}", dereference_fn_name, iterator_state_name, @@ -1096,8 +1106,10 @@ inline std::tuple make_reverse_iterator_s else { dereference_fn_src = std::format( - "extern \"C\" __device__ void {0}({1}* state, {2} x) {{\n" - " *state->data = x{3};\n" + "extern \"C\" __device__ void {0}(void* state, const void* x) {{\n" + " auto* typed_state = static_cast<{1}*>(state);\n" + " auto x_val = *static_cast(x);\n" + " *typed_state->data = x_val{3};\n" "}}", dereference_fn_name, iterator_state_name, @@ -1124,9 +1136,11 @@ struct {0} {{ const std::string it_state_def_src = std::format(it_state_src_tmpl, state_name, index_ty_name); static constexpr std::string_view it_def_src_tmpl = R"XXX( -extern "C" __device__ void {0}({1}* state, {2} offset) +extern "C" __device__ void {0}(void* state, const void* offset) {{ - state->linear_id += offset; + auto* typed_state = static_cast<{1}*>(state); + auto offset_val = *static_cast(offset); + typed_state->linear_id += offset_val; }} )XXX"; @@ -1134,14 +1148,15 @@ extern "C" __device__ void {0}({1}* state, {2} offset) std::format(it_def_src_tmpl, /*0*/ advance_fn_name, state_name, index_ty_name); static constexpr std::string_view it_deref_src_tmpl = R"XXX( -extern "C" __device__ void {0}({1}* state, {2}* result) +extern "C" __device__ void {0}(const void* state, {1}* result) {{ - *result = (state->linear_id) * (state->segment_size); + auto* typed_state = static_cast(state); + *result = (typed_state->linear_id) * (typed_state->segment_size); }} )XXX"; const std::string it_deref_fn_def_src = - std::format(it_deref_src_tmpl, dereference_fn_name, state_name, index_ty_name); + std::format(it_deref_src_tmpl, dereference_fn_name, index_ty_name, state_name); return std::make_tuple(it_state_def_src, it_advance_fn_def_src, it_deref_fn_def_src); } @@ -1229,8 +1244,9 @@ struct {0} {{ static constexpr std::string_view transform_it_advance_fn_src_tmpl = R"XXX( {3} -extern "C" __device__ void {0}({1} *transform_it_state, unsigned long long offset) {{ - {2}(&(transform_it_state->base_it_state), offset); +extern "C" __device__ void {0}(void* transform_it_state, const void* offset) {{ + auto* typed_state = static_cast<{1}*>(transform_it_state); + {2}(&(typed_state->base_it_state), offset); }} )XXX"; @@ -1244,11 +1260,12 @@ extern "C" __device__ void {0}({1} *transform_it_state, unsigned long long offse static constexpr std::string_view transform_it_dereference_fn_src_tmpl = R"XXX( {5} {6} -extern "C" __device__ void {0}({1} *transform_it_state, {2}* result) {{ +extern "C" __device__ void {0}(const void* transform_it_state, {2}* result) {{ + auto* typed_state = static_cast(transform_it_state); {7} base_result; - {4}(&(transform_it_state->base_it_state), &base_result); + {4}(&(typed_state->base_it_state), &base_result); *result = {3}( - &(transform_it_state->functor_state), + const_castfunctor_state)*>(&(typed_state->functor_state)), base_result ); }} @@ -1332,8 +1349,9 @@ struct {0} {{ static constexpr std::string_view transform_it_advance_fn_src_tmpl = R"XXX( {3} -extern "C" __device__ void {0}({1} *transform_it_state, unsigned long long offset) {{ - {2}(&(transform_it_state->base_it_state), offset); +extern "C" __device__ void {0}(void *transform_it_state, const void* offset) {{ + auto* typed_state = static_cast<{1}*>(transform_it_state); + {2}(&(typed_state->base_it_state), offset); }} )XXX"; @@ -1411,7 +1429,7 @@ inline std::tuple make_discard_iterator_s { std::string state_def_src = std::format("struct {0} {{ {1}* data; }};\n", iterator_state_name, value_type); std::string advance_fn_def_src = std::format( - "extern \"C\" __device__ void {0}({1}* /*state*/, unsigned long long /*offset*/) {{\n" + "extern \"C\" __device__ void {0}(void* /*state*/, const void* /*offset*/) {{\n" "}}", advance_fn_name, iterator_state_name); @@ -1420,7 +1438,7 @@ inline std::tuple make_discard_iterator_s if (kind == iterator_kind::INPUT) { dereference_fn_def_src = std::format( - "extern \"C\" __device__ void {0}({1}* /*state*/, {2}* /*result*/) {{\n" + "extern \"C\" __device__ void {0}(const void* /*state*/, {2}* /*result*/) {{\n" "}}", dereference_fn_name, iterator_state_name, @@ -1429,7 +1447,7 @@ inline std::tuple make_discard_iterator_s else { dereference_fn_def_src = std::format( - "extern \"C\" __device__ void {0}({1}* /*state*/, {2} /*x*/) {{\n" + "extern \"C\" __device__ void {0}(void* /*state*/, const void* /*x*/) {{\n" "}}", dereference_fn_name, iterator_state_name, diff --git a/python/cuda_cccl/cuda/compute/_cccl_interop.py b/python/cuda_cccl/cuda/compute/_cccl_interop.py index 77d1a0e39e4..90ea4d68e00 100644 --- a/python/cuda_cccl/cuda/compute/_cccl_interop.py +++ b/python/cuda_cccl/cuda/compute/_cccl_interop.py @@ -248,7 +248,10 @@ def codegen(context, builder, impl_sig, args): input_vals = [builder.load(p) for p in input_ptrs] # Call the original operator - result = context.compile_internal(builder, op, sig, input_vals) + # See NVIDIA/numba-cuda#590 for why we need compile_subroutine + # vs compile_internal here: + cres = context.compile_subroutine(builder, op, sig, caching=False) + result = context.call_internal(builder, cres.fndesc, sig, input_vals) # Store the result builder.store(result, ret_ptr) @@ -277,6 +280,145 @@ def wrapped_{op.__name__}({arg_str}): return wrapper_func, void_sig +def _create_advance_wrapper(advance_fn, state_ptr_type): + """Creates a wrapper function for iterator advance that takes void* arguments. + + The wrapper takes 2 void* arguments: + - state pointer + - offset pointer (points to uint64 value) + """ + void_sig = types.void(types.voidptr, types.voidptr) + + wrapper_src = textwrap.dedent(f""" + @intrinsic + def impl(typingctx, state_arg, offset_arg): + def codegen(context, builder, impl_sig, args): + state_type_llvm = context.get_value_type(state_ptr_type.dtype) + offset_type_llvm = context.get_value_type(types.uint64) + + state_ptr = builder.bitcast(args[0], state_type_llvm.as_pointer()) + offset_ptr = builder.bitcast(args[1], offset_type_llvm.as_pointer()) + offset_val = builder.load(offset_ptr) + + sig = types.void(state_ptr_type, types.uint64) + cres = context.compile_subroutine(builder, advance_fn, sig, caching=False) + result = context.call_internal(builder, cres.fndesc, sig, [state_ptr, offset_val]) + + return context.get_dummy_value() + return void_sig, codegen + + def wrapped_{advance_fn.__name__}(state_arg, offset_arg): + return impl(state_arg, offset_arg) + """) + + local_dict = { + "types": types, + "state_ptr_type": state_ptr_type, + "advance_fn": advance_fn, + "intrinsic": intrinsic, + "void_sig": void_sig, + } + exec(wrapper_src, globals(), local_dict) + + wrapper_func = local_dict[f"wrapped_{advance_fn.__name__}"] + wrapper_func.__globals__.update(local_dict) + + return wrapper_func, void_sig + + +def _create_input_dereference_wrapper(deref_fn, state_ptr_type, value_type): + """Creates a wrapper function for input iterator dereference that takes void* arguments. + + The wrapper takes 2 void* arguments: + - state pointer + - result pointer + """ + void_sig = types.void(types.voidptr, types.voidptr) + + wrapper_src = textwrap.dedent(f""" + @intrinsic + def impl(typingctx, state_arg, result_arg): + def codegen(context, builder, impl_sig, args): + state_type_llvm = context.get_value_type(state_ptr_type.dtype) + value_type_llvm = context.get_value_type(value_type) + + state_ptr = builder.bitcast(args[0], state_type_llvm.as_pointer()) + result_ptr = builder.bitcast(args[1], value_type_llvm.as_pointer()) + + sig = types.void(state_ptr_type, types.CPointer(value_type)) + cres = context.compile_subroutine(builder, deref_fn, sig, caching=False) + result = context.call_internal(builder, cres.fndesc, sig, [state_ptr, result_ptr]) + + return context.get_dummy_value() + return void_sig, codegen + + def wrapped_{deref_fn.__name__}(state_arg, result_arg): + return impl(state_arg, result_arg) + """) + + local_dict = { + "types": types, + "state_ptr_type": state_ptr_type, + "value_type": value_type, + "deref_fn": deref_fn, + "intrinsic": intrinsic, + "void_sig": void_sig, + } + exec(wrapper_src, globals(), local_dict) + + wrapper_func = local_dict[f"wrapped_{deref_fn.__name__}"] + wrapper_func.__globals__.update(local_dict) + + return wrapper_func, void_sig + + +def _create_output_dereference_wrapper(deref_fn, state_ptr_type, value_type): + """Creates a wrapper function for output iterator dereference that takes void* arguments. + + The wrapper takes 2 void* arguments: + - state pointer + - value pointer (points to value) + """ + void_sig = types.void(types.voidptr, types.voidptr) + + wrapper_src = textwrap.dedent(f""" + @intrinsic + def impl(typingctx, state_arg, value_arg): + def codegen(context, builder, impl_sig, args): + state_type_llvm = context.get_value_type(state_ptr_type.dtype) + value_type_llvm = context.get_value_type(value_type) + + state_ptr = builder.bitcast(args[0], state_type_llvm.as_pointer()) + value_ptr = builder.bitcast(args[1], value_type_llvm.as_pointer()) + value_val = builder.load(value_ptr) + + sig = types.void(state_ptr_type, value_type) + cres = context.compile_subroutine(builder, deref_fn, sig, caching=False) + result = context.call_internal(builder, cres.fndesc, sig, [state_ptr, value_val]) + + return context.get_dummy_value() + return void_sig, codegen + + def wrapped_{deref_fn.__name__}(state_arg, value_arg): + return impl(state_arg, value_arg) + """) + + local_dict = { + "types": types, + "state_ptr_type": state_ptr_type, + "value_type": value_type, + "deref_fn": deref_fn, + "intrinsic": intrinsic, + "void_sig": void_sig, + } + exec(wrapper_src, globals(), local_dict) + + wrapper_func = local_dict[f"wrapped_{deref_fn.__name__}"] + wrapper_func.__globals__.update(local_dict) + + return wrapper_func, void_sig + + def to_cccl_op(op: Callable | OpKind, sig: Signature | None) -> Op: """Return an `Op` object corresponding to the given callable or well-known operation. diff --git a/python/cuda_cccl/cuda/compute/iterators/_iterators.py b/python/cuda_cccl/cuda/compute/iterators/_iterators.py index 6ac1684f6d3..8a532191329 100644 --- a/python/cuda_cccl/cuda/compute/iterators/_iterators.py +++ b/python/cuda_cccl/cuda/compute/iterators/_iterators.py @@ -136,48 +136,45 @@ def is_output_iterator(self) -> bool: return self.output_dereference is not None def get_advance_ltoir(self) -> Tuple: + from .._cccl_interop import _create_advance_wrapper + abi_name = f"advance_{_get_abi_suffix(self.kind)}" - signature = ( - self.state_ptr_type, - types.uint64, # distance type + wrapped_advance, wrapper_sig = _create_advance_wrapper( + self.advance, self.state_ptr_type ) ltoir, _ = cached_compile( - self.advance, - signature, + wrapped_advance, + wrapper_sig, output="ltoir", abi_name=abi_name, ) return (abi_name, ltoir) - def get_input_dereference_signature(self): - return ( - self.state_ptr_type, - types.CPointer(self.value_type), - ) - - def get_output_dereference_signature(self): - return ( - self.state_ptr_type, - self.value_type, - ) - def get_input_dereference_ltoir(self) -> Tuple: + from .._cccl_interop import _create_input_dereference_wrapper + abi_name = f"input_dereference_{_get_abi_suffix(self.kind)}" - signature = self.get_input_dereference_signature() + wrapped_deref, wrapper_sig = _create_input_dereference_wrapper( + self.input_dereference, self.state_ptr_type, self.value_type + ) ltoir, _ = cached_compile( - self.input_dereference, - signature, + wrapped_deref, + wrapper_sig, output="ltoir", abi_name=abi_name, ) return (abi_name, ltoir) def get_output_dereference_ltoir(self) -> Tuple: + from .._cccl_interop import _create_output_dereference_wrapper + abi_name = f"output_dereference_{_get_abi_suffix(self.kind)}" - signature = self.get_output_dereference_signature() + wrapped_deref, wrapper_sig = _create_output_dereference_wrapper( + self.output_dereference, self.state_ptr_type, self.value_type + ) ltoir, _ = cached_compile( - self.output_dereference, - signature, + wrapped_deref, + wrapper_sig, output="ltoir", abi_name=abi_name, ) diff --git a/python/cuda_cccl/cuda/compute/iterators/_zip_iterator.py b/python/cuda_cccl/cuda/compute/iterators/_zip_iterator.py index ad96ce4ab1e..24de85780ac 100644 --- a/python/cuda_cccl/cuda/compute/iterators/_zip_iterator.py +++ b/python/cuda_cccl/cuda/compute/iterators/_zip_iterator.py @@ -232,7 +232,7 @@ def __init__(self, iterators_list): state_type=state_type, value_type=value_type, ) - self.kind_ = self.__class__.iterator_kind_type( + self._kind = self.__class__.iterator_kind_type( (value_type, *kinds), self.state_type ) diff --git a/python/cuda_cccl/tests/compute/examples/iterator/transform_output_iterator.py b/python/cuda_cccl/tests/compute/examples/iterator/transform_output_iterator.py index a660ef82d48..a3c938eb6f7 100644 --- a/python/cuda_cccl/tests/compute/examples/iterator/transform_output_iterator.py +++ b/python/cuda_cccl/tests/compute/examples/iterator/transform_output_iterator.py @@ -12,7 +12,7 @@ import cuda.compute from cuda.compute import ( OpKind, - TransformIterator, + TransformOutputIterator, ) # Create input and output arrays @@ -27,7 +27,7 @@ def sqrt(x: np.float32) -> np.float32: # Create transform output iterator -d_out_it = TransformIterator(d_output, sqrt) +d_out_it = TransformOutputIterator(d_output, sqrt) # Apply a sum reduction into the transform output iterator diff --git a/python/cuda_cccl/tests/compute/test_zip_iterator.py b/python/cuda_cccl/tests/compute/test_zip_iterator.py index 5a365886835..0da2c14ef3e 100644 --- a/python/cuda_cccl/tests/compute/test_zip_iterator.py +++ b/python/cuda_cccl/tests/compute/test_zip_iterator.py @@ -387,10 +387,7 @@ def sum_nested_zips(v1, v2): "dtype_map", [ {"x": np.float32, "y": np.float32}, - pytest.param( - {"x": np.float64, "y": np.float32}, - marks=pytest.mark.xfail(reason="Fails due to ODR violation (GH #4573)"), - ), + {"x": np.float64, "y": np.float32}, ], ) def test_nested_output_zip_iterator_with_scan(monkeypatch, num_items, dtype_map): @@ -448,3 +445,17 @@ def add_vec2_pairs(v1, v2): np.testing.assert_array_equal(d_out1.get(), expected_out1) np.testing.assert_array_equal(d_out2.get(), expected_out2) + + +def test_zip_iterator_of_transform_iterator_kind(): + arr = cp.arange(10, dtype=np.int64) + + def f(x): + return x + + def g(x): + return x + 1 + + it1 = ZipIterator(TransformIterator(arr, f)) + it2 = ZipIterator(TransformIterator(arr, g)) + assert it1.kind != it2.kind