Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions cpp/src/groupby/sort/group_count_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
* SPDX-License-Identifier: Apache-2.0
*/

#include "group_scan.hpp"

#include <cudf/column/column.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/types.hpp>
#include <cudf/utilities/memory_resource.hpp>
#include <cudf/utilities/span.hpp>
Expand All @@ -18,7 +21,9 @@
namespace cudf {
namespace groupby {
namespace detail {
std::unique_ptr<column> count_scan(cudf::device_span<size_type const> group_labels,
std::unique_ptr<column> count_scan(column_view const& values,
null_policy nulls,
cudf::device_span<size_type const> group_labels,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
Expand All @@ -29,11 +34,24 @@ std::unique_ptr<column> count_scan(cudf::device_span<size_type const> group_labe

auto resultview = result->mutable_view();
// aggregation::COUNT_ALL
thrust::inclusive_scan_by_key(rmm::exec_policy(stream),
group_labels.begin(),
group_labels.end(),
thrust::make_constant_iterator<size_type>(1),
resultview.begin<size_type>());
if (nulls == null_policy::INCLUDE) {
thrust::inclusive_scan_by_key(rmm::exec_policy(stream),
group_labels.begin(),
group_labels.end(),
thrust::make_constant_iterator<size_type>(1),
resultview.begin<size_type>());
} else { // aggregation::COUNT_VALID
auto d_values = cudf::column_device_view::create(values, stream);
auto itr = cudf::detail::make_counting_transform_iterator(
0, [d_values = *d_values] __device__(auto idx) -> cudf::size_type {
return d_values.is_valid(idx);
});
thrust::inclusive_scan_by_key(rmm::exec_policy(stream),
group_labels.begin(),
group_labels.end(),
itr,
resultview.begin<size_type>());
}
return result;
}

Expand Down
8 changes: 6 additions & 2 deletions cpp/src/groupby/sort/group_scan.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -82,12 +82,16 @@ std::unique_ptr<column> max_scan(column_view const& values,
/**
* @brief Internal API to calculate cumulative number of values in each group
*
* @param values Grouped values to get valid rows from
* @param nulls Indicates whether nulls should be included in the count or not
* @param group_labels ID of group that the corresponding value belongs to
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return Column of type INT32 of count values
*/
std::unique_ptr<column> count_scan(device_span<size_type const> group_labels,
std::unique_ptr<column> count_scan(column_view const& values,
null_policy nulls,
device_span<size_type const> group_labels,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr);

Expand Down
19 changes: 17 additions & 2 deletions cpp/src/groupby/sort/scan.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -116,7 +116,22 @@ void scan_result_functor::operator()<aggregation::COUNT_ALL>(aggregation const&
{
if (cache.has_result(values, agg)) return;

cache.add_result(values, agg, detail::count_scan(helper.group_labels(stream), stream, mr));
cache.add_result(
values,
agg,
detail::count_scan(values, null_policy::INCLUDE, helper.group_labels(stream), stream, mr));
}

template <>
void scan_result_functor::operator()<aggregation::COUNT_VALID>(aggregation const& agg)
{
if (cache.has_result(values, agg)) return;

cache.add_result(
values,
agg,
detail::count_scan(
get_grouped_values(), null_policy::EXCLUDE, helper.group_labels(stream), stream, mr));
}

template <>
Expand Down
62 changes: 36 additions & 26 deletions cpp/tests/groupby/count_scan_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ TYPED_TEST(groupby_count_scan_test, basic)
result_wrapper expect_vals{1, 2, 3, 1, 2, 3, 4, 1, 2, 3};
// clang-format on

// Count groupby aggregation is only supported with cudf::null_policy::EXCLUDE
auto agg1 = cudf::make_count_aggregation<cudf::groupby_scan_aggregation>();
EXPECT_THROW(test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)),
cudf::logic_error);
auto agg1 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::EXCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1));

auto agg2 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::INCLUDE);
Expand All @@ -57,8 +56,9 @@ TYPED_TEST(groupby_count_scan_test, empty_cols)
key_wrapper expect_keys;
result_wrapper expect_vals;

auto agg1 = cudf::make_count_aggregation<cudf::groupby_scan_aggregation>();
EXPECT_NO_THROW(test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)));
auto agg1 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::EXCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1));

