Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 90 additions & 39 deletions include/ck/utility/sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,55 +199,92 @@ template <index_t N>
using make_index_sequence =
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;

// merge sequence
template <typename Seq, typename... Seqs>
struct sequence_merge
// merge sequence - optimized to avoid recursive instantiation
namespace detail {

// Helper to concatenate multiple sequences in one step using fold expression
template <typename... Seqs>
struct sequence_merge_impl;

// Base case: single sequence
template <index_t... Is>
struct sequence_merge_impl<Sequence<Is...>>
{
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
using type = Sequence<Is...>;
};

// Two sequences: direct concatenation
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>>
{
using type = Sequence<Xs..., Ys...>;
};

template <typename Seq>
struct sequence_merge<Seq>
// Three sequences: direct concatenation (avoids one level of recursion)
template <index_t... Xs, index_t... Ys, index_t... Zs>
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>>
{
using type = Seq;
using type = Sequence<Xs..., Ys..., Zs...>;
};

// generate sequence
template <index_t NSize, typename F>
struct sequence_gen
// Four sequences: direct concatenation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like these specializations. It will be interesting to get a survey of the code to see how often the specializations are used and if these four smallest cases are the most impactful ones.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using the build traces to drive the optimizations. Maybe removing the unused code is one other aspect which could help with parsing times

template <index_t... As, index_t... Bs, index_t... Cs, index_t... Ds>
struct sequence_merge_impl<Sequence<As...>, Sequence<Bs...>, Sequence<Cs...>, Sequence<Ds...>>
{
template <index_t IBegin, index_t NRemain, typename G>
struct sequence_gen_impl
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type = Sequence<As..., Bs..., Cs..., Ds...>;
};

using type = typename sequence_merge<
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
};
// General case: binary tree reduction (O(log N) depth instead of O(N))
template <typename S1, typename S2, typename S3, typename S4, typename... Rest>
struct sequence_merge_impl<S1, S2, S3, S4, Rest...>
{
// Merge pairs first, then recurse
using left = typename sequence_merge_impl<S1, S2>::type;
using right = typename sequence_merge_impl<S3, S4, Rest...>::type;
using type = typename sequence_merge_impl<left, right>::type;
};

template <index_t I, typename G>
struct sequence_gen_impl<I, 1, G>
{
static constexpr index_t Is = G{}(Number<I>{});
using type = Sequence<Is>;
};
} // namespace detail

template <index_t I, typename G>
struct sequence_gen_impl<I, 0, G>
{
using type = Sequence<>;
};
template <typename... Seqs>
struct sequence_merge
{
using type = typename detail::sequence_merge_impl<Seqs...>::type;
};

template <>
struct sequence_merge<>
{
using type = Sequence<>;
};

// generate sequence - optimized using __make_integer_seq to avoid recursive instantiation
namespace detail {

// Helper that applies functor F to indices and produces a Sequence
// __make_integer_seq<sequence_gen_helper, index_t, N> produces sequence_gen_helper<index_t, 0, 1,
// ..., N-1>
template <typename T, T... Is>
struct sequence_gen_helper
{
// Apply a functor F to all indices at once via pack expansion (O(1) depth)
template <typename F>
using apply = Sequence<F{}(Number<Is>{})...>;
};

} // namespace detail

using type = typename sequence_gen_impl<0, NSize, F>::type;
template <index_t NSize, typename F>
struct sequence_gen
{
using type =
typename __make_integer_seq<detail::sequence_gen_helper, index_t, NSize>::template apply<F>;
};

template <typename F>
struct sequence_gen<0, F>
{
using type = Sequence<>;
};

// arithmetic sequence
Expand Down Expand Up @@ -283,16 +320,30 @@ struct arithmetic_sequence_gen<0, IEnd, 1>
using type = typename __make_integer_seq<WrapSequence, index_t, IEnd>::type;
};

// uniform sequence
// uniform sequence - optimized using __make_integer_seq
namespace detail {

template <typename T, T... Is>
struct uniform_sequence_helper
{
// Apply a constant value to all indices via pack expansion
template <index_t Value>
using apply = Sequence<((void)Is, Value)...>;
};

} // namespace detail

template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
struct F
{
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
};
using type = typename __make_integer_seq<detail::uniform_sequence_helper, index_t, NSize>::
template apply<I>;
};

using type = typename sequence_gen<NSize, F>::type;
template <index_t I>
struct uniform_sequence_gen<0, I>
{
using type = Sequence<>;
};

// reverse inclusive scan (with init) sequence
Expand Down
1 change: 1 addition & 0 deletions include/ck/utility/statically_indexed_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
using type = Tuple<Xs..., Ys...>;
};

// StaticallyIndexedArrayImpl uses binary split for O(log N) depth
template <typename T, index_t N>
struct StaticallyIndexedArrayImpl
{
Expand Down