Skip to content

Commit c05fa96

Browse files
Optimize for CG again
1 parent b983e89 commit c05fa96

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

libcudacxx/include/cuda/__memcpy_async/cp_async_bulk_shared_global.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,46 @@
3737

3838
# include <cuda/std/__cccl/prologue.h>
3939

40+
namespace cooperative_groups
41+
{
42+
namespace __v1
43+
{
44+
class thread_block;
45+
46+
template <unsigned int Size, typename ParentT>
47+
class thread_block_tile;
48+
} // namespace __v1
49+
using namespace __v1;
50+
} // namespace cooperative_groups
51+
4052
_CCCL_BEGIN_NAMESPACE_CUDA
4153

54+
template <typename Group>
55+
[[nodiscard]] _CCCL_DEVICE _CCCL_FORCEINLINE unsigned int __thread_rank(const Group& __g)
56+
{
57+
return __g.thread_rank();
58+
}
59+
60+
// elect from the whole thread block
61+
[[nodiscard]] _CCCL_DEVICE _CCCL_FORCEINLINE bool
62+
__elect_from_group(const cooperative_groups::thread_block& __g) noexcept
63+
{
64+
// Cannot call __g.thread_rank(), because we only forward declared the thread_block type
65+
// cooperative groups (and we here) maps a multidimensional thread id into the thread rank the same way as warps do
66+
const unsigned int tid = threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x;
67+
const unsigned int warp_id = tid / 32;
68+
const unsigned int uniform_warp_id = __shfl_sync(0xFFFFFFFF, warp_id, 0); // broadcast from lane 0
69+
return uniform_warp_id == 0 && ::cuda::ptx::elect_sync(0xFFFFFFFF); // elect a leader thread among warp 0
70+
}
71+
72+
// elect from a single warp
73+
template <typename Parent>
74+
[[nodiscard]] _CCCL_DEVICE _CCCL_FORCEINLINE bool
75+
__elect_from_group(const cooperative_groups::thread_block_tile<32, Parent>&) noexcept
76+
{
77+
return ::cuda::ptx::elect_sync(0xFFFFFFFF); // elect a leader thread among warp 0
78+
}
79+
4280
template <typename _Group>
4381
[[nodiscard]] _CCCL_DEVICE _CCCL_FORCEINLINE bool __elect_from_group(const _Group& __g) noexcept
4482
{

0 commit comments

Comments
 (0)