Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions cub/cub/detail/env_dispatch.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remark: I assume we are retaining BSD-3 here, cause that's what we had in the file where we extracted the below code from.

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cub/detail/device_memory_resource.cuh>
#include <cub/detail/temporary_storage.cuh>

#include <cuda/__execution/tune.h>
#include <cuda/__memory_resource/get_memory_resource.h>
#include <cuda/__stream/get_stream.h>
#include <cuda/std/__execution/env.h>

CUB_NAMESPACE_BEGIN

namespace detail
{
//! @cond
//! Generic environment-based algorithm dispatch wrapper
//!
//! Handles common boilerplate for all env-based algorithms:
//! - Query stream, memory resource, and tuning from environment
//! - Two-phase call (query temp storage size, then execute)
//! - Temporary storage allocation/deallocation
//! - Memory resource querying from environment
//!
//! @param env The execution environment
//! @param determinism Pre-computed determinism type (algorithm-specific)
//! @param algorithm_callable Callable that invokes the algorithm implementation with determinism specified
//! @param algorithm_args Arguments to forward to the algorithm implementation
template <typename EnvT, typename DeterminismT, typename AlgorithmCallable, typename... AlgorithmArgs>
CUB_RUNTIME_FUNCTION static cudaError_t dispatch_with_env(
EnvT env, DeterminismT determinism, AlgorithmCallable algorithm_callable, AlgorithmArgs&&... algorithm_args)
{
// Query stream from environment
auto stream = ::cuda::std::execution::__query_or(env, ::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}});

// Query memory resource from environment
auto mr =
::cuda::std::execution::__query_or(env, ::cuda::mr::__get_memory_resource, detail::device_memory_resource{});

// Query tuning from environment
using tuning_t =
::cuda::std::execution::__query_result_or_t<EnvT, ::cuda::execution::__get_tuning_t, ::cuda::std::execution::env<>>;

void* d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;

// Phase 1: Query temporary storage size
cudaError_t error = algorithm_callable(
tuning_t{},
d_temp_storage,
temp_storage_bytes,
::cuda::std::forward<AlgorithmArgs>(algorithm_args)...,
determinism,
stream.get());

if (error != cudaSuccess)
{
return error;
}

// Allocate temporary storage
error = CubDebug(detail::temporary_storage::allocate(stream, d_temp_storage, temp_storage_bytes, mr));
if (error != cudaSuccess)
{
return error;
}
Comment on lines +73 to +77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: I got really fond of this style of error handling recently:

Suggested change
error = CubDebug(detail::temporary_storage::allocate(stream, d_temp_storage, temp_storage_bytes, mr));
if (error != cudaSuccess)
{
return error;
}
if (const auto error = CubDebug(detail::temporary_storage::allocate(stream, d_temp_storage, temp_storage_bytes, mr)))
{
return error;
}

This way, we don't leak an error variable anywhere else, so we are always reporting the error we just got.


// Phase 2: Execute algorithm
error = algorithm_callable(
tuning_t{},
d_temp_storage,
temp_storage_bytes,
::cuda::std::forward<AlgorithmArgs>(algorithm_args)...,
determinism,
stream.get());

// Deallocate temporary storage (always attempt, even on error)
cudaError_t deallocate_error =
CubDebug(detail::temporary_storage::deallocate(stream, d_temp_storage, temp_storage_bytes, mr));

// Algorithm error takes precedence over deallocation error
return (error != cudaSuccess) ? error : deallocate_error;
}
//! @endcond
} // namespace detail

CUB_NAMESPACE_END
Loading
Loading