diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 6e68690048..18bb36d112 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -199,55 +199,92 @@ template using make_index_sequence = typename __make_integer_seq::seq_type; -// merge sequence -template -struct sequence_merge +// merge sequence - optimized to avoid recursive instantiation +namespace detail { + +// Helper to concatenate multiple sequences in one step using fold expression +template +struct sequence_merge_impl; + +// Base case: single sequence +template +struct sequence_merge_impl> { - using type = typename sequence_merge::type>::type; + using type = Sequence; }; +// Two sequences: direct concatenation template -struct sequence_merge, Sequence> +struct sequence_merge_impl, Sequence> { using type = Sequence; }; -template -struct sequence_merge +// Three sequences: direct concatenation (avoids one level of recursion) +template +struct sequence_merge_impl, Sequence, Sequence> { - using type = Seq; + using type = Sequence; }; -// generate sequence -template -struct sequence_gen +// Four sequences: direct concatenation +template +struct sequence_merge_impl, Sequence, Sequence, Sequence> { - template - struct sequence_gen_impl - { - static constexpr index_t NRemainLeft = NRemain / 2; - static constexpr index_t NRemainRight = NRemain - NRemainLeft; - static constexpr index_t IMiddle = IBegin + NRemainLeft; + using type = Sequence; +}; - using type = typename sequence_merge< - typename sequence_gen_impl::type, - typename sequence_gen_impl::type>::type; - }; +// General case: binary tree reduction (O(log N) depth instead of O(N)) +template +struct sequence_merge_impl +{ + // Merge pairs first, then recurse + using left = typename sequence_merge_impl::type; + using right = typename sequence_merge_impl::type; + using type = typename sequence_merge_impl::type; +}; - template - struct sequence_gen_impl - { - static constexpr index_t Is = G{}(Number{}); - using type = Sequence; - }; +} // namespace detail - template - struct sequence_gen_impl - { - using type = Sequence<>; - }; +template +struct sequence_merge +{ + using type = typename detail::sequence_merge_impl::type; +}; + +template <> +struct sequence_merge<> +{ + using type = Sequence<>; +}; + +// generate sequence - optimized using __make_integer_seq to avoid recursive instantiation +namespace detail { + +// Helper that applies functor F to indices and produces a Sequence +// __make_integer_seq produces sequence_gen_helper +template +struct sequence_gen_helper +{ + // Apply a functor F to all indices at once via pack expansion (O(1) depth) + template + using apply = Sequence{})...>; +}; + +} // namespace detail - using type = typename sequence_gen_impl<0, NSize, F>::type; +template +struct sequence_gen +{ + using type = + typename __make_integer_seq::template apply; +}; + +template +struct sequence_gen<0, F> +{ + using type = Sequence<>; }; // arithmetic sequence @@ -283,16 +320,30 @@ struct arithmetic_sequence_gen<0, IEnd, 1> using type = typename __make_integer_seq::type; }; -// uniform sequence +// uniform sequence - optimized using __make_integer_seq +namespace detail { + +template +struct uniform_sequence_helper +{ + // Apply a constant value to all indices via pack expansion + template + using apply = Sequence<((void)Is, Value)...>; +}; + +} // namespace detail + template struct uniform_sequence_gen { - struct F - { - __host__ __device__ constexpr index_t operator()(index_t) const { return I; } - }; + using type = typename __make_integer_seq:: + template apply; +}; - using type = typename sequence_gen::type; +template +struct uniform_sequence_gen<0, I> +{ + using type = Sequence<>; }; // reverse inclusive scan (with init) sequence diff --git a/include/ck/utility/statically_indexed_array.hpp b/include/ck/utility/statically_indexed_array.hpp index d0735a32f6..f3d73e84a7 100644 --- a/include/ck/utility/statically_indexed_array.hpp +++ b/include/ck/utility/statically_indexed_array.hpp @@ -20,6 +20,7 @@ struct tuple_concat, Tuple> using type = Tuple; }; +// StaticallyIndexedArrayImpl uses binary split for O(log N) depth template struct StaticallyIndexedArrayImpl {