1010
1111#include < cub/detail/choose_offset.cuh>
1212#include < cub/detail/launcher/cuda_driver.cuh>
13- #include < cub/detail/ptx-json-parser.h>
1413#include < cub/device/device_reduce.cuh>
15- #include < cub/grid/grid_even_share.cuh>
1614#include < cub/util_device.cuh>
1715
1816#include < cuda/std/algorithm>
@@ -44,29 +42,6 @@ static_assert(std::is_same_v<cub::detail::choose_offset_t<OffsetT>, OffsetT>, "O
4442
4543namespace reduce
4644{
47- struct reduce_runtime_tuning_policy
48- {
49- cub::detail::RuntimeReduceAgentPolicy single_tile;
50- cub::detail::RuntimeReduceAgentPolicy reduce;
51-
52- auto SingleTile () const
53- {
54- return single_tile;
55- }
56- auto Reduce () const
57- {
58- return reduce;
59- }
60-
61- using MaxPolicy = reduce_runtime_tuning_policy;
62-
63- template <typename F>
64- cudaError_t Invoke (int , F& op)
65- {
66- return op.template Invoke <reduce_runtime_tuning_policy>(*this );
67- }
68- };
69-
7045static cccl_type_info get_accumulator_type (cccl_op_t /* op*/ , cccl_iterator_t /* input_it*/ , cccl_value_t init)
7146{
7247 // TODO Should be decltype(op(init, *input_it)) but haven't implemented type arithmetic yet
@@ -179,7 +154,6 @@ CUresult cccl_device_reduce_build_ex(
179154 {
180155 const char * name = " device_reduce" ;
181156
182- const int cc = cc_major * 10 + cc_minor;
183157 const cccl_type_info accum_t = reduce::get_accumulator_type (op, input_it, init);
184158 const auto accum_cpp = cccl_type_enum_to_name (accum_t .type );
185159
@@ -193,7 +167,8 @@ CUresult cccl_device_reduce_build_ex(
193167
194168 const auto offset_t = cccl_type_enum_to_name (cccl_type_enum::CCCL_UINT64);
195169
196- auto policy_hub_expr = std::format (" cub::detail::reduce::policy_hub<{}, {}, {}>" , accum_cpp, offset_t , op_name);
170+ auto policy_hub_expr =
171+ std::format (" cub::detail::reduce::arch_policies_from_types<{}, {}, {}>" , accum_cpp, offset_t , op_name);
197172
198173 std::string final_src = std::format (
199174 R"XXX(
@@ -206,13 +181,7 @@ struct __align__({2}) storage_t {{
206181{3}
207182{4}
208183{5}
209- using device_reduce_policy = {6}::MaxPolicy;
210-
211- #include <cub/detail/ptx-json/json.h>
212- __device__ consteval auto& policy_generator() {{
213- return ptx_json::id<ptx_json::string("device_reduce_policy")>()
214- = cub::detail::reduce::ReducePolicyWrapper<device_reduce_policy::ActivePolicy>::EncodedPolicy();
215- }};
184+ using device_reduce_policy = {6};
216185)XXX" ,
217186 jit_template_header_contents, // 0
218187 input_it.value_type .size , // 1
@@ -249,7 +218,6 @@ __device__ consteval auto& policy_generator() {{
249218 " -rdc=true" ,
250219 " -dlto" ,
251220 " -DCUB_DISABLE_CDP" ,
252- " -DCUB_ENABLE_POLICY_PTX_JSON" ,
253221 " -std=c++20" };
254222
255223 // Add user's extra flags if config is provided
@@ -286,18 +254,40 @@ __device__ consteval auto& policy_generator() {{
286254 &build->single_tile_second_kernel , build->library , single_tile_second_kernel_lowered_name.c_str ()));
287255 check (cuLibraryGetKernel (&build->reduction_kernel , build->library , reduction_kernel_lowered_name.c_str ()));
288256
289- nlohmann::json runtime_policy =
290- cub::detail::ptx_json::parse (" device_reduce_policy" , {result.data .get (), result.size });
257+ // convert type information to CUB arch_policies
258+ using namespace cub ::detail::reduce;
259+
260+ auto accum_type = accum_type::other;
261+ if (accum_t .type == CCCL_FLOAT32)
262+ {
263+ accum_type = accum_type::float32;
264+ }
265+ if (accum_t .type == CCCL_FLOAT64)
266+ {
267+ accum_type = accum_type::double32;
268+ }
269+
270+ auto operation_t = op_type::unknown;
271+ switch (op.type )
272+ {
273+ case CCCL_PLUS:
274+ operation_t = op_type::plus;
275+ break ;
276+ case CCCL_MINIMUM:
277+ case CCCL_MAXIMUM:
278+ operation_t = op_type::min_or_max;
279+ break ;
280+ default :
281+ break ;
282+ }
291283
292- using cub::detail::RuntimeReduceAgentPolicy;
293- auto reduce_policy = RuntimeReduceAgentPolicy::from_json (runtime_policy, " ReducePolicy" );
294- auto st_policy = RuntimeReduceAgentPolicy::from_json (runtime_policy, " SingleTilePolicy" );
284+ const int offset_size = int {sizeof (uint64_t )};
295285
296- build->cc = cc ;
286+ build->cc = cc_major * 10 + cc_minor ;
297287 build->cubin = (void *) result.data .release ();
298288 build->cubin_size = result.size ;
299289 build->accumulator_size = accum_t .size ;
300- build->runtime_policy = new reduce::reduce_runtime_tuning_policy{st_policy, reduce_policy };
290+ build->runtime_policy = new arch_policies{accum_type, operation_t , offset_size, static_cast < int >( accum_t . size ) };
301291 }
302292 catch (const std::exception& exc)
303293 {
@@ -330,30 +320,19 @@ CUresult cccl_device_reduce(
330320 CUdevice cu_device;
331321 check (cuCtxGetDevice (&cu_device));
332322
333- auto exec_status = cub::DispatchReduce<
334- indirect_arg_t , // InputIteratorT
335- indirect_arg_t , // OutputIteratorT
336- ::cuda::std::size_t , // OffsetT
337- indirect_arg_t , // ReductionOpT
338- indirect_arg_t , // InitT
339- void , // AccumT
340- ::cuda::std::identity, // TransformOpT
341- reduce::reduce_runtime_tuning_policy, // PolicyHub
342- reduce::reduce_kernel_source, // KernelSource
343- cub::detail::CudaDriverLauncherFactory>:: // KernelLauncherFactory
344- Dispatch (
345- d_temp_storage,
346- *temp_storage_bytes,
347- d_in,
348- d_out,
349- num_items,
350- op,
351- init,
352- stream,
353- {},
354- {build},
355- cub::detail::CudaDriverLauncherFactory{cu_device, build.cc },
356- *reinterpret_cast <reduce::reduce_runtime_tuning_policy*>(build.runtime_policy ));
323+ auto exec_status = cub::detail::reduce::dispatch<void >(
324+ d_temp_storage,
325+ *temp_storage_bytes,
326+ indirect_arg_t {d_in}, // could be indirect_iterator_t, but CUB does not need to increment it
327+ indirect_arg_t {d_out}, // could be indirect_iterator_t, but CUB does not need to increment it
328+ static_cast <::cuda::std::size_t >(num_items),
329+ indirect_arg_t {op},
330+ indirect_arg_t {init},
331+ stream,
332+ ::cuda::std::identity{},
333+ *static_cast <cub::detail::reduce::arch_policies*>(build.runtime_policy ),
334+ reduce::reduce_kernel_source{build},
335+ cub::detail::CudaDriverLauncherFactory{cu_device, build.cc });
357336
358337 error = static_cast <CUresult>(exec_status);
359338 }
@@ -383,8 +362,9 @@ CUresult cccl_device_reduce_cleanup(cccl_device_reduce_build_result_t* build_ptr
383362 return CUDA_ERROR_INVALID_VALUE;
384363 }
385364
386- std::unique_ptr<char []> cubin (reinterpret_cast <char *>(build_ptr->cubin ));
387- std::unique_ptr<char []> policy (reinterpret_cast <char *>(build_ptr->runtime_policy ));
365+ using namespace cub ::detail::reduce;
366+ std::unique_ptr<char []> cubin (static_cast <char *>(build_ptr->cubin ));
367+ std::unique_ptr<arch_policies> policy (static_cast <arch_policies*>(build_ptr->runtime_policy ));
388368 check (cuLibraryUnload (build_ptr->library ));
389369 }
390370 catch (const std::exception& exc)
0 commit comments