auto agg2 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::INCLUDE);
Expand All @@ -75,6 +75,10 @@ TYPED_TEST(groupby_count_scan_test, zero_valid_keys)
key_wrapper expect_keys{};
result_wrapper expect_vals{};

auto agg1 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::EXCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1));

auto agg2 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::INCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg2));
Expand All @@ -88,8 +92,13 @@ TYPED_TEST(groupby_count_scan_test, zero_valid_values)
key_wrapper keys{1, 1, 1};
value_wrapper vals({3, 4, 5}, cudf::test::iterators::all_nulls());
key_wrapper expect_keys{1, 1, 1};
result_wrapper expect_vals{1, 2, 3};
result_wrapper expect_vals{0, 0, 0};

auto agg1 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::EXCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1));

expect_vals = result_wrapper{1, 2, 3};
auto agg2 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::INCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg2));
Expand All @@ -106,10 +115,15 @@ TYPED_TEST(groupby_count_scan_test, null_keys_and_values)

// {1, 1, 1, 2, 2, 2, 2, 3, _, 3, 4}
key_wrapper expect_keys( {1, 1, 1, 2, 2, 2, 2, 3, 3, 4}, cudf::test::iterators::no_nulls());
// {0, 3, 6, 1, 4, _, 9, 2, 7, 8, -}
result_wrapper expect_vals{1, 2, 3, 1, 2, 3, 4, 1, 2, 1};
// {_, 3, 6, 1, 4, _, 9, 2, 7, 8, _}
result_wrapper expect_vals{0, 1, 2, 1, 2, 2, 3, 1, 2, 0};
// clang-format on

auto agg1 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::EXCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1));

expect_vals = result_wrapper{1, 2, 3, 1, 2, 3, 4, 1, 2, 1};
auto agg2 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::INCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg2));
Expand All @@ -130,6 +144,10 @@ TEST_F(groupby_count_scan_string_test, basic)
result_wrapper expect_vals{1, 1, 1, 2, 1, 2};
// clang-format on

auto agg1 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::EXCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1));

auto agg2 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::INCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg2));
Expand Down Expand Up @@ -157,13 +175,9 @@ TYPED_TEST(GroupByCountScanFixedPointTest, GroupByCountScan)
auto const expect_keys = key_wrapper{1, 1, 1, 2, 2, 2, 2, 3, 3, 3};
auto const expect_vals = result_wrapper{1, 2, 3, 1, 2, 3, 4, 1, 2, 3};

// Count groupby aggregation is only supported with cudf::null_policy::EXCLUDE
EXPECT_THROW(test_single_scan(keys,
vals,
expect_keys,
expect_vals,
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>()),
cudf::logic_error);
auto agg1 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::EXCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1));

auto agg2 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::INCLUDE);
Expand All @@ -184,14 +198,10 @@ TEST_F(groupby_dictionary_count_scan_test, basic)
cudf::test::strings_column_wrapper expect_keys{"0", "1", "3", "3", "5", "5"};
result_wrapper expect_vals{1, 1, 1, 2, 1, 2};

// Count groupby aggregation is only supported with cudf::null_policy::EXCLUDE
auto agg1 = cudf::make_count_aggregation<cudf::groupby_scan_aggregation>();
EXPECT_THROW(test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1)),
cudf::logic_error);
test_single_scan(
keys,
vals,
expect_keys,
expect_vals,
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::INCLUDE));
auto agg1 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::EXCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg1));
auto agg2 =
cudf::make_count_aggregation<cudf::groupby_scan_aggregation>(cudf::null_policy::INCLUDE);
test_single_scan(keys, vals, expect_keys, expect_vals, std::move(agg2));
}
Loading