diff --git a/include/xsimd/arch/xsimd_avx.hpp b/include/xsimd/arch/xsimd_avx.hpp index ddcc9e4bf..66ef31bf4 100644 --- a/include/xsimd/arch/xsimd_avx.hpp +++ b/include/xsimd/arch/xsimd_avx.hpp @@ -920,16 +920,8 @@ namespace xsimd using int_t = as_integer_t; constexpr size_t half_size = batch::size / 2; - XSIMD_IF_CONSTEXPR(mask.none()) - { - return batch(T { 0 }); - } - else XSIMD_IF_CONSTEXPR(mask.all()) - { - return load(mem, Mode {}); - } // confined to lower 128-bit half → forward to SSE2 - else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= half_size) + XSIMD_IF_CONSTEXPR(mask.countl_zero() >= half_size) { constexpr auto mlo = ::xsimd::detail::lower_half(batch_bool_constant {}); const auto lo = load_masked(reinterpret_cast(mem), mlo, convert {}, Mode {}, sse4_2 {}); @@ -970,16 +962,8 @@ namespace xsimd { constexpr size_t half_size = batch::size / 2; - XSIMD_IF_CONSTEXPR(mask.none()) - { - return; - } - else XSIMD_IF_CONSTEXPR(mask.all()) - { - src.store(mem, Mode {}); - } // confined to lower 128-bit half → forward to SSE2 - else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= half_size) + XSIMD_IF_CONSTEXPR(mask.countl_zero() >= half_size) { constexpr auto mlo = ::xsimd::detail::lower_half(mask); const auto lo = detail::lower_half(src); diff --git a/include/xsimd/arch/xsimd_avx2.hpp b/include/xsimd/arch/xsimd_avx2.hpp index dbceb78da..de78b6f22 100644 --- a/include/xsimd/arch/xsimd_avx2.hpp +++ b/include/xsimd/arch/xsimd_avx2.hpp @@ -142,21 +142,10 @@ namespace xsimd XSIMD_INLINE typename std::enable_if::value && (sizeof(T) >= 4), batch>::type load_masked(T const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept { - XSIMD_IF_CONSTEXPR(mask.none()) - { - return _mm256_setzero_si256(); - } - else XSIMD_IF_CONSTEXPR(mask.all()) - { - return load(mem, Mode {}); - } - else - { - static_assert(sizeof(T) == 4 || sizeof(T) == 8, "load_masked supports only 32/64-bit integers on AVX2"); - using int_t = typename std::conditional::type; - // Use the raw register-level maskload helpers for the remaining cases. - return detail::maskload(reinterpret_cast(mem), mask.as_batch()); - } + static_assert(sizeof(T) == 4 || sizeof(T) == 8, "load_masked supports only 32/64-bit integers on AVX2"); + using int_t = typename std::conditional::type; + // Use the raw register-level maskload helpers for the remaining cases. + return detail::maskload(reinterpret_cast(mem), mask.as_batch()); } template @@ -206,16 +195,8 @@ namespace xsimd { constexpr size_t lanes_per_half = sizeof(__m128i) / sizeof(T); - XSIMD_IF_CONSTEXPR(mask.none()) - { - return; - } - else XSIMD_IF_CONSTEXPR(mask.all()) - { - src.store(mem, Mode {}); - } // confined to lower 128-bit half → forward to SSE - else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= lanes_per_half) + XSIMD_IF_CONSTEXPR(mask.countl_zero() >= lanes_per_half) { constexpr auto mlo = ::xsimd::detail::lower_half(mask); const auto lo = detail::lower_half(src); diff --git a/include/xsimd/arch/xsimd_avx512f.hpp b/include/xsimd/arch/xsimd_avx512f.hpp index 990c07148..5b8599a8d 100644 --- a/include/xsimd/arch/xsimd_avx512f.hpp +++ b/include/xsimd/arch/xsimd_avx512f.hpp @@ -304,34 +304,23 @@ namespace xsimd batch_bool_constant mask, convert, Mode, requires_arch) noexcept { - XSIMD_IF_CONSTEXPR(mask.none()) + constexpr auto half = batch::size / 2; + XSIMD_IF_CONSTEXPR(mask.countl_zero() >= half) // lower-half AVX2 forwarding { - return batch(T { 0 }); + constexpr auto mlo = ::xsimd::detail::lower_half(mask); + const auto lo = load_masked(mem, mlo, convert {}, Mode {}, avx2 {}); + return detail::load_masked(lo); // zero-extend low half } - else XSIMD_IF_CONSTEXPR(mask.all()) + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= half) // upper-half AVX2 forwarding { - return load(mem, Mode {}); + constexpr auto mhi = ::xsimd::detail::upper_half(mask); + const auto hi = load_masked(mem + half, mhi, convert {}, Mode {}, avx2 {}); + return detail::load_masked(hi, detail::high_tag {}); } else { - constexpr auto half = batch::size / 2; - XSIMD_IF_CONSTEXPR(mask.countl_zero() >= half) // lower-half AVX2 forwarding - { - constexpr auto mlo = ::xsimd::detail::lower_half(mask); - const auto lo = load_masked(mem, mlo, convert {}, Mode {}, avx2 {}); - return detail::load_masked(lo); // zero-extend low half - } - else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= half) // upper-half AVX2 forwarding - { - constexpr auto mhi = ::xsimd::detail::upper_half(mask); - const auto hi = load_masked(mem + half, mhi, convert {}, Mode {}, avx2 {}); - return detail::load_masked(hi, detail::high_tag {}); - } - else - { - // fallback to centralized pointer-level helper - return detail::load_masked(mem, mask.mask(), Mode {}); - } + // fallback to centralized pointer-level helper + return detail::load_masked(mem, mask.mask(), Mode {}); } } @@ -342,34 +331,23 @@ namespace xsimd batch_bool_constant mask, Mode, requires_arch) noexcept { - XSIMD_IF_CONSTEXPR(mask.none()) + constexpr auto half = batch::size / 2; + XSIMD_IF_CONSTEXPR(mask.countl_zero() >= half) // lower-half AVX2 forwarding { - return; + constexpr auto mlo = ::xsimd::detail::lower_half(mask); + const auto lo = detail::lower_half(src); + store_masked(mem, lo, mlo, Mode {}, avx2 {}); } - else XSIMD_IF_CONSTEXPR(mask.all()) + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= half) // upper-half AVX2 forwarding { - src.store(mem, Mode {}); + constexpr auto mhi = ::xsimd::detail::upper_half(mask); + const auto hi = detail::upper_half(src); + store_masked(mem + half, hi, mhi, Mode {}, avx2 {}); } else { - constexpr auto half = batch::size / 2; - XSIMD_IF_CONSTEXPR(mask.countl_zero() >= half) // lower-half AVX2 forwarding - { - constexpr auto mlo = ::xsimd::detail::lower_half(mask); - const auto lo = detail::lower_half(src); - store_masked(mem, lo, mlo, Mode {}, avx2 {}); - } - else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= half) // upper-half AVX2 forwarding - { - constexpr auto mhi = ::xsimd::detail::upper_half(mask); - const auto hi = detail::upper_half(src); - store_masked(mem + half, hi, mhi, Mode {}, avx2 {}); - } - else - { - // fallback to centralized pointer-level helper - detail::store_masked(mem, src, mask.mask(), Mode {}); - } + // fallback to centralized pointer-level helper + detail::store_masked(mem, src, mask.mask(), Mode {}); } } diff --git a/include/xsimd/arch/xsimd_neon.hpp b/include/xsimd/arch/xsimd_neon.hpp index 707257b5d..d97be3cc2 100644 --- a/include/xsimd/arch/xsimd_neon.hpp +++ b/include/xsimd/arch/xsimd_neon.hpp @@ -613,6 +613,44 @@ namespace xsimd return load_unaligned(mem, t, r); } + /* masked version */ + namespace detail + { + template + struct load_masked; + + template <> + struct load_masked<> + { + template + static XSIMD_INLINE batch apply(T const* mem, batch acc, std::integral_constant) noexcept + { + return acc; + } + }; + template + struct load_masked + { + template + static XSIMD_INLINE batch apply(T const* mem, batch acc, std::true_type) noexcept + { + return load_masked::template apply(mem, insert(acc, mem[I], index {}), std::integral_constant {}); + } + template + static XSIMD_INLINE batch apply(T const* mem, batch acc, std::false_type) noexcept + { + return load_masked::template apply(mem, acc, std::integral_constant {}); + } + }; + } + + template + XSIMD_INLINE batch load_masked(T const* mem, batch_bool_constant mask, Mode, requires_arch) noexcept + { + // Call insert whenever Values... are true + return detail::load_masked::template apply<0>(mem, broadcast(T(0), A {}), std::integral_constant {}); + } + /********* * store * *********/ diff --git a/include/xsimd/arch/xsimd_sse2.hpp b/include/xsimd/arch/xsimd_sse2.hpp index 012c9b252..63893cdbb 100644 --- a/include/xsimd/arch/xsimd_sse2.hpp +++ b/include/xsimd/arch/xsimd_sse2.hpp @@ -1071,15 +1071,7 @@ namespace xsimd template XSIMD_INLINE batch load_masked(float const* mem, batch_bool_constant mask, Mode, requires_arch) noexcept { - XSIMD_IF_CONSTEXPR(mask.none()) - { - return _mm_setzero_ps(); - } - else XSIMD_IF_CONSTEXPR(mask.all()) - { - return load(mem, Mode {}); - } - else XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) + XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) { return _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<__m64 const*>(mem)); } @@ -1095,15 +1087,7 @@ namespace xsimd template XSIMD_INLINE batch load_masked(double const* mem, batch_bool_constant mask, Mode, requires_arch) noexcept { - XSIMD_IF_CONSTEXPR(mask.none()) - { - return _mm_setzero_pd(); - } - else XSIMD_IF_CONSTEXPR(mask.all()) - { - return load(mem, Mode {}); - } - else XSIMD_IF_CONSTEXPR(mask.countr_one() == 1) + XSIMD_IF_CONSTEXPR(mask.countr_one() == 1) { return _mm_move_sd(_mm_setzero_pd(), _mm_load_sd(mem)); } @@ -1121,15 +1105,7 @@ namespace xsimd template XSIMD_INLINE void store_masked(float* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept { - XSIMD_IF_CONSTEXPR(mask.none()) - { - return; - } - else XSIMD_IF_CONSTEXPR(mask.all()) - { - src.store(mem, Mode {}); - } - else XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) + XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) { _mm_storel_pi(reinterpret_cast<__m64*>(mem), src); } @@ -1144,17 +1120,9 @@ namespace xsimd } template - XSIMD_INLINE void store_masked(double* mem, batch const& src, batch_bool_constant mask, Mode mode, requires_arch) noexcept + XSIMD_INLINE void store_masked(double* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept { - XSIMD_IF_CONSTEXPR(mask.none()) - { - return; - } - else XSIMD_IF_CONSTEXPR(mask.all()) - { - src.store(mem, mode); - } - else XSIMD_IF_CONSTEXPR(mask.countr_one() == 1) + XSIMD_IF_CONSTEXPR(mask.countr_one() == 1) { _mm_store_sd(mem, src); } @@ -2205,15 +2173,7 @@ namespace xsimd aligned_mode, requires_arch) noexcept { - XSIMD_IF_CONSTEXPR(mask.none()) - { - return; - } - else XSIMD_IF_CONSTEXPR(mask.all()) - { - _mm_store_ps(mem, src); - } - else XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) + XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) { _mm_storel_pi(reinterpret_cast<__m64*>(mem), src); } diff --git a/include/xsimd/types/xsimd_batch.hpp b/include/xsimd/types/xsimd_batch.hpp index 2581152ea..7ee1572d6 100644 --- a/include/xsimd/types/xsimd_batch.hpp +++ b/include/xsimd/types/xsimd_batch.hpp @@ -146,10 +146,8 @@ namespace xsimd XSIMD_INLINE void store(U* mem, unaligned_mode) const noexcept; // Compile-time mask overloads - template - XSIMD_INLINE void store(U* mem, batch_bool_constant mask, aligned_mode) const noexcept; - template - XSIMD_INLINE void store(U* mem, batch_bool_constant mask, unaligned_mode) const noexcept; + template + XSIMD_INLINE void store(U* mem, batch_bool_constant mask, Mode) const noexcept; template XSIMD_NO_DISCARD static XSIMD_INLINE batch load_aligned(U const* mem) noexcept; @@ -160,10 +158,8 @@ namespace xsimd template XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, unaligned_mode) noexcept; // Compile-time mask overloads - template - XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, batch_bool_constant mask, aligned_mode = {}) noexcept; - template - XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, batch_bool_constant mask, unaligned_mode) noexcept; + template + XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, batch_bool_constant mask, Mode = {}) noexcept; template XSIMD_NO_DISCARD static XSIMD_INLINE batch gather(U const* src, batch const& index) noexcept; @@ -418,19 +414,15 @@ namespace xsimd template XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, unaligned_mode) noexcept; // Compile-time mask overloads - template - XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, batch_bool_constant mask, aligned_mode = {}) noexcept; - template - XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, batch_bool_constant mask, unaligned_mode) noexcept; + template + XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, batch_bool_constant mask, Mode = {}) noexcept; template XSIMD_INLINE void store(U* mem, aligned_mode) const noexcept; template XSIMD_INLINE void store(U* mem, unaligned_mode) const noexcept; // Compile-time mask overloads - template - XSIMD_INLINE void store(U* mem, batch_bool_constant mask, aligned_mode) const noexcept; - template - XSIMD_INLINE void store(U* mem, batch_bool_constant mask, unaligned_mode) const noexcept; + template + XSIMD_INLINE void store(U* mem, batch_bool_constant mask, Mode = {}) const noexcept; XSIMD_INLINE real_batch real() const noexcept; XSIMD_INLINE real_batch imag() const noexcept; @@ -691,43 +683,49 @@ namespace xsimd } template - template + template XSIMD_INLINE batch batch::load(U const* mem, batch_bool_constant mask, - aligned_mode) noexcept + Mode mode) noexcept { detail::static_check_supported_config(); - return kernel::load_masked(mem, mask, kernel::convert {}, aligned_mode {}, A {}); - } - - template - template - XSIMD_INLINE batch batch::load(U const* mem, - batch_bool_constant mask, - unaligned_mode) noexcept - { - detail::static_check_supported_config(); - return kernel::load_masked(mem, mask, kernel::convert {}, unaligned_mode {}, A {}); - } - - template - template - XSIMD_INLINE void batch::store(U* mem, - batch_bool_constant mask, - aligned_mode) const noexcept - { - detail::static_check_supported_config(); - kernel::store_masked(mem, *this, mask, aligned_mode {}, A {}); + static_assert(std::is_same::value || std::is_same::value, + "supported load mode"); + XSIMD_IF_CONSTEXPR(mask.all()) + { + return load(mem, mode); + } + else XSIMD_IF_CONSTEXPR(mask.none()) + { + return broadcast(0); + } + else + { + return kernel::load_masked(mem, mask, kernel::convert {}, mode, A {}); + } } template - template + template XSIMD_INLINE void batch::store(U* mem, batch_bool_constant mask, - unaligned_mode) const noexcept + Mode mode) const noexcept { detail::static_check_supported_config(); - kernel::store_masked(mem, *this, mask, unaligned_mode {}, A {}); + static_assert(std::is_same::value || std::is_same::value, + "supported store mode"); + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + store(mem, mode); + } + else + { + kernel::store_masked(mem, *this, mask, mode, A {}); + } } /** @@ -1285,17 +1283,10 @@ namespace xsimd // Compile-time mask overloads for complex store template - template - XSIMD_INLINE void batch, A>::store(U* mem, batch_bool_constant mask, aligned_mode) const noexcept - { - kernel::store_masked(mem, *this, mask, aligned_mode {}, A {}); - } - - template - template - XSIMD_INLINE void batch, A>::store(U* mem, batch_bool_constant mask, unaligned_mode) const noexcept + template + XSIMD_INLINE void batch, A>::store(U* mem, batch_bool_constant mask, Mode mode) const noexcept { - kernel::store_masked(mem, *this, mask, unaligned_mode {}, A {}); + kernel::store_masked(mem, *this, mask, mode, A {}); } template @@ -1328,21 +1319,12 @@ namespace xsimd // Compile-time mask overloads for complex load template - template - XSIMD_INLINE batch, A> batch, A>::load(U const* mem, - batch_bool_constant mask, - aligned_mode) noexcept - { - return kernel::load_masked(mem, mask, kernel::convert {}, aligned_mode {}, A {}); - } - - template - template + template XSIMD_INLINE batch, A> batch, A>::load(U const* mem, batch_bool_constant mask, - unaligned_mode) noexcept + Mode mode) noexcept { - return kernel::load_masked(mem, mask, kernel::convert {}, unaligned_mode {}, A {}); + return kernel::load_masked(mem, mask, kernel::convert {}, mode, A {}); } template