Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 37 additions & 19 deletions include/ck/utility/sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,31 +597,49 @@ struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMa
{
};

template <typename SeqMap>
struct sequence_map_inverse
// Invert a permutation sequence: given X2Y = {a, b, c, ...}, compute Y2X where Y2X[X2Y[i]] = i
// Example: Sequence<2,0,1> (meaning pos0->2, pos1->0, pos2->1) inverts to Sequence<1,2,0>
//
// Why this implementation is faster to compile than recursive templates:
//
// The old recursive approach created a new template type for each element:
// sequence_map_inverse<Seq<2,0,1>> -> sequence_map_inverse<Seq<0,1>> ->
// sequence_map_inverse<Seq<1>>
// Each "->" is a new type the compiler must create, track, and manage. For N elements, that's
// N template types, each with overhead (name mangling, debug info, symbol table entries).
//
// This implementation uses O(N) direct assignment with a fold expression:
// For input Sequence<2,0,1>, the fold expression ((result[Is] = pos++), ...) expands to:
// result[2]=0, result[0]=1, result[1]=2
// This builds the inverse permutation in a single pass without any searching.
//
template <index_t... Is>
struct sequence_map_inverse<Sequence<Is...>>
{
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
private:
struct InverseArray
{
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});

using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
index_t data[sizeof...(Is)] = {};
};

template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
static constexpr auto build_inverse()
{
using type = WorkingY2X;
};
InverseArray result{};
index_t pos = 0;
((result.data[Is] = pos++), ...);
return result;
}

using type =
typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
0,
SeqMap::Size()>::type;
static constexpr InverseArray inverse = build_inverse();

template <index_t... Positions>
static constexpr auto compute(Sequence<Positions...>)
{
return Sequence<inverse.data[Positions]...>{};
}

public:
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
};

template <index_t... Xs, index_t... Ys>
Expand Down