diff --git a/hwy/contrib/bit_pack/bit_pack-inl.h b/hwy/contrib/bit_pack/bit_pack-inl.h index 09d37ac5ef..d6a2023c44 100644 --- a/hwy/contrib/bit_pack/bit_pack-inl.h +++ b/hwy/contrib/bit_pack/bit_pack-inl.h @@ -16,6 +16,9 @@ #include #include +#include +#include + #include "hwy/base.h" // Per-target include guard @@ -29,6 +32,7 @@ #endif #include "hwy/highway.h" +#include "hwy/print-inl.h" HWY_BEFORE_NAMESPACE(); namespace hwy { @@ -2737,6 +2741,11 @@ constexpr size_t UnpackedIncr() { return (sizeof(T) * 8) / NumLoops(); } +template +constexpr uint8_t MaskBits8() { + return static_cast((1ull << kBits) - 1); +} + template constexpr uint32_t MaskBits32() { return static_cast((1ull << kBits) - 1); @@ -2751,8 +2760,692 @@ constexpr uint64_t MaskBits64<64>() { return ~uint64_t{0}; } +struct State { + size_t kBits; + size_t S; + size_t kLoadPos; + size_t kStorePos; + size_t kLoops; + + template + constexpr State NextState() const { + return State{ + .kBits = kBits, + .S = ((S < B) ? ((S + kBits == B) ? 0 : S + kBits) : (S % B)), + .kLoadPos = + kLoadPos + (S > B ? 1 : static_cast((S + kBits) == B)), + .kStorePos = kStorePos + static_cast(S < B), + .kLoops = S + kBits == B ? kLoops - 1 : kLoops, + }; + } + + template + constexpr bool ShouldGoToNextColumn() const { + return (S < B); + } + + constexpr bool ShouldAccumulate() const { return kBits == 0 || kLoops == 0; } +}; + +template +static V MakeMask(D d) { + HWY_IF_CONSTEXPR(std::is_same_v) { + return Set(d, MaskBits8()); + } + HWY_IF_CONSTEXPR(std::is_same_v) { + return Set(d, MaskBits32()); + } + HWY_IF_CONSTEXPR(std::is_same_v) { + return Set(d, MaskBits64()); + } +} + +template +static constexpr void AssignAndUpdateInOutFGt(D d, const T* pi, V& in, V& out, + V& f) { + constexpr size_t B = sizeof(T) * 8; + const V mask = MakeMask(d); + in = LoadU(d, pi + m.kLoadPos * Lanes(d)); + constexpr size_t shl_amount = (m.kBits - m.S % B) % B; + // Print(d, "m_0.S>B in_0", in_0, 0, Lanes(d)); + // Print(d, "m_0.S>B out_0", out_0, 0, Lanes(d)); + out = And(Or(out, ShiftLeft(in)), mask); +} + +template +static constexpr void AssignAndUpdateInOutFLt(D d, const T* pi, V& in, V& out, + V& f) { + constexpr size_t B = sizeof(T) * 8; + const V mask = MakeMask(d); + f = out; + HWY_IF_CONSTEXPR(m.S + m.kBits < B) { + // Optimize for the case when `S` is zero. + // We can skip the `ShiftRight` to align `in`. + HWY_IF_CONSTEXPR(m.S == 0) { out = And(in, mask); } + HWY_IF_CONSTEXPR(m.S != 0) { out = And(ShiftRight(in), mask); } + } + HWY_IF_CONSTEXPR(m.S + m.kBits >= B) { + // Print(d, "m_0.S=B out_0", out_0, 0, Lanes(d)); + out = ShiftRight(in); + } + HWY_IF_CONSTEXPR(m.kBits != B && (m.S + m.kBits == B) && m.kLoops > 1) { + in = LoadU(d, pi + m.kLoadPos * Lanes(d)); + } + HWY_IF_CONSTEXPR(m.kBits == B && m.kLoops > 0) { + in = LoadU(d, pi + m.kLoadPos * Lanes(d)); + out = in; + } +} + +#define NEXT2 \ + NextUnroller::FusedUnpack(d, pi_0, pi_1, raw, in_0, out_0, in_1, out_1, f_0, \ + f_1); +#define BYPASS2 \ + FusedBitPackUnroller::FusedUnpack( \ + d, pi_0, pi_1, raw, in_0, out_0, in_1, out_1, f_0, f_1); + +// Generates the implementation for bit-packing/un-packing `T` type numbers +// where each number takes `kBits` bits. +// `S` is the remainder bits left from the previous bit-packed block. +// `kLoadPos` is the offset from which the next vector block should be loaded. +// `kStorePos` is the offset into which the next vector block should be stored. +// `BlockPackingType` is the type of packing/unpacking for this block. +template +struct FusedBitPackUnroller { + static constexpr size_t B = sizeof(T) * 8; + + template + static inline void Unpack(D d, const T* HWY_RESTRICT pi_0, + T* HWY_RESTRICT raw, const V& mask_0, V& in_0, + V& out_0, V& f_0) { + HWY_IF_CONSTEXPR(m_0.kLoops > 0) { + constexpr State mn_0 = m_0.NextState(); + using NextUnroller = FusedBitPackUnroller; + + (void)pi_0; + (void)in_0; + (void)mask_0; + + const size_t N = Lanes(d); + HWY_IF_CONSTEXPR(m_0.S > B) { + AssignAndUpdateInOutFGt(d, pi_0, in_0, out_0, f_0); + return NextUnroller::Unpack(d, pi_0, raw, mask_0, in_0, out_0, f_0); + } + HWY_IF_CONSTEXPR(m_0.S < B) { + AssignAndUpdateInOutFLt(d, pi_0, in_0, out_0, f_0); + StoreU(f_0, d, raw + m_0.kStorePos * N); + return NextUnroller::Unpack(d, pi_0, raw, mask_0, in_0, out_0, f_0); + } + } + HWY_IF_CONSTEXPR(m_0.kLoops == 0) { + f_0 = out_0; + StoreU(f_0, d, raw + m_0.kStorePos * Lanes(d)); + return; + } + } + + template + static inline void AccumulateResult(D d, V& f_0, V& f_1, uint16_t* out, + size_t pos) { + const auto low = InterleaveWholeLower(d, f_0, f_1); + const auto high = InterleaveWholeUpper(d, f_0, f_1); + const Repartition d16; + + StoreU(BitCast(d16, low), d16, out + pos * Lanes(d)); + StoreU(BitCast(d16, high), d16, out + pos * Lanes(d) + Lanes(d16)); + } + + template + static inline void FusedUnpack(D d, const T* HWY_RESTRICT pi_0, + const T* HWY_RESTRICT pi_1, + BT* HWY_RESTRICT raw, V& in_0, V& out_0, + V& in_1, V& out_1, V& f_0, V& f_1) { + HWY_IF_CONSTEXPR(C == 2) { + AccumulateResult(d, f_0, f_1, raw, kMainPos); + HWY_IF_CONSTEXPR(m_0.kLoops != 0) { + HWY_IF_CONSTEXPR(m_1.kLoops != 0) { + return FusedBitPackUnroller::FusedUnpack(d, pi_0, pi_1, raw, in_0, + out_0, in_1, out_1, f_0, + f_1); + } + } + HWY_IF_CONSTEXPR(m_0.ShouldAccumulate() && m_1.ShouldAccumulate()) { + AccumulateResult(d, out_0, out_1, raw, kMainPos + 1); + } + } + + HWY_IF_CONSTEXPR(C == 0) { + HWY_IF_CONSTEXPR(m_0.kBits == 0 || m_0.kLoops == 0) { + HWY_IF_CONSTEXPR(m_0.kLoops == 0) { f_0 = out_0; } + return BYPASS2; + } + HWY_IF_CONSTEXPR(m_0.kLoops > 0) { + constexpr State mn_0 = m_0.NextState(); + using NextUnroller = + FusedBitPackUnroller()>; + HWY_IF_CONSTEXPR(m_0.S > B) { + AssignAndUpdateInOutFGt(d, pi_0, in_0, out_0, f_0); + return NEXT2; + } + HWY_IF_CONSTEXPR(m_0.S < B) { + AssignAndUpdateInOutFLt(d, pi_0, in_0, out_0, f_0); + return NEXT2; + } + } + } + + HWY_IF_CONSTEXPR(C == 1) { + HWY_IF_CONSTEXPR(m_1.kBits == 0 || m_1.kLoops == 0) { + HWY_IF_CONSTEXPR(m_1.kLoops == 0) { f_1 = out_1; } + return BYPASS2; + } + HWY_IF_CONSTEXPR(m_1.kLoops > 0) { + constexpr State mn_1 = m_1.NextState(); + using NextUnroller = + FusedBitPackUnroller()>; + HWY_IF_CONSTEXPR(m_1.S > B) { + AssignAndUpdateInOutFGt(d, pi_1, in_1, out_1, f_1); + return NEXT2; + } + HWY_IF_CONSTEXPR(m_1.S < B) { + AssignAndUpdateInOutFLt(d, pi_1, in_1, out_1, f_1); + return NEXT2; + } + } + } + } +}; + +#define NEXT4 \ + NextUnroller4::FusedUnpack4(d, pi_0, pi_1, pi_2, pi_3, raw, in_0, out_0, \ + in_1, out_1, in_2, out_2, in_3, out_3, f_0, f_1, \ + f_2, f_3); +#define BYPASS4 \ + FusedBitPackUnroller4::FusedUnpack4( \ + d, pi_0, pi_1, pi_2, pi_3, raw, in_0, out_0, in_1, out_1, in_2, out_2, \ + in_3, out_3, f_0, f_1, f_2, f_3); + +template +struct FusedBitPackUnroller4 { + static constexpr size_t B = sizeof(T) * 8; + + template + static inline void AccumulateResult4(D d, V& v0, V& v1, V& v2, V& v3, + uint32_t* raw, size_t pos) { + const Repartition d16; + const Repartition d32; + const size_t N = Lanes(d); + + uint32_t* out = raw + pos * N; + + auto v00 = BitCast(d16, InterleaveWholeLower(d, v0, v1)); + auto v11 = BitCast(d16, InterleaveWholeUpper(d, v0, v1)); + auto v22 = BitCast(d16, InterleaveWholeLower(d, v2, v3)); + auto v33 = BitCast(d16, InterleaveWholeUpper(d, v2, v3)); + StoreU(BitCast(d32, InterleaveWholeLower(d16, v00, v22)), d32, out); + StoreU(BitCast(d32, InterleaveWholeUpper(d16, v00, v22)), d32, out + N / 4); + StoreU(BitCast(d32, InterleaveWholeLower(d16, v11, v33)), d32, out + N / 2); + StoreU(BitCast(d32, InterleaveWholeUpper(d16, v11, v33)), d32, + out + 3 * N / 4); + } + + template + static inline void FusedUnpack4(D d, const T* HWY_RESTRICT pi_0, + const T* HWY_RESTRICT pi_1, + const T* HWY_RESTRICT pi_2, + const T* HWY_RESTRICT pi_3, + BT* HWY_RESTRICT raw, V& in_0, V& out_0, + V& in_1, V& out_1, V& in_2, V& out_2, V& in_3, + V& out_3, V& f_0, V& f_1, V& f_2, V& f_3) { + HWY_IF_CONSTEXPR(C == 2) { + AccumulateResult4(d, f_0, f_1, f_2, f_3, raw, kMainPos); + HWY_IF_CONSTEXPR(m_0.kLoops != 0) { + HWY_IF_CONSTEXPR(m_2.kLoops != 0) { + return FusedBitPackUnroller4::FusedUnpack4(d, pi_0, pi_1, pi_2, + pi_3, raw, in_0, out_0, + in_1, out_1, in_2, + out_2, in_3, out_3, f_0, + f_1, f_2, f_3); + } + } + HWY_IF_CONSTEXPR(m_0.ShouldAccumulate() && m_2.ShouldAccumulate()) { + AccumulateResult4(d, out_0, out_1, out_2, out_3, raw, kMainPos + 1); + } + } + + HWY_IF_CONSTEXPR(C == 0) { + HWY_IF_CONSTEXPR(m_0.kBits == 0 || m_0.kLoops == 0) { + HWY_IF_CONSTEXPR(m_0.kLoops == 0) { + f_0 = out_0; + f_1 = out_1; + } + return BYPASS4; + } + HWY_IF_CONSTEXPR(m_0.kLoops > 0) { + constexpr State mn_0 = m_0.NextState(); + using NextUnroller4 = + FusedBitPackUnroller4()>; + HWY_IF_CONSTEXPR(m_0.S > B) { + AssignAndUpdateInOutFGt(d, pi_0, in_0, out_0, f_0); + AssignAndUpdateInOutFGt(d, pi_1, in_1, out_1, f_1); + return NEXT4; + } + HWY_IF_CONSTEXPR(m_0.S < B) { + AssignAndUpdateInOutFLt(d, pi_0, in_0, out_0, f_0); + AssignAndUpdateInOutFLt(d, pi_1, in_1, out_1, f_1); + return NEXT4; + } + } + } + + HWY_IF_CONSTEXPR(C == 1) { + HWY_IF_CONSTEXPR(m_2.kBits == 0 || m_2.kLoops == 0) { + HWY_IF_CONSTEXPR(m_2.kLoops == 0) { + f_2 = out_2; + f_3 = out_3; + } + return BYPASS4; + } + HWY_IF_CONSTEXPR(m_2.kLoops > 0) { + constexpr State mn_2 = m_2.NextState(); + using NextUnroller4 = + FusedBitPackUnroller4()>; + HWY_IF_CONSTEXPR(m_2.S > B) { + AssignAndUpdateInOutFGt(d, pi_2, in_2, out_2, f_2); + AssignAndUpdateInOutFGt(d, pi_3, in_3, out_3, f_3); + return NEXT4; + } + HWY_IF_CONSTEXPR(m_2.S < B) { + AssignAndUpdateInOutFLt(d, pi_2, in_2, out_2, f_2); + AssignAndUpdateInOutFLt(d, pi_3, in_3, out_3, f_3); + return NEXT4; + } + } + } + } +}; + +#define NEXT8 \ + NextUnroller8::FusedUnpack8( \ + d, pi_0, pi_1, pi_2, pi_3, pi_4, pi_5, pi_6, pi_7, raw, in_0, out_0, \ + in_1, out_1, in_2, out_2, in_3, out_3, in_4, out_4, in_5, out_5, in_6, \ + out_6, in_7, out_7, f_0, f_1, f_2, f_3, f_4, f_5, f_6, f_7); +#define BYPASS8 \ + FusedBitPackUnroller8::FusedUnpack8( \ + d, pi_0, pi_1, pi_2, pi_3, pi_4, pi_5, pi_6, pi_7, raw, in_0, out_0, \ + in_1, out_1, in_2, out_2, in_3, out_3, in_4, out_4, in_5, out_5, in_6, \ + out_6, in_7, out_7, f_0, f_1, f_2, f_3, f_4, f_5, f_6, f_7); + +template +struct FusedBitPackUnroller8 { + static constexpr size_t B = sizeof(T) * 8; + template + static inline void AccumulateResult8(D d, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, uint64_t* raw, + size_t pos) { + const Repartition d16; + const Repartition d32; + const Repartition d64; + const size_t N = Lanes(d); + + uint64_t* out = raw + pos * N; + + auto v00 = BitCast(d16, InterleaveWholeLower(d, v0, v1)); + auto v11 = BitCast(d16, InterleaveWholeUpper(d, v0, v1)); + auto v22 = BitCast(d16, InterleaveWholeLower(d, v2, v3)); + auto v33 = BitCast(d16, InterleaveWholeUpper(d, v2, v3)); + auto v44 = BitCast(d16, InterleaveWholeLower(d, v4, v5)); + auto v55 = BitCast(d16, InterleaveWholeUpper(d, v4, v5)); + auto v66 = BitCast(d16, InterleaveWholeLower(d, v6, v7)); + auto v77 = BitCast(d16, InterleaveWholeUpper(d, v6, v7)); + + auto v000 = BitCast(d32, InterleaveWholeLower(d16, v00, v22)); + auto v111 = BitCast(d32, InterleaveWholeUpper(d16, v00, v22)); + auto v222 = BitCast(d32, InterleaveWholeLower(d16, v11, v33)); + auto v333 = BitCast(d32, InterleaveWholeUpper(d16, v11, v33)); + auto v444 = BitCast(d32, InterleaveWholeLower(d16, v44, v66)); + auto v555 = BitCast(d32, InterleaveWholeUpper(d16, v44, v66)); + auto v666 = BitCast(d32, InterleaveWholeLower(d16, v55, v77)); + auto v777 = BitCast(d32, InterleaveWholeUpper(d16, v55, v77)); + + StoreU(BitCast(d64, InterleaveWholeLower(d32, v000, v444)), d64, out); + StoreU(BitCast(d64, InterleaveWholeUpper(d32, v000, v444)), d64, + out + N / 8); + StoreU(BitCast(d64, InterleaveWholeLower(d32, v111, v555)), d64, + out + 2 * N / 8); + StoreU(BitCast(d64, InterleaveWholeUpper(d32, v111, v555)), d64, + out + 3 * N / 8); + StoreU(BitCast(d64, InterleaveWholeLower(d32, v222, v666)), d64, + out + 4 * N / 8); + StoreU(BitCast(d64, InterleaveWholeUpper(d32, v222, v666)), d64, + out + 5 * N / 8); + StoreU(BitCast(d64, InterleaveWholeLower(d32, v333, v777)), d64, + out + 6 * N / 8); + StoreU(BitCast(d64, InterleaveWholeUpper(d32, v333, v777)), d64, + out + 7 * N / 8); + } + template + static inline void FusedUnpack8( + D d, const T* HWY_RESTRICT pi_0, const T* HWY_RESTRICT pi_1, + const T* HWY_RESTRICT pi_2, const T* HWY_RESTRICT pi_3, + const T* HWY_RESTRICT pi_4, const T* HWY_RESTRICT pi_5, + const T* HWY_RESTRICT pi_6, const T* HWY_RESTRICT pi_7, + BT* HWY_RESTRICT raw, V& in_0, V& out_0, V& in_1, V& out_1, V& in_2, + V& out_2, V& in_3, V& out_3, V& in_4, V& out_4, V& in_5, V& out_5, + V& in_6, V& out_6, V& in_7, V& out_7, V& f_0, V& f_1, V& f_2, V& f_3, + V& f_4, V& f_5, V& f_6, V& f_7) { + HWY_IF_CONSTEXPR(C == 2) { + AccumulateResult8(d, f_0, f_1, f_2, f_3, f_4, f_5, f_6, f_7, raw, + kMainPos); + HWY_IF_CONSTEXPR(m_0.kLoops != 0) { + HWY_IF_CONSTEXPR(m_4.kLoops != 0) { + return FusedBitPackUnroller8::FusedUnpack8(d, pi_0, pi_1, pi_2, + pi_3, pi_4, pi_5, pi_6, + pi_7, raw, in_0, out_0, + in_1, out_1, in_2, + out_2, in_3, out_3, + in_4, out_4, in_5, + out_5, in_6, out_6, + in_7, out_7, f_0, f_1, + f_2, f_3, f_4, f_5, f_6, + f_7); + } + } + HWY_IF_CONSTEXPR(m_0.ShouldAccumulate() && m_4.ShouldAccumulate()) { + AccumulateResult8(d, out_0, out_1, out_2, out_3, out_4, out_5, out_6, + out_7, raw, kMainPos + 1); + } + } + + HWY_IF_CONSTEXPR(C == 0) { + HWY_IF_CONSTEXPR(m_0.kBits == 0 || m_0.kLoops == 0) { + HWY_IF_CONSTEXPR(m_0.kLoops == 0) { + f_0 = out_0; + f_1 = out_1; + f_2 = out_2; + f_3 = out_3; + } + return BYPASS8; + } + HWY_IF_CONSTEXPR(m_0.kLoops > 0) { + constexpr State mn_0 = m_0.NextState(); + using NextUnroller8 = + FusedBitPackUnroller8()>; + HWY_IF_CONSTEXPR(m_0.S > B) { + AssignAndUpdateInOutFGt(d, pi_0, in_0, out_0, f_0); + AssignAndUpdateInOutFGt(d, pi_1, in_1, out_1, f_1); + AssignAndUpdateInOutFGt(d, pi_2, in_2, out_2, f_2); + AssignAndUpdateInOutFGt(d, pi_3, in_3, out_3, f_3); + return NEXT8; + } + HWY_IF_CONSTEXPR(m_0.S < B) { + AssignAndUpdateInOutFLt(d, pi_0, in_0, out_0, f_0); + AssignAndUpdateInOutFLt(d, pi_1, in_1, out_1, f_1); + AssignAndUpdateInOutFLt(d, pi_2, in_2, out_2, f_2); + AssignAndUpdateInOutFLt(d, pi_3, in_3, out_3, f_3); + return NEXT8; + } + } + } + + HWY_IF_CONSTEXPR(C == 1) { + HWY_IF_CONSTEXPR(m_4.kBits == 0 || m_4.kLoops == 0) { + HWY_IF_CONSTEXPR(m_4.kLoops == 0) { + f_4 = out_4; + f_5 = out_5; + f_6 = out_6; + f_7 = out_7; + } + return BYPASS8; + } + HWY_IF_CONSTEXPR(m_4.kLoops > 0) { + constexpr State mn_4 = m_4.NextState(); + using NextUnroller8 = + FusedBitPackUnroller8()>; + HWY_IF_CONSTEXPR(m_4.S > B) { + AssignAndUpdateInOutFGt(d, pi_4, in_4, out_4, f_4); + AssignAndUpdateInOutFGt(d, pi_5, in_5, out_5, f_5); + AssignAndUpdateInOutFGt(d, pi_6, in_6, out_6, f_6); + AssignAndUpdateInOutFGt(d, pi_7, in_7, out_7, f_7); + return NEXT8; + } + HWY_IF_CONSTEXPR(m_4.S < B) { + AssignAndUpdateInOutFLt(d, pi_4, in_4, out_4, f_4); + AssignAndUpdateInOutFLt(d, pi_5, in_5, out_5, f_5); + AssignAndUpdateInOutFLt(d, pi_6, in_6, out_6, f_6); + AssignAndUpdateInOutFLt(d, pi_7, in_7, out_7, f_7); + return NEXT8; + } + } + } + } +}; + } // namespace detail +template // <= 32 +struct Packk8 { + template + HWY_INLINE void Pack(D d, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out, + const uint8_t frame_of_reference_value = 0) const { + using V = VFromD; + const V mask = Set(d, detail::MaskBits8()); + const V frame_of_reference = Set(d, frame_of_reference_value); + for (size_t i = 0; i < detail::NumLoops(); ++i) { + V in = Zero(d); + V out = Zero(d); + detail::BitPackUnroller::Pack(d, raw, packed_out, + mask, + frame_of_reference, in, + out); + raw += detail::UnpackedIncr() * Lanes(d); + packed_out += detail::PackedIncr() * Lanes(d); + } + } + + template + HWY_INLINE void Unpack(D d, const uint8_t* HWY_RESTRICT pi_0, + uint8_t* HWY_RESTRICT raw) const { + using V = VFromD; + const V mask = Set(d, detail::MaskBits8()); + V in_0 = LoadU(d, pi_0 + 0 * Lanes(d)); + V out_0 = And(in_0, mask); + V f_0 = Zero(d); + constexpr size_t B = sizeof(uint8_t) * 8; + constexpr detail::State m_0{ + .kBits = kBits0, + .S = kBits0 % B, + .kLoadPos = 1, + .kStorePos = 0, + .kLoops = kBits0 == B ? B - 1 : std::gcd(kBits0, B), + }; + detail::FusedBitPackUnroller::Unpack( + d, pi_0, raw, mask, in_0, out_0, f_0); + } + + template + HWY_INLINE void FusedUnpack2(D d, std::array ptrs, + uint16_t* HWY_RESTRICT raw) const { + constexpr bool kAllZero = (kBits0 == 0 && kBits1 == 0); + HWY_IF_CONSTEXPR(kAllZero) { + memset(raw, 0, 2 * 8 * Lanes(d)); + return; + } + HWY_IF_CONSTEXPR(!kAllZero) { + using V = VFromD; + const V mask_0 = Set(d, detail::MaskBits8()); + const V mask_1 = Set(d, detail::MaskBits8()); + + constexpr size_t B = sizeof(uint8_t) * 8; + V in_0 = (kBits0 == 0 ? Zero(d) : LoadU(d, ptrs[0])); + V out_0 = And(in_0, mask_0); + V f_0 = Zero(d); + + constexpr detail::State m_0{ + .kBits = kBits0, + .S = kBits0 % B, + .kLoadPos = 1, + .kStorePos = 0, + .kLoops = kBits0 == B ? B - 1 : std::gcd(kBits0, B), + }; + + V in_1 = (kBits1 == 0 ? Zero(d) : LoadU(d, ptrs[1])); + V out_1 = And(in_1, mask_1); + V f_1 = Zero(d); + constexpr detail::State m_1{ + .kBits = kBits1, + .S = kBits1 % B, + .kLoadPos = 1, + .kStorePos = 0, + .kLoops = kBits1 == B ? B - 1 : std::gcd(kBits1, B), + }; + + detail::FusedBitPackUnroller::FusedUnpack(d, ptrs[0], ptrs[1], raw, + in_0, out_0, in_1, out_1, + f_0, f_1); + } + } + + template + HWY_INLINE void FusedUnpack4(D d, std::array ptrs, + uint32_t* HWY_RESTRICT raw) const { + constexpr bool kAllZero = (kBits0 == 0 && kBits1 == 0); + HWY_IF_CONSTEXPR(kAllZero) { return; } + HWY_IF_CONSTEXPR(!kAllZero) { + using V = VFromD; + constexpr size_t B = sizeof(uint8_t) * 8; + const V mask_0 = Set(d, detail::MaskBits8()); + const V mask_1 = Set(d, detail::MaskBits8()); + V in_0 = (kBits0 == 0 ? Zero(d) : LoadU(d, ptrs[0])); + V out_0 = And(in_0, mask_0); + V f_0 = Zero(d); + V in_1 = (kBits0 == 0 ? Zero(d) : LoadU(d, ptrs[1])); + V out_1 = And(in_1, mask_1); + V f_1 = Zero(d); + + constexpr detail::State m_0{ + .kBits = kBits0, + .S = kBits0 % B, + .kLoadPos = 1, + .kStorePos = 0, + .kLoops = kBits0 == B ? B - 1 : std::gcd(kBits0, B), + }; + + const V mask_2 = Set(d, detail::MaskBits8()); + const V mask_3 = Set(d, detail::MaskBits8()); + V in_2 = (kBits1 == 0 ? Zero(d) : LoadU(d, ptrs[2])); + V out_2 = And(in_2, mask_2); + V f_2 = Zero(d); + + V in_3 = (kBits1 == 0 ? Zero(d) : LoadU(d, ptrs[3])); + V out_3 = And(in_3, mask_3); + V f_3 = Zero(d); + + constexpr detail::State m_2{ + .kBits = kBits1, + .S = kBits1 % B, + .kLoadPos = 1, + .kStorePos = 0, + .kLoops = kBits1 == B ? B - 1 : std::gcd(kBits1, B), + }; + + detail::FusedBitPackUnroller4::FusedUnpack4(d, ptrs[0], ptrs[1], + ptrs[2], ptrs[3], raw, + in_0, out_0, in_1, out_1, + in_2, out_2, in_3, out_3, + f_0, f_1, f_2, f_3); + } + } + + template + HWY_INLINE void FusedUnpack8(D d, std::array ptrs, + uint64_t* HWY_RESTRICT raw) const { + constexpr bool kAllZero = (kBits0 == 0 && kBits1 == 0); + HWY_IF_CONSTEXPR(kAllZero) { return; } + HWY_IF_CONSTEXPR(!kAllZero) { + using V = VFromD; + constexpr size_t B = sizeof(uint8_t) * 8; + const V mask_0 = Set(d, detail::MaskBits8()); + const V mask_1 = Set(d, detail::MaskBits8()); + const V mask_2 = Set(d, detail::MaskBits8()); + const V mask_3 = Set(d, detail::MaskBits8()); + V in_0 = (kBits0 == 0 ? Zero(d) : LoadU(d, ptrs[0])); + V out_0 = And(in_0, mask_0); + V f_0 = Zero(d); + V in_1 = (kBits0 == 0 ? Zero(d) : LoadU(d, ptrs[1])); + V out_1 = And(in_1, mask_1); + V f_1 = Zero(d); + V in_2 = (kBits0 == 0 ? Zero(d) : LoadU(d, ptrs[2])); + V out_2 = And(in_2, mask_2); + V f_2 = Zero(d); + V in_3 = (kBits0 == 0 ? Zero(d) : LoadU(d, ptrs[3])); + V out_3 = And(in_3, mask_3); + V f_3 = Zero(d); + + constexpr detail::State m_0{ + .kBits = kBits0, + .S = kBits0 % B, + .kLoadPos = 1, + .kStorePos = 0, + .kLoops = kBits0 == B ? B - 1 : std::gcd(kBits0, B), + }; + + const V mask_4 = Set(d, detail::MaskBits8()); + const V mask_5 = Set(d, detail::MaskBits8()); + const V mask_6 = Set(d, detail::MaskBits8()); + const V mask_7 = Set(d, detail::MaskBits8()); + + constexpr detail::State m_4{ + .kBits = kBits1, + .S = kBits1 % B, + .kLoadPos = 1, + .kStorePos = 0, + .kLoops = kBits1 == B ? B - 1 : std::gcd(kBits1, B), + }; + V in_4 = (kBits1 == 0 ? Zero(d) : LoadU(d, ptrs[4])); + V out_4 = And(in_4, mask_4); + V f_4 = Zero(d); + V in_5 = (kBits1 == 0 ? Zero(d) : LoadU(d, ptrs[5])); + V out_5 = And(in_5, mask_5); + V f_5 = Zero(d); + V in_6 = (kBits1 == 0 ? Zero(d) : LoadU(d, ptrs[6])); + V out_6 = And(in_6, mask_6); + V f_6 = Zero(d); + V in_7 = (kBits1 == 0 ? Zero(d) : LoadU(d, ptrs[7])); + V out_7 = And(in_7, mask_7); + V f_7 = Zero(d); + + detail::FusedBitPackUnroller8::FusedUnpack8(d, ptrs[0], ptrs[1], + ptrs[2], ptrs[3], ptrs[4], + ptrs[5], ptrs[6], ptrs[7], + raw, in_0, out_0, in_1, + out_1, in_2, out_2, in_3, + out_3, in_4, out_4, in_5, + out_5, in_6, out_6, in_7, + out_7, f_0, f_1, f_2, f_3, + f_4, f_5, f_6, f_7); + } + } +}; + template // <= 32 struct Pack32 { template - HWY_INLINE void Unpack(D d, const uint32_t* HWY_RESTRICT packed_in, - uint32_t* HWY_RESTRICT raw, - const uint32_t frame_of_reference_value = 0) const { + template + HWY_INLINE void Unpack(D d, const uint32_t* HWY_RESTRICT pi_0, + uint32_t* HWY_RESTRICT raw) const { using V = VFromD; - const V mask = Set(d, detail::MaskBits32()); - const V frame_of_reference = Set(d, frame_of_reference_value); - for (size_t i = 0; i < detail::NumLoops(); ++i) { - V in = LoadU(d, packed_in + 0 * Lanes(d)); - V out = And(in, mask); - detail::BitPackUnroller::Unpack(d, packed_in, raw, - mask, - frame_of_reference, - in, out); - raw += detail::UnpackedIncr() * Lanes(d); - packed_in += detail::PackedIncr() * Lanes(d); - } + const V mask = detail::MakeMask(d); + V in_0 = LoadU(d, pi_0 + 0 * Lanes(d)); + V out_0 = And(in_0, mask); + V f_0 = Zero(d); + constexpr size_t B = sizeof(uint32_t) * 8; + constexpr detail::State m_0{ + .kBits = kBits, + .S = kBits % B, + .kLoadPos = 1, + .kStorePos = 0, + .kLoops = std::gcd(kBits, B), + }; + detail::FusedBitPackUnroller::Unpack( + d, pi_0, raw, mask, in_0, out_0, f_0); } }; @@ -2821,25 +3513,24 @@ struct Pack64 { } } - template - HWY_INLINE void Unpack(D d, const uint64_t* HWY_RESTRICT packed_in, - uint64_t* HWY_RESTRICT raw, - const uint64_t frame_of_reference_value = 0) const { + template + HWY_INLINE void Unpack(D d, const uint64_t* HWY_RESTRICT pi_0, + uint64_t* HWY_RESTRICT raw) const { using V = VFromD; - const V mask = Set(d, detail::MaskBits64()); - const V frame_of_reference = Set(d, frame_of_reference_value); - for (size_t i = 0; i < detail::NumLoops(); ++i) { - V in = LoadU(d, packed_in + 0 * Lanes(d)); - V out = And(in, mask); - detail::BitPackUnroller::Unpack(d, packed_in, raw, - mask, - frame_of_reference, - in, out); - raw += detail::UnpackedIncr() * Lanes(d); - packed_in += detail::PackedIncr() * Lanes(d); - } + const V mask = detail::MakeMask(d); + V in_0 = LoadU(d, pi_0 + 0 * Lanes(d)); + V out_0 = And(in_0, mask); + V f_0 = Zero(d); + constexpr size_t B = sizeof(uint64_t) * 8; + constexpr detail::State m_0{ + .kBits = kBits, + .S = kBits % B, + .kLoadPos = 1, + .kStorePos = 0, + .kLoops = std::gcd(kBits, B), + }; + detail::FusedBitPackUnroller::Unpack( + d, pi_0, raw, mask, in_0, out_0, f_0); } }; diff --git a/hwy/contrib/bit_pack/bit_pack_test.cc b/hwy/contrib/bit_pack/bit_pack_test.cc index d6a3bcfa87..5de1f49b57 100644 --- a/hwy/contrib/bit_pack/bit_pack_test.cc +++ b/hwy/contrib/bit_pack/bit_pack_test.cc @@ -167,6 +167,17 @@ struct TestPack { } }; +void TestAllPackk() { + ForShrinkableVectors>()(uint8_t()); + ForShrinkableVectors>()(uint8_t()); + ForShrinkableVectors>()(uint8_t()); + ForShrinkableVectors>()(uint8_t()); + ForShrinkableVectors>()(uint8_t()); + ForShrinkableVectors>()(uint8_t()); + ForShrinkableVectors>()(uint8_t()); + ForShrinkableVectors>()(uint8_t()); +} + void TestAllPack8() { ForShrinkableVectors>()(uint8_t()); ForShrinkableVectors>()(uint8_t()); @@ -200,6 +211,8 @@ void TestAllPack16() { void TestAllPack32() { ForShrinkableVectors>()(uint32_t()); ForShrinkableVectors>()(uint32_t()); + ForShrinkableVectors>()(uint32_t()); + ForShrinkableVectors>()(uint32_t()); ForShrinkableVectors>()(uint32_t()); ForShrinkableVectors>()(uint32_t()); ForShrinkableVectors>()(uint32_t()); @@ -220,6 +233,7 @@ void TestAllPack64() { ForShrinkableVectors>()(uint64_t()); ForShrinkableVectors>()(uint64_t()); ForShrinkableVectors>()(uint64_t()); + ForShrinkableVectors>()(uint64_t()); #endif } @@ -233,6 +247,7 @@ HWY_AFTER_NAMESPACE(); namespace hwy { namespace { HWY_BEFORE_TEST(BitPackTest); +HWY_EXPORT_AND_TEST_P(BitPackTest, TestAllPackk); HWY_EXPORT_AND_TEST_P(BitPackTest, TestAllPack8); HWY_EXPORT_AND_TEST_P(BitPackTest, TestAllPack16); HWY_EXPORT_AND_TEST_P(BitPackTest, TestAllPack32);