diff --git a/cpp/src/groupby/sort/group_count_scan.cu b/cpp/src/groupby/sort/group_count_scan.cu index 9108fcf2782..2d08a0fe841 100644 --- a/cpp/src/groupby/sort/group_count_scan.cu +++ b/cpp/src/groupby/sort/group_count_scan.cu @@ -3,8 +3,11 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include "group_scan.hpp" + #include #include +#include #include #include #include @@ -18,7 +21,9 @@ namespace cudf { namespace groupby { namespace detail { -std::unique_ptr count_scan(cudf::device_span group_labels, +std::unique_ptr count_scan(column_view const& values, + null_policy nulls, + cudf::device_span group_labels, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { @@ -29,11 +34,24 @@ std::unique_ptr count_scan(cudf::device_span group_labe auto resultview = result->mutable_view(); // aggregation::COUNT_ALL - thrust::inclusive_scan_by_key(rmm::exec_policy(stream), - group_labels.begin(), - group_labels.end(), - thrust::make_constant_iterator(1), - resultview.begin()); + if (nulls == null_policy::INCLUDE) { + thrust::inclusive_scan_by_key(rmm::exec_policy(stream), + group_labels.begin(), + group_labels.end(), + thrust::make_constant_iterator(1), + resultview.begin()); + } else { // aggregation::COUNT_VALID + auto d_values = cudf::column_device_view::create(values, stream); + auto itr = cudf::detail::make_counting_transform_iterator( + 0, [d_values = *d_values] __device__(auto idx) -> cudf::size_type { + return d_values.is_valid(idx); + }); + thrust::inclusive_scan_by_key(rmm::exec_policy(stream), + group_labels.begin(), + group_labels.end(), + itr, + resultview.begin()); + } return result; } diff --git a/cpp/src/groupby/sort/group_scan.hpp b/cpp/src/groupby/sort/group_scan.hpp index f2c1ba45d23..8896b64b0b2 100644 --- a/cpp/src/groupby/sort/group_scan.hpp +++ b/cpp/src/groupby/sort/group_scan.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -82,12 +82,16 @@ std::unique_ptr max_scan(column_view const& values, /** * @brief Internal API to calculate cumulative number of values in each group * + * @param values Grouped values to get valid rows from + * @param nulls Indicates whether nulls should be included in the count or not * @param group_labels ID of group that the corresponding value belongs to * @param stream CUDA stream used for device memory operations and kernel launches * @param mr Device memory resource used to allocate the returned column's device memory * @return Column of type INT32 of count values */ -std::unique_ptr count_scan(device_span group_labels, +std::unique_ptr count_scan(column_view const& values, + null_policy nulls, + device_span group_labels, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr); diff --git a/cpp/src/groupby/sort/scan.cpp b/cpp/src/groupby/sort/scan.cpp index e99f0b69cfc..e4ae039d2f1 100644 --- a/cpp/src/groupby/sort/scan.cpp +++ b/cpp/src/groupby/sort/scan.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -116,7 +116,22 @@ void scan_result_functor::operator()(aggregation const& { if (cache.has_result(values, agg)) return; - cache.add_result(values, agg, detail::count_scan(helper.group_labels(stream), stream, mr)); + cache.add_result( + values, + agg, + detail::count_scan(values, null_policy::INCLUDE, helper.group_labels(stream), stream, mr)); +} + +template <> +void scan_result_functor::operator()(aggregation const& agg) +{ + if (cache.has_result(values, agg)) return; + + cache.add_result( + values, + agg, + detail::count_scan( + get_grouped_values(), null_policy::EXCLUDE, helper.group_labels(stream), stream, mr)); } template <> diff --git a/cpp/tests/groupby/count_scan_tests.cpp b/cpp/tests/groupby/count_scan_tests.cpp index 078157593a1..f5b10d6606d 100644 --- a/cpp/tests/groupby/count_scan_tests.cpp +++ b/cpp/tests/groupby/count_scan_tests.cpp @@ -37,10 +37,9 @@ TYPED_TEST(groupby_count_scan_test, basic) result_wrapper expect_vals{1, 2, 3, 1, 2, 3, 4, 1, 2, 3}; // clang-format on - // Count groupby aggregation is only supported with cudf::null_policy::EXCLUDE - auto agg1 = cudf::make_count_aggregation(); - EXPECT_THROW(test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)), - cudf::logic_error); + auto agg1 = + cudf::make_count_aggregation(cudf::null_policy::EXCLUDE); + test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)); auto agg2 = cudf::make_count_aggregation(cudf::null_policy::INCLUDE); @@ -57,8 +56,9 @@ TYPED_TEST(groupby_count_scan_test, empty_cols) key_wrapper expect_keys; result_wrapper expect_vals; - auto agg1 = cudf::make_count_aggregation(); - EXPECT_NO_THROW(test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1))); + auto agg1 = + cudf::make_count_aggregation(cudf::null_policy::EXCLUDE); + test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)); auto agg2 = cudf::make_count_aggregation(cudf::null_policy::INCLUDE); @@ -75,6 +75,10 @@ TYPED_TEST(groupby_count_scan_test, zero_valid_keys) key_wrapper expect_keys{}; result_wrapper expect_vals{}; + auto agg1 = + cudf::make_count_aggregation(cudf::null_policy::EXCLUDE); + test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)); + auto agg2 = cudf::make_count_aggregation(cudf::null_policy::INCLUDE); test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg2)); @@ -88,8 +92,13 @@ TYPED_TEST(groupby_count_scan_test, zero_valid_values) key_wrapper keys{1, 1, 1}; value_wrapper vals({3, 4, 5}, cudf::test::iterators::all_nulls()); key_wrapper expect_keys{1, 1, 1}; - result_wrapper expect_vals{1, 2, 3}; + result_wrapper expect_vals{0, 0, 0}; + + auto agg1 = + cudf::make_count_aggregation(cudf::null_policy::EXCLUDE); + test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)); + expect_vals = result_wrapper{1, 2, 3}; auto agg2 = cudf::make_count_aggregation(cudf::null_policy::INCLUDE); test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg2)); @@ -106,10 +115,15 @@ TYPED_TEST(groupby_count_scan_test, null_keys_and_values) // {1, 1, 1, 2, 2, 2, 2, 3, _, 3, 4} key_wrapper expect_keys( {1, 1, 1, 2, 2, 2, 2, 3, 3, 4}, cudf::test::iterators::no_nulls()); - // {0, 3, 6, 1, 4, _, 9, 2, 7, 8, -} - result_wrapper expect_vals{1, 2, 3, 1, 2, 3, 4, 1, 2, 1}; + // {_, 3, 6, 1, 4, _, 9, 2, 7, 8, _} + result_wrapper expect_vals{0, 1, 2, 1, 2, 2, 3, 1, 2, 0}; // clang-format on + auto agg1 = + cudf::make_count_aggregation(cudf::null_policy::EXCLUDE); + test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)); + + expect_vals = result_wrapper{1, 2, 3, 1, 2, 3, 4, 1, 2, 1}; auto agg2 = cudf::make_count_aggregation(cudf::null_policy::INCLUDE); test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg2)); @@ -130,6 +144,10 @@ TEST_F(groupby_count_scan_string_test, basic) result_wrapper expect_vals{1, 1, 1, 2, 1, 2}; // clang-format on + auto agg1 = + cudf::make_count_aggregation(cudf::null_policy::EXCLUDE); + test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)); + auto agg2 = cudf::make_count_aggregation(cudf::null_policy::INCLUDE); test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg2)); @@ -157,13 +175,9 @@ TYPED_TEST(GroupByCountScanFixedPointTest, GroupByCountScan) auto const expect_keys = key_wrapper{1, 1, 1, 2, 2, 2, 2, 3, 3, 3}; auto const expect_vals = result_wrapper{1, 2, 3, 1, 2, 3, 4, 1, 2, 3}; - // Count groupby aggregation is only supported with cudf::null_policy::EXCLUDE - EXPECT_THROW(test_single_scan(keys, - vals, - expect_keys, - expect_vals, - cudf::make_count_aggregation()), - cudf::logic_error); + auto agg1 = + cudf::make_count_aggregation(cudf::null_policy::EXCLUDE); + test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)); auto agg2 = cudf::make_count_aggregation(cudf::null_policy::INCLUDE); @@ -184,14 +198,10 @@ TEST_F(groupby_dictionary_count_scan_test, basic) cudf::test::strings_column_wrapper expect_keys{"0", "1", "3", "3", "5", "5"}; result_wrapper expect_vals{1, 1, 1, 2, 1, 2}; - // Count groupby aggregation is only supported with cudf::null_policy::EXCLUDE - auto agg1 = cudf::make_count_aggregation(); - EXPECT_THROW(test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)), - cudf::logic_error); - test_single_scan( - keys, - vals, - expect_keys, - expect_vals, - cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); + auto agg1 = + cudf::make_count_aggregation(cudf::null_policy::EXCLUDE); + test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)); + auto agg2 = + cudf::make_count_aggregation(cudf::null_policy::INCLUDE); + test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg2)); }