diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index c00591e..58305dc 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -1,94 +1,94 @@ -// AVX2 specific routines: - -#include "x86simdsort-static-incl.h" -#include "x86simdsort-internal.h" - -#define DEFINE_ALL_METHODS(type) \ - template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ - { \ - x86simdsortStatic::qsort(arr, arrsize, hasnan, descending); \ - } \ - template <> \ - void qselect( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ - { \ - x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); \ - } \ - template <> \ - void partial_qsort( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ - { \ - x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); \ - } \ - template <> \ - std::vector argsort( \ - type *arr, size_t arrsize, bool hasnan, bool descending) \ - { \ - return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \ - } \ - template <> \ - std::vector argselect( \ - type *arr, size_t k, size_t arrsize, bool hasnan) \ - { \ - return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ - } - -#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ - template <> \ - void keyvalue_qsort(type1 *key, \ - type2 *val, \ - size_t arrsize, \ - bool hasnan, \ - bool descending) \ - { \ - x86simdsortStatic::keyvalue_qsort( \ - key, val, arrsize, hasnan, descending); \ - } \ - template <> \ - void keyvalue_select(type1 *key, \ - type2 *val, \ - size_t k, \ - size_t arrsize, \ - bool hasnan, \ - bool descending) \ - { \ - x86simdsortStatic::keyvalue_select( \ - key, val, k, arrsize, hasnan, descending); \ - } \ - template <> \ - void keyvalue_partial_sort(type1 *key, \ - type2 *val, \ - size_t k, \ - size_t arrsize, \ - bool hasnan, \ - bool descending) \ - { \ - x86simdsortStatic::keyvalue_partial_sort( \ - key, val, k, arrsize, hasnan, descending); \ - } - -#define DEFINE_KEYVALUE_METHODS(type) \ - DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \ - DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \ - DEFINE_KEYVALUE_METHODS_BASE(type, double) \ - DEFINE_KEYVALUE_METHODS_BASE(type, uint32_t) \ - DEFINE_KEYVALUE_METHODS_BASE(type, int32_t) \ - DEFINE_KEYVALUE_METHODS_BASE(type, float) - -namespace xss { -namespace avx2 { - DEFINE_ALL_METHODS(uint32_t) - DEFINE_ALL_METHODS(int32_t) - DEFINE_ALL_METHODS(float) - DEFINE_ALL_METHODS(uint64_t) - DEFINE_ALL_METHODS(int64_t) - DEFINE_ALL_METHODS(double) - DEFINE_KEYVALUE_METHODS(uint64_t) - DEFINE_KEYVALUE_METHODS(int64_t) - DEFINE_KEYVALUE_METHODS(double) - DEFINE_KEYVALUE_METHODS(uint32_t) - DEFINE_KEYVALUE_METHODS(int32_t) - DEFINE_KEYVALUE_METHODS(float) -} // namespace avx2 +// AVX2 specific routines: + +#include "x86simdsort-static-incl.h" +#include "x86simdsort-internal.h" + +#define DEFINE_ALL_METHODS(type) \ + template <> \ + void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ + { \ + x86simdsortStatic::qsort(arr, arrsize, hasnan, descending); \ + } \ + template <> \ + void qselect( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + { \ + x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending); \ + } \ + template <> \ + void partial_qsort( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + { \ + x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending); \ + } \ + template <> \ + std::vector argsort( \ + type *arr, size_t arrsize, bool hasnan, bool descending) \ + { \ + return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \ + } \ + template <> \ + std::vector argselect( \ + type *arr, size_t k, size_t arrsize, bool hasnan) \ + { \ + return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ + } + +#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ + template <> \ + void keyvalue_qsort(type1 *key, \ + type2 *val, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ + { \ + x86simdsortStatic::keyvalue_qsort( \ + key, val, arrsize, hasnan, descending); \ + } \ + template <> \ + void keyvalue_select(type1 *key, \ + type2 *val, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ + { \ + x86simdsortStatic::keyvalue_select( \ + key, val, k, arrsize, hasnan, descending); \ + } \ + template <> \ + void keyvalue_partial_sort(type1 *key, \ + type2 *val, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ + { \ + x86simdsortStatic::keyvalue_partial_sort( \ + key, val, k, arrsize, hasnan, descending); \ + } + +#define DEFINE_KEYVALUE_METHODS(type) \ + DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, double) \ + DEFINE_KEYVALUE_METHODS_BASE(type, uint32_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, int32_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, float) + +namespace xss { +namespace avx2 { + DEFINE_ALL_METHODS(uint32_t) + DEFINE_ALL_METHODS(int32_t) + DEFINE_ALL_METHODS(float) + DEFINE_ALL_METHODS(uint64_t) + DEFINE_ALL_METHODS(int64_t) + DEFINE_ALL_METHODS(double) + DEFINE_KEYVALUE_METHODS(uint64_t) + DEFINE_KEYVALUE_METHODS(int64_t) + DEFINE_KEYVALUE_METHODS(double) + DEFINE_KEYVALUE_METHODS(uint32_t) + DEFINE_KEYVALUE_METHODS(int32_t) + DEFINE_KEYVALUE_METHODS(float) +} // namespace avx2 } // namespace xss \ No newline at end of file diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index a30da7c..527462c 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -1,421 +1,421 @@ -#ifndef AVX2_EMU_FUNCS -#define AVX2_EMU_FUNCS - -#include -#include - -constexpr auto avx2_mask_helper_lut32 = [] { - std::array, 256> lut {}; - for (int64_t i = 0; i <= 0xFF; i++) { - std::array entry {}; - for (int j = 0; j < 8; j++) { - if (((i >> j) & 1) == 1) - entry[j] = 0xFFFFFFFF; - else - entry[j] = 0; - } - lut[i] = entry; - } - return lut; -}(); - -constexpr auto avx2_mask_helper_lut64 = [] { - std::array, 16> lut {}; - for (int64_t i = 0; i <= 0xF; i++) { - std::array entry {}; - for (int j = 0; j < 4; j++) { - if (((i >> j) & 1) == 1) - entry[j] = 0xFFFFFFFFFFFFFFFF; - else - entry[j] = 0; - } - lut[i] = entry; - } - return lut; -}(); - -constexpr auto avx2_mask_helper_lut32_half = [] { - std::array, 16> lut {}; - for (int64_t i = 0; i <= 0xF; i++) { - std::array entry {}; - for (int j = 0; j < 4; j++) { - if (((i >> j) & 1) == 1) - entry[j] = 0xFFFFFFFF; - else - entry[j] = 0; - } - lut[i] = entry; - } - return lut; -}(); - -constexpr auto avx2_compressstore_lut32_gen = [] { - std::array, 256>, 2> lutPair {}; - auto &permLut = lutPair[0]; - auto &leftLut = lutPair[1]; - for (int64_t i = 0; i <= 0xFF; i++) { - std::array indices {}; - std::array leftEntry = {0, 0, 0, 0, 0, 0, 0, 0}; - int right = 7; - int left = 0; - for (int j = 0; j < 8; j++) { - bool ge = (i >> j) & 1; - if (ge) { - indices[right] = j; - right--; - } - else { - indices[left] = j; - leftEntry[left] = 0xFFFFFFFF; - left++; - } - } - permLut[i] = indices; - leftLut[i] = leftEntry; - } - return lutPair; -}(); - -constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0]; -constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1]; - -constexpr auto avx2_compressstore_lut32_half_gen = [] { - std::array, 16>, 2> lutPair {}; - auto &permLut = lutPair[0]; - auto &leftLut = lutPair[1]; - for (int64_t i = 0; i <= 0xF; i++) { - std::array indices {}; - std::array leftEntry = {0, 0, 0, 0}; - int right = 3; - int left = 0; - for (int j = 0; j < 4; j++) { - bool ge = (i >> j) & 1; - if (ge) { - indices[right] = j; - right--; - } - else { - indices[left] = j; - leftEntry[left] = 0xFFFFFFFF; - left++; - } - } - permLut[i] = indices; - leftLut[i] = leftEntry; - } - return lutPair; -}(); - -constexpr auto avx2_compressstore_lut32_half_perm - = avx2_compressstore_lut32_half_gen[0]; -constexpr auto avx2_compressstore_lut32_half_left - = avx2_compressstore_lut32_half_gen[1]; - -constexpr auto avx2_compressstore_lut64_gen = [] { - std::array, 16> permLut {}; - std::array, 16> leftLut {}; - for (int64_t i = 0; i <= 0xF; i++) { - std::array indices {}; - std::array leftEntry = {0, 0, 0, 0}; - int right = 7; - int left = 0; - for (int j = 0; j < 4; j++) { - bool ge = (i >> j) & 1; - if (ge) { - indices[right] = 2 * j + 1; - indices[right - 1] = 2 * j; - right -= 2; - } - else { - indices[left + 1] = 2 * j + 1; - indices[left] = 2 * j; - leftEntry[left / 2] = 0xFFFFFFFFFFFFFFFF; - left += 2; - } - } - permLut[i] = indices; - leftLut[i] = leftEntry; - } - return std::make_pair(permLut, leftLut); -}(); -constexpr auto avx2_compressstore_lut64_perm - = avx2_compressstore_lut64_gen.first; -constexpr auto avx2_compressstore_lut64_left - = avx2_compressstore_lut64_gen.second; - -X86_SIMD_SORT_INLINE -__m256i convert_int_to_avx2_mask(int32_t m) -{ - return _mm256_loadu_si256( - (const __m256i *)avx2_mask_helper_lut32[m].data()); -} - -X86_SIMD_SORT_INLINE -int32_t convert_avx2_mask_to_int(__m256i m) -{ - return _mm256_movemask_ps(_mm256_castsi256_ps(m)); -} - -X86_SIMD_SORT_INLINE -__m256i convert_int_to_avx2_mask_64bit(int32_t m) -{ - return _mm256_loadu_si256( - (const __m256i *)avx2_mask_helper_lut64[m].data()); -} - -X86_SIMD_SORT_INLINE -int32_t convert_avx2_mask_to_int_64bit(__m256i m) -{ - return _mm256_movemask_pd(_mm256_castsi256_pd(m)); -} - -X86_SIMD_SORT_INLINE -__m128i convert_int_to_avx2_mask_half(int32_t m) -{ - return _mm_loadu_si128( - (const __m128i *)avx2_mask_helper_lut32_half[m].data()); -} - -X86_SIMD_SORT_INLINE -int32_t convert_avx2_mask_to_int_half(__m128i m) -{ - return _mm_movemask_ps(_mm_castsi128_ps(m)); -} - -// Emulators for intrinsics missing from AVX2 compared to AVX512 -template -T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) -{ - using vtype = avx2_vector; - using reg_t = typename vtype::reg_t; - - reg_t inter1 = vtype::max( - x, vtype::template shuffle(x)); - reg_t inter2 = vtype::max( - inter1, vtype::template shuffle(inter1)); - T arr[vtype::numlanes]; - vtype::storeu(arr, inter2); - return std::max(arr[0], arr[7]); -} - -template -T avx2_emu_reduce_max32_half(typename avx2_half_vector::reg_t x) -{ - using vtype = avx2_half_vector; - using reg_t = typename vtype::reg_t; - - reg_t inter1 = vtype::max( - x, vtype::template shuffle(x)); - T arr[vtype::numlanes]; - vtype::storeu(arr, inter1); - return std::max(arr[0], arr[3]); -} - -template -T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) -{ - using vtype = avx2_vector; - using reg_t = typename vtype::reg_t; - - reg_t inter1 = vtype::min( - x, vtype::template shuffle(x)); - reg_t inter2 = vtype::min( - inter1, vtype::template shuffle(inter1)); - T arr[vtype::numlanes]; - vtype::storeu(arr, inter2); - return std::min(arr[0], arr[7]); -} - -template -T avx2_emu_reduce_min32_half(typename avx2_half_vector::reg_t x) -{ - using vtype = avx2_half_vector; - using reg_t = typename vtype::reg_t; - - reg_t inter1 = vtype::min( - x, vtype::template shuffle(x)); - T arr[vtype::numlanes]; - vtype::storeu(arr, inter1); - return std::min(arr[0], arr[3]); -} - -template -T avx2_emu_reduce_max64(typename avx2_vector::reg_t x) -{ - using vtype = avx2_vector; - typename vtype::reg_t inter1 = vtype::max( - x, vtype::template permutexvar(x)); - T arr[vtype::numlanes]; - vtype::storeu(arr, inter1); - return std::max(arr[0], arr[3]); -} - -template -T avx2_emu_reduce_min64(typename avx2_vector::reg_t x) -{ - using vtype = avx2_vector; - typename vtype::reg_t inter1 = vtype::min( - x, vtype::template permutexvar(x)); - T arr[vtype::numlanes]; - vtype::storeu(arr, inter1); - return std::min(arr[0], arr[3]); -} - -template -void avx2_emu_mask_compressstoreu32(void *base_addr, - typename avx2_vector::opmask_t k, - typename avx2_vector::reg_t reg) -{ - using vtype = avx2_vector; - - T *leftStore = (T *)base_addr; - - int32_t shortMask = convert_avx2_mask_to_int(k); - const __m256i &perm = _mm256_loadu_si256( - (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); - const __m256i &left = _mm256_loadu_si256( - (const __m256i *)avx2_compressstore_lut32_left[shortMask].data()); - - typename vtype::reg_t temp = vtype::permutexvar(perm, reg); - - vtype::mask_storeu(leftStore, left, temp); -} - -template -void avx2_emu_mask_compressstoreu32_half( - void *base_addr, - typename avx2_half_vector::opmask_t k, - typename avx2_half_vector::reg_t reg) -{ - using vtype = avx2_half_vector; - - T *leftStore = (T *)base_addr; - - int32_t shortMask = convert_avx2_mask_to_int_half(k); - const __m128i &perm = _mm_loadu_si128( - (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] - .data()); - const __m128i &left = _mm_loadu_si128( - (const __m128i *)avx2_compressstore_lut32_half_left[shortMask] - .data()); - - typename vtype::reg_t temp = vtype::permutexvar(perm, reg); - - vtype::mask_storeu(leftStore, left, temp); -} - -template -void avx2_emu_mask_compressstoreu64(void *base_addr, - typename avx2_vector::opmask_t k, - typename avx2_vector::reg_t reg) -{ - using vtype = avx2_vector; - - T *leftStore = (T *)base_addr; - - int32_t shortMask = convert_avx2_mask_to_int_64bit(k); - const __m256i &perm = _mm256_loadu_si256( - (const __m256i *)avx2_compressstore_lut64_perm[shortMask].data()); - const __m256i &left = _mm256_loadu_si256( - (const __m256i *)avx2_compressstore_lut64_left[shortMask].data()); - - typename vtype::reg_t temp = vtype::cast_from( - _mm256_permutevar8x32_epi32(vtype::cast_to(reg), perm)); - - vtype::mask_storeu(leftStore, left, temp); -} - -template -int avx2_double_compressstore32(void *left_addr, - void *right_addr, - typename avx2_vector::opmask_t k, - typename avx2_vector::reg_t reg) -{ - using vtype = avx2_vector; - - T *leftStore = (T *)left_addr; - T *rightStore = (T *)right_addr; - - int32_t shortMask = convert_avx2_mask_to_int(k); - const __m256i &perm = _mm256_loadu_si256( - (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); - - typename vtype::reg_t temp = vtype::permutexvar(perm, reg); - - vtype::storeu(leftStore, temp); - vtype::storeu(rightStore, temp); - - return _mm_popcnt_u32(shortMask); -} - -template -int avx2_double_compressstore32_half(void *left_addr, - void *right_addr, - typename avx2_half_vector::opmask_t k, - typename avx2_half_vector::reg_t reg) -{ - using vtype = avx2_half_vector; - - T *leftStore = (T *)left_addr; - T *rightStore = (T *)right_addr; - - int32_t shortMask = convert_avx2_mask_to_int_half(k); - const __m128i &perm = _mm_loadu_si128( - (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] - .data()); - - typename vtype::reg_t temp = vtype::permutexvar(perm, reg); - - vtype::storeu(leftStore, temp); - vtype::storeu(rightStore, temp); - - return _mm_popcnt_u32(shortMask); -} - -template -int32_t avx2_double_compressstore64(void *left_addr, - void *right_addr, - typename avx2_vector::opmask_t k, - typename avx2_vector::reg_t reg) -{ - using vtype = avx2_vector; - - T *leftStore = (T *)left_addr; - T *rightStore = (T *)right_addr; - - int32_t shortMask = convert_avx2_mask_to_int_64bit(k); - const __m256i &perm = _mm256_loadu_si256( - (const __m256i *)avx2_compressstore_lut64_perm[shortMask].data()); - - typename vtype::reg_t temp = vtype::cast_from( - _mm256_permutevar8x32_epi32(vtype::cast_to(reg), perm)); - - vtype::storeu(leftStore, temp); - vtype::storeu(rightStore, temp); - - return _mm_popcnt_u32(shortMask); -} - -template -typename avx2_vector::reg_t avx2_emu_max(typename avx2_vector::reg_t x, - typename avx2_vector::reg_t y) -{ - using vtype = avx2_vector; - typename vtype::opmask_t nlt = vtype::gt(x, y); - return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(y), - _mm256_castsi256_pd(x), - _mm256_castsi256_pd(nlt))); -} - -template -typename avx2_vector::reg_t avx2_emu_min(typename avx2_vector::reg_t x, - typename avx2_vector::reg_t y) -{ - using vtype = avx2_vector; - typename vtype::opmask_t nlt = vtype::gt(x, y); - return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(x), - _mm256_castsi256_pd(y), - _mm256_castsi256_pd(nlt))); -} - -#endif +#ifndef AVX2_EMU_FUNCS +#define AVX2_EMU_FUNCS + +#include +#include + +constexpr auto avx2_mask_helper_lut32 = [] { + std::array, 256> lut {}; + for (int64_t i = 0; i <= 0xFF; i++) { + std::array entry {}; + for (int j = 0; j < 8; j++) { + if (((i >> j) & 1) == 1) + entry[j] = 0xFFFFFFFF; + else + entry[j] = 0; + } + lut[i] = entry; + } + return lut; +}(); + +constexpr auto avx2_mask_helper_lut64 = [] { + std::array, 16> lut {}; + for (int64_t i = 0; i <= 0xF; i++) { + std::array entry {}; + for (int j = 0; j < 4; j++) { + if (((i >> j) & 1) == 1) + entry[j] = 0xFFFFFFFFFFFFFFFF; + else + entry[j] = 0; + } + lut[i] = entry; + } + return lut; +}(); + +constexpr auto avx2_mask_helper_lut32_half = [] { + std::array, 16> lut {}; + for (int64_t i = 0; i <= 0xF; i++) { + std::array entry {}; + for (int j = 0; j < 4; j++) { + if (((i >> j) & 1) == 1) + entry[j] = 0xFFFFFFFF; + else + entry[j] = 0; + } + lut[i] = entry; + } + return lut; +}(); + +constexpr auto avx2_compressstore_lut32_gen = [] { + std::array, 256>, 2> lutPair {}; + auto &permLut = lutPair[0]; + auto &leftLut = lutPair[1]; + for (int64_t i = 0; i <= 0xFF; i++) { + std::array indices {}; + std::array leftEntry = {0, 0, 0, 0, 0, 0, 0, 0}; + int right = 7; + int left = 0; + for (int j = 0; j < 8; j++) { + bool ge = (i >> j) & 1; + if (ge) { + indices[right] = j; + right--; + } + else { + indices[left] = j; + leftEntry[left] = 0xFFFFFFFF; + left++; + } + } + permLut[i] = indices; + leftLut[i] = leftEntry; + } + return lutPair; +}(); + +constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0]; +constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1]; + +constexpr auto avx2_compressstore_lut32_half_gen = [] { + std::array, 16>, 2> lutPair {}; + auto &permLut = lutPair[0]; + auto &leftLut = lutPair[1]; + for (int64_t i = 0; i <= 0xF; i++) { + std::array indices {}; + std::array leftEntry = {0, 0, 0, 0}; + int right = 3; + int left = 0; + for (int j = 0; j < 4; j++) { + bool ge = (i >> j) & 1; + if (ge) { + indices[right] = j; + right--; + } + else { + indices[left] = j; + leftEntry[left] = 0xFFFFFFFF; + left++; + } + } + permLut[i] = indices; + leftLut[i] = leftEntry; + } + return lutPair; +}(); + +constexpr auto avx2_compressstore_lut32_half_perm + = avx2_compressstore_lut32_half_gen[0]; +constexpr auto avx2_compressstore_lut32_half_left + = avx2_compressstore_lut32_half_gen[1]; + +constexpr auto avx2_compressstore_lut64_gen = [] { + std::array, 16> permLut {}; + std::array, 16> leftLut {}; + for (int64_t i = 0; i <= 0xF; i++) { + std::array indices {}; + std::array leftEntry = {0, 0, 0, 0}; + int right = 7; + int left = 0; + for (int j = 0; j < 4; j++) { + bool ge = (i >> j) & 1; + if (ge) { + indices[right] = 2 * j + 1; + indices[right - 1] = 2 * j; + right -= 2; + } + else { + indices[left + 1] = 2 * j + 1; + indices[left] = 2 * j; + leftEntry[left / 2] = 0xFFFFFFFFFFFFFFFF; + left += 2; + } + } + permLut[i] = indices; + leftLut[i] = leftEntry; + } + return std::make_pair(permLut, leftLut); +}(); +constexpr auto avx2_compressstore_lut64_perm + = avx2_compressstore_lut64_gen.first; +constexpr auto avx2_compressstore_lut64_left + = avx2_compressstore_lut64_gen.second; + +X86_SIMD_SORT_INLINE +__m256i convert_int_to_avx2_mask(int32_t m) +{ + return _mm256_loadu_si256( + (const __m256i *)avx2_mask_helper_lut32[m].data()); +} + +X86_SIMD_SORT_INLINE +int32_t convert_avx2_mask_to_int(__m256i m) +{ + return _mm256_movemask_ps(_mm256_castsi256_ps(m)); +} + +X86_SIMD_SORT_INLINE +__m256i convert_int_to_avx2_mask_64bit(int32_t m) +{ + return _mm256_loadu_si256( + (const __m256i *)avx2_mask_helper_lut64[m].data()); +} + +X86_SIMD_SORT_INLINE +int32_t convert_avx2_mask_to_int_64bit(__m256i m) +{ + return _mm256_movemask_pd(_mm256_castsi256_pd(m)); +} + +X86_SIMD_SORT_INLINE +__m128i convert_int_to_avx2_mask_half(int32_t m) +{ + return _mm_loadu_si128( + (const __m128i *)avx2_mask_helper_lut32_half[m].data()); +} + +X86_SIMD_SORT_INLINE +int32_t convert_avx2_mask_to_int_half(__m128i m) +{ + return _mm_movemask_ps(_mm_castsi128_ps(m)); +} + +// Emulators for intrinsics missing from AVX2 compared to AVX512 +template +T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) +{ + using vtype = avx2_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::max( + x, vtype::template shuffle(x)); + reg_t inter2 = vtype::max( + inter1, vtype::template shuffle(inter1)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter2); + return std::max(arr[0], arr[7]); +} + +template +T avx2_emu_reduce_max32_half(typename avx2_half_vector::reg_t x) +{ + using vtype = avx2_half_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::max( + x, vtype::template shuffle(x)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter1); + return std::max(arr[0], arr[3]); +} + +template +T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) +{ + using vtype = avx2_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::min( + x, vtype::template shuffle(x)); + reg_t inter2 = vtype::min( + inter1, vtype::template shuffle(inter1)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter2); + return std::min(arr[0], arr[7]); +} + +template +T avx2_emu_reduce_min32_half(typename avx2_half_vector::reg_t x) +{ + using vtype = avx2_half_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::min( + x, vtype::template shuffle(x)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter1); + return std::min(arr[0], arr[3]); +} + +template +T avx2_emu_reduce_max64(typename avx2_vector::reg_t x) +{ + using vtype = avx2_vector; + typename vtype::reg_t inter1 = vtype::max( + x, vtype::template permutexvar(x)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter1); + return std::max(arr[0], arr[3]); +} + +template +T avx2_emu_reduce_min64(typename avx2_vector::reg_t x) +{ + using vtype = avx2_vector; + typename vtype::reg_t inter1 = vtype::min( + x, vtype::template permutexvar(x)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter1); + return std::min(arr[0], arr[3]); +} + +template +void avx2_emu_mask_compressstoreu32(void *base_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) +{ + using vtype = avx2_vector; + + T *leftStore = (T *)base_addr; + + int32_t shortMask = convert_avx2_mask_to_int(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); + const __m256i &left = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_left[shortMask].data()); + + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); + + vtype::mask_storeu(leftStore, left, temp); +} + +template +void avx2_emu_mask_compressstoreu32_half( + void *base_addr, + typename avx2_half_vector::opmask_t k, + typename avx2_half_vector::reg_t reg) +{ + using vtype = avx2_half_vector; + + T *leftStore = (T *)base_addr; + + int32_t shortMask = convert_avx2_mask_to_int_half(k); + const __m128i &perm = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] + .data()); + const __m128i &left = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_left[shortMask] + .data()); + + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); + + vtype::mask_storeu(leftStore, left, temp); +} + +template +void avx2_emu_mask_compressstoreu64(void *base_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) +{ + using vtype = avx2_vector; + + T *leftStore = (T *)base_addr; + + int32_t shortMask = convert_avx2_mask_to_int_64bit(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut64_perm[shortMask].data()); + const __m256i &left = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut64_left[shortMask].data()); + + typename vtype::reg_t temp = vtype::cast_from( + _mm256_permutevar8x32_epi32(vtype::cast_to(reg), perm)); + + vtype::mask_storeu(leftStore, left, temp); +} + +template +int avx2_double_compressstore32(void *left_addr, + void *right_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) +{ + using vtype = avx2_vector; + + T *leftStore = (T *)left_addr; + T *rightStore = (T *)right_addr; + + int32_t shortMask = convert_avx2_mask_to_int(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); + + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); + + vtype::storeu(leftStore, temp); + vtype::storeu(rightStore, temp); + + return _mm_popcnt_u32(shortMask); +} + +template +int avx2_double_compressstore32_half(void *left_addr, + void *right_addr, + typename avx2_half_vector::opmask_t k, + typename avx2_half_vector::reg_t reg) +{ + using vtype = avx2_half_vector; + + T *leftStore = (T *)left_addr; + T *rightStore = (T *)right_addr; + + int32_t shortMask = convert_avx2_mask_to_int_half(k); + const __m128i &perm = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] + .data()); + + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); + + vtype::storeu(leftStore, temp); + vtype::storeu(rightStore, temp); + + return _mm_popcnt_u32(shortMask); +} + +template +int32_t avx2_double_compressstore64(void *left_addr, + void *right_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) +{ + using vtype = avx2_vector; + + T *leftStore = (T *)left_addr; + T *rightStore = (T *)right_addr; + + int32_t shortMask = convert_avx2_mask_to_int_64bit(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut64_perm[shortMask].data()); + + typename vtype::reg_t temp = vtype::cast_from( + _mm256_permutevar8x32_epi32(vtype::cast_to(reg), perm)); + + vtype::storeu(leftStore, temp); + vtype::storeu(rightStore, temp); + + return _mm_popcnt_u32(shortMask); +} + +template +typename avx2_vector::reg_t avx2_emu_max(typename avx2_vector::reg_t x, + typename avx2_vector::reg_t y) +{ + using vtype = avx2_vector; + typename vtype::opmask_t nlt = vtype::gt(x, y); + return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(y), + _mm256_castsi256_pd(x), + _mm256_castsi256_pd(nlt))); +} + +template +typename avx2_vector::reg_t avx2_emu_min(typename avx2_vector::reg_t x, + typename avx2_vector::reg_t y) +{ + using vtype = avx2_vector; + typename vtype::opmask_t nlt = vtype::gt(x, y); + return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(x), + _mm256_castsi256_pd(y), + _mm256_castsi256_pd(nlt))); +} + +#endif diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 6c071c2..7b805fa 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -1,754 +1,754 @@ -/******************************************************************* - * Copyright (C) 2022 Intel Corporation - * SPDX-License-Identifier: BSD-3-Clause - * Authors: Raghuveer Devulapalli - * ****************************************************************/ - -#ifndef XSS_COMMON_ARGSORT -#define XSS_COMMON_ARGSORT - -#include "xss-network-keyvaluesort.hpp" -#include - -template -X86_SIMD_SORT_INLINE void std_argselect_withnan( - T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) -{ - std::nth_element(arg + left, - arg + k, - arg + right, - [arr](arrsize_t a, arrsize_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { - return arr[a] < arr[b]; - } - else if (std::isnan(arr[a])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { - return arr[left] < arr[right]; - } - else if (std::isnan(arr[left])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - // sort indices according to corresponding array element - return arr[left] < arr[right]; - }); -} - -/* - * Parition one ZMM register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE int32_t partition_vec_avx512(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); - int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - argtype::mask_compressstoreu( - arg + left, vtype::knot_opmask(gt_mask), arg_vec); - argtype::mask_compressstoreu( - arg + right - amount_gt_pivot, gt_mask, arg_vec); - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); - return amount_gt_pivot; -} - -/* - * Parition one AVX2 register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec); - typename argtype::opmask_t ge_mask - = resize_mask(ge_mask_vtype); - - auto l_store = arg + left; - auto r_store = arg + right - vtype::numlanes; - - int amount_ge_pivot - = argtype::double_compressstore(l_store, r_store, ge_mask, arg_vec); - - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); - - return amount_ge_pivot; -} - -template -X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - if constexpr (vtype::vec_type == simd_type::AVX512) { - return partition_vec_avx512(arg, - left, - right, - arg_vec, - curr_vec, - pivot_vec, - smallest_vec, - biggest_vec); - } - else if constexpr (vtype::vec_type == simd_type::AVX2) { - return partition_vec_avx2(arg, - left, - right, - arg_vec, - curr_vec, - pivot_vec, - smallest_vec, - biggest_vec); - } - else { - static_assert(sizeof(argreg_t) == 0, "Should not get here"); - } -} - -/* - * Parition an array based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE arrsize_t argpartition(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using reg_t = typename vtype::reg_t; - using argreg_t = typename argtype::reg_t; - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - if (right - left == vtype::numlanes) { - argreg_t argvec = argtype::loadu(arg + left); - reg_t vec = vtype::i64gather(arr, arg + left); - int32_t amount_gt_pivot - = partition_vec(arg, - left, - left + vtype::numlanes, - argvec, - vec, - pivot_vec, - &min_vec, - &max_vec); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return left + (vtype::numlanes - amount_gt_pivot); - } - - // first and last vtype::numlanes values are partitioned at the end - argreg_t argvec_left = argtype::loadu(arg + left); - reg_t vec_left = vtype::i64gather(arr, arg + left); - argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); - reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += vtype::numlanes; - right -= vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec; - reg_t curr_vec; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= vtype::numlanes; - arg_vec = argtype::loadu(arg + right); - curr_vec = vtype::i64gather(arr, arg + right); - } - else { - arg_vec = argtype::loadu(arg + left); - curr_vec = vtype::i64gather(arr, arg + left); - left += vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec, - curr_vec, - pivot_vec, - &min_vec, - &max_vec); - ; - r_store -= amount_gt_pivot; - l_store += (vtype::numlanes - amount_gt_pivot); - } - - /* partition and save vec_left and vec_right */ - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left, - vec_left, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(arg, - l_store, - l_store + vtype::numlanes, - argvec_right, - vec_right, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE arrsize_t argpartition_unrolled(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - if (right - left <= 8 * num_unroll * vtype::numlanes) { - return argpartition( - arr, arg, left, right, pivot, smallest, biggest); - } - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; - --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using reg_t = typename vtype::reg_t; - using argreg_t = typename argtype::reg_t; - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - // first and last vtype::numlanes values are partitioned at the end - reg_t vec_left[num_unroll], vec_right[num_unroll]; - argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); - vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); - argvec_right[ii] = argtype::loadu( - arg + (right - vtype::numlanes * (num_unroll - ii))); - vec_right[ii] = vtype::i64gather( - arr, arg + (right - vtype::numlanes * (num_unroll - ii))); - } - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += num_unroll * vtype::numlanes; - right -= num_unroll * vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec[num_unroll]; - reg_t curr_vec[num_unroll]; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= num_unroll * vtype::numlanes; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] - = argtype::loadu(arg + right + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + right + ii * vtype::numlanes); - } - } - else { - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + left + ii * vtype::numlanes); - } - left += num_unroll * vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec[ii], - curr_vec[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - } - - /* partition and save vec_left and vec_right */ - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left[ii], - vec_left[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_right[ii], - vec_right[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, - arrsize_t *arg, - const arrsize_t left, - const arrsize_t right) -{ - if constexpr (vtype::numlanes == 8) { - if (right - left >= vtype::numlanes) { - // median of 8 - arrsize_t size = (right - left) / 8; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[arg[left + size]], - arr[arg[left + 2 * size]], - arr[arg[left + 3 * size]], - arr[arg[left + 4 * size]], - arr[arg[left + 5 * size]], - arr[arg[left + 6 * size]], - arr[arg[left + 7 * size]], - arr[arg[left + 8 * size]]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[4]; - } - else { - return arr[arg[right]]; - } - } - else if constexpr (vtype::numlanes == 4) { - if (right - left >= vtype::numlanes) { - // median of 4 - arrsize_t size = (right - left) / 4; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[arg[left + size]], - arr[arg[left + 2 * size]], - arr[arg[left + 3 * size]], - arr[arg[left + 4 * size]]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[2]; - } - else { - return arr[arg[right]]; - } - } -} - -template -X86_SIMD_SORT_INLINE void argsort_(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters, - arrsize_t task_threshold) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 256) { - argsort_n( - arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = argpartition_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); -#ifdef XSS_COMPILE_OPENMP - if (pivot != smallest) { - bool parallel_left = (pivot_index - left) > task_threshold; - if (parallel_left) { -#pragma omp task - argsort_(arr, - arg, - left, - pivot_index - 1, - max_iters - 1, - task_threshold); - } - else { - argsort_(arr, - arg, - left, - pivot_index - 1, - max_iters - 1, - task_threshold); - } - } - if (pivot != biggest) { - bool parallel_right = (right - pivot_index) > task_threshold; - - if (parallel_right) { -#pragma omp task - argsort_(arr, - arg, - pivot_index, - right, - max_iters - 1, - task_threshold); - } - else { - argsort_(arr, - arg, - pivot_index, - right, - max_iters - 1, - task_threshold); - } - } -#else - UNUSED(task_threshold); - if (pivot != smallest) - argsort_( - arr, arg, left, pivot_index - 1, max_iters - 1, 0); - if (pivot != biggest) - argsort_( - arr, arg, pivot_index, right, max_iters - 1, 0); -#endif -} - -template -X86_SIMD_SORT_INLINE void argselect_(type_t *arr, - arrsize_t *arg, - arrsize_t pos, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 256) { - argsort_n( - arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = argpartition_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - argselect_( - arr, arg, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_( - arr, arg, pos, pivot_index, right, max_iters - 1); -} - -/* argsort methods for 32-bit and 64-bit dtypes */ -template - typename full_vector, - template - typename half_vector> -X86_SIMD_SORT_INLINE void xss_argsort(T *arr, - arrsize_t *arg, - arrsize_t arrsize, - bool hasnan = false, - bool descending = false) -{ - - using vectype = typename std::conditional, - full_vector>::type; - - using argtype = - typename std::conditional, - full_vector>::type; - - if (arrsize > 1) { - /* simdargsort does not work for float/double arrays with nan */ - if constexpr (xss::fp::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argsort_withnan(arr, arg, 0, arrsize); - - if (descending) { std::reverse(arg, arg + arrsize); } - - return; - } - } - UNUSED(hasnan); - - /* early exit for already sorted arrays: float/double with nan never reach here*/ - auto comp = descending ? Comparator::STDSortComparator - : Comparator::STDSortComparator; - if (std::is_sorted(arr, arr + arrsize, comp)) { return; } - -#ifdef XSS_COMPILE_OPENMP - - bool use_parallel = arrsize > 10000; - - if (use_parallel) { - int thread_count = xss_get_num_threads(); - arrsize_t task_threshold - = std::max((arrsize_t)10000, arrsize / 100); - - // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ - // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems - // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays -#pragma omp parallel num_threads(thread_count) -#pragma omp single - argsort_(arr, - arg, - 0, - arrsize - 1, - 2 * (arrsize_t)log2(arrsize), - task_threshold); -#pragma omp taskwait - } - else { - argsort_(arr, - arg, - 0, - arrsize - 1, - 2 * (arrsize_t)log2(arrsize), - std::numeric_limits::max()); - } -#else - argsort_( - arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); -#endif - - if (descending) { std::reverse(arg, arg + arrsize); } - } - -#ifdef __MMX__ - // Workaround for compiler bug generating MMX instructions without emms - _mm_empty(); -#endif -} - -template -X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, - arrsize_t *arg, - arrsize_t arrsize, - bool hasnan = false, - bool descending = false) -{ - xss_argsort( - arr, arg, arrsize, hasnan, descending); -} - -template -X86_SIMD_SORT_INLINE void avx2_argsort(T *arr, - arrsize_t *arg, - arrsize_t arrsize, - bool hasnan = false, - bool descending = false) -{ - xss_argsort( - arr, arg, arrsize, hasnan, descending); -} - -/* argselect methods for 32-bit and 64-bit dtypes */ -template - typename full_vector, - template - typename half_vector> -X86_SIMD_SORT_INLINE void xss_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */ - using vectype = typename std::conditional, - full_vector>::type; - - using argtype = - typename std::conditional, - full_vector>::type; - - if (arrsize > 1) { - if constexpr (xss::fp::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argselect_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } - -#ifdef __MMX__ - // Workaround for compiler bug generating MMX instructions without emms - _mm_empty(); -#endif -} - -template -X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - xss_argselect(arr, arg, k, arrsize, hasnan); -} - -template -X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - xss_argselect( - arr, arg, k, arrsize, hasnan); -} - -#endif // XSS_COMMON_ARGSORT +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef XSS_COMMON_ARGSORT +#define XSS_COMMON_ARGSORT + +#include "xss-network-keyvaluesort.hpp" +#include + +template +X86_SIMD_SORT_INLINE void std_argselect_withnan( + T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) +{ + std::nth_element(arg + left, + arg + k, + arg + right, + [arr](arrsize_t a, arrsize_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] < arr[b]; + } + else if (std::isnan(arr[a])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { + return arr[left] < arr[right]; + } + else if (std::isnan(arr[left])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + // sort indices according to corresponding array element + return arr[left] < arr[right]; + }); +} + +/* + * Parition one ZMM register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx512(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + argtype::mask_compressstoreu( + arg + left, vtype::knot_opmask(gt_mask), arg_vec); + argtype::mask_compressstoreu( + arg + right - amount_gt_pivot, gt_mask, arg_vec); + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + return amount_gt_pivot; +} + +/* + * Parition one AVX2 register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec); + typename argtype::opmask_t ge_mask + = resize_mask(ge_mask_vtype); + + auto l_store = arg + left; + auto r_store = arg + right - vtype::numlanes; + + int amount_ge_pivot + = argtype::double_compressstore(l_store, r_store, ge_mask, arg_vec); + + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + + return amount_ge_pivot; +} + +template +X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + if constexpr (vtype::vec_type == simd_type::AVX512) { + return partition_vec_avx512(arg, + left, + right, + arg_vec, + curr_vec, + pivot_vec, + smallest_vec, + biggest_vec); + } + else if constexpr (vtype::vec_type == simd_type::AVX2) { + return partition_vec_avx2(arg, + left, + right, + arg_vec, + curr_vec, + pivot_vec, + smallest_vec, + biggest_vec); + } + else { + static_assert(sizeof(argreg_t) == 0, "Should not get here"); + } +} + +/* + * Parition an array based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE arrsize_t argpartition(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + using argreg_t = typename argtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + argreg_t argvec = argtype::loadu(arg + left); + reg_t vec = vtype::i64gather(arr, arg + left); + int32_t amount_gt_pivot + = partition_vec(arg, + left, + left + vtype::numlanes, + argvec, + vec, + pivot_vec, + &min_vec, + &max_vec); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return left + (vtype::numlanes - amount_gt_pivot); + } + + // first and last vtype::numlanes values are partitioned at the end + argreg_t argvec_left = argtype::loadu(arg + left); + reg_t vec_left = vtype::i64gather(arr, arg + left); + argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); + reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec; + reg_t curr_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= vtype::numlanes; + arg_vec = argtype::loadu(arg + right); + curr_vec = vtype::i64gather(arr, arg + right); + } + else { + arg_vec = argtype::loadu(arg + left); + curr_vec = vtype::i64gather(arr, arg + left); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec, + curr_vec, + pivot_vec, + &min_vec, + &max_vec); + ; + r_store -= amount_gt_pivot; + l_store += (vtype::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left, + vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(arg, + l_store, + l_store + vtype::numlanes, + argvec_right, + vec_right, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE arrsize_t argpartition_unrolled(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + if (right - left <= 8 * num_unroll * vtype::numlanes) { + return argpartition( + arr, arg, left, right, pivot, smallest, biggest); + } + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; + --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + using argreg_t = typename argtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + // first and last vtype::numlanes values are partitioned at the end + reg_t vec_left[num_unroll], vec_right[num_unroll]; + argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); + vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); + argvec_right[ii] = argtype::loadu( + arg + (right - vtype::numlanes * (num_unroll - ii))); + vec_right[ii] = vtype::i64gather( + arr, arg + (right - vtype::numlanes * (num_unroll - ii))); + } + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += num_unroll * vtype::numlanes; + right -= num_unroll * vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec[num_unroll]; + reg_t curr_vec[num_unroll]; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= num_unroll * vtype::numlanes; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] + = argtype::loadu(arg + right + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + right + ii * vtype::numlanes); + } + } + else { + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + left + ii * vtype::numlanes); + } + left += num_unroll * vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec[ii], + curr_vec[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + } + + /* partition and save vec_left and vec_right */ + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left[ii], + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_right[ii], + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, + arrsize_t *arg, + const arrsize_t left, + const arrsize_t right) +{ + if constexpr (vtype::numlanes == 8) { + if (right - left >= vtype::numlanes) { + // median of 8 + arrsize_t size = (right - left) / 8; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]], + arr[arg[left + 5 * size]], + arr[arg[left + 6 * size]], + arr[arg[left + 7 * size]], + arr[arg[left + 8 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[4]; + } + else { + return arr[arg[right]]; + } + } + else if constexpr (vtype::numlanes == 4) { + if (right - left >= vtype::numlanes) { + // median of 4 + arrsize_t size = (right - left) / 4; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[2]; + } + else { + return arr[arg[right]]; + } + } +} + +template +X86_SIMD_SORT_INLINE void argsort_(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters, + arrsize_t task_threshold) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 256) { + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = argpartition_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); +#ifdef XSS_COMPILE_OPENMP + if (pivot != smallest) { + bool parallel_left = (pivot_index - left) > task_threshold; + if (parallel_left) { +#pragma omp task + argsort_(arr, + arg, + left, + pivot_index - 1, + max_iters - 1, + task_threshold); + } + else { + argsort_(arr, + arg, + left, + pivot_index - 1, + max_iters - 1, + task_threshold); + } + } + if (pivot != biggest) { + bool parallel_right = (right - pivot_index) > task_threshold; + + if (parallel_right) { +#pragma omp task + argsort_(arr, + arg, + pivot_index, + right, + max_iters - 1, + task_threshold); + } + else { + argsort_(arr, + arg, + pivot_index, + right, + max_iters - 1, + task_threshold); + } + } +#else + UNUSED(task_threshold); + if (pivot != smallest) + argsort_( + arr, arg, left, pivot_index - 1, max_iters - 1, 0); + if (pivot != biggest) + argsort_( + arr, arg, pivot_index, right, max_iters - 1, 0); +#endif +} + +template +X86_SIMD_SORT_INLINE void argselect_(type_t *arr, + arrsize_t *arg, + arrsize_t pos, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 256) { + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = argpartition_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if ((pivot != smallest) && (pos < pivot_index)) + argselect_( + arr, arg, pos, left, pivot_index - 1, max_iters - 1); + else if ((pivot != biggest) && (pos >= pivot_index)) + argselect_( + arr, arg, pos, pivot_index, right, max_iters - 1); +} + +/* argsort methods for 32-bit and 64-bit dtypes */ +template + typename full_vector, + template + typename half_vector> +X86_SIMD_SORT_INLINE void xss_argsort(T *arr, + arrsize_t *arg, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) +{ + + using vectype = typename std::conditional, + full_vector>::type; + + using argtype = + typename std::conditional, + full_vector>::type; + + if (arrsize > 1) { + /* simdargsort does not work for float/double arrays with nan */ + if constexpr (xss::fp::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argsort_withnan(arr, arg, 0, arrsize); + + if (descending) { std::reverse(arg, arg + arrsize); } + + return; + } + } + UNUSED(hasnan); + + /* early exit for already sorted arrays: float/double with nan never reach here*/ + auto comp = descending ? Comparator::STDSortComparator + : Comparator::STDSortComparator; + if (std::is_sorted(arr, arr + arrsize, comp)) { return; } + +#ifdef XSS_COMPILE_OPENMP + + bool use_parallel = arrsize > 10000; + + if (use_parallel) { + int thread_count = xss_get_num_threads(); + arrsize_t task_threshold + = std::max((arrsize_t)10000, arrsize / 100); + + // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ + // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems + // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays +#pragma omp parallel num_threads(thread_count) +#pragma omp single + argsort_(arr, + arg, + 0, + arrsize - 1, + 2 * (arrsize_t)log2(arrsize), + task_threshold); +#pragma omp taskwait + } + else { + argsort_(arr, + arg, + 0, + arrsize - 1, + 2 * (arrsize_t)log2(arrsize), + std::numeric_limits::max()); + } +#else + argsort_( + arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); +#endif + + if (descending) { std::reverse(arg, arg + arrsize); } + } + +#ifdef __MMX__ + // Workaround for compiler bug generating MMX instructions without emms + _mm_empty(); +#endif +} + +template +X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, + arrsize_t *arg, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) +{ + xss_argsort( + arr, arg, arrsize, hasnan, descending); +} + +template +X86_SIMD_SORT_INLINE void avx2_argsort(T *arr, + arrsize_t *arg, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) +{ + xss_argsort( + arr, arg, arrsize, hasnan, descending); +} + +/* argselect methods for 32-bit and 64-bit dtypes */ +template + typename full_vector, + template + typename half_vector> +X86_SIMD_SORT_INLINE void xss_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */ + using vectype = typename std::conditional, + full_vector>::type; + + using argtype = + typename std::conditional, + full_vector>::type; + + if (arrsize > 1) { + if constexpr (xss::fp::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argselect_( + arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } + +#ifdef __MMX__ + // Workaround for compiler bug generating MMX instructions without emms + _mm_empty(); +#endif +} + +template +X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + xss_argselect(arr, arg, k, arrsize, hasnan); +} + +template +X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + xss_argselect( + arr, arg, k, arrsize, hasnan); +} + +#endif // XSS_COMMON_ARGSORT diff --git a/src/xss-common-comparators.hpp b/src/xss-common-comparators.hpp index e786c91..447709a 100644 --- a/src/xss-common-comparators.hpp +++ b/src/xss-common-comparators.hpp @@ -1,127 +1,127 @@ -#ifndef XSS_COMMON_COMPARATORS -#define XSS_COMMON_COMPARATORS - -template -type_t prev_value(type_t value) -{ - // TODO this probably handles non-native float16 wrong - if constexpr (std::is_floating_point::value) { - return std::nextafter(value, -std::numeric_limits::infinity()); - } - else { - if (value > std::numeric_limits::min()) { return value - 1; } - else { - return value; - } - } -} - -template -type_t next_value(type_t value) -{ - // TODO this probably handles non-native float16 wrong - if constexpr (std::is_floating_point::value) { - return std::nextafter(value, std::numeric_limits::infinity()); - } - else { - if (value < std::numeric_limits::max()) { return value + 1; } - else { - return value; - } - } -} - -template -X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); - -template -struct Comparator { - using reg_t = typename vtype::reg_t; - using opmask_t = typename vtype::opmask_t; - using type_t = typename vtype::type_t; - - X86_SIMD_SORT_INLINE bool STDSortComparator(const type_t &a, - const type_t &b) - { - if constexpr (descend) { return comparison_func(b, a); } - else { - return comparison_func(a, b); - } - } - - X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b) - { - if constexpr (descend) { return vtype::ge(b, a); } - else { - return vtype::ge(a, b); - } - } - - X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b) - { - if constexpr (descend) { ::COEX(b, a); } - else { - ::COEX(a, b); - } - } - - // Returns a vector of values that would be sorted as far right as possible - // For ascending order, this is the maximum possible value - X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec() - { - if constexpr (descend) { return vtype::zmm_min(); } - else { - return vtype::zmm_max(); - } - } - - // Returns the value that would be leftmost of the two when sorted - // For ascending order, that is the smaller value - X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t larger) - { - if constexpr (descend) { - UNUSED(smaller); - return larger; - } - else { - UNUSED(larger); - return smaller; - } - } - - // Returns the value that would be rightmost of the two when sorted - // For ascending order, that is the larger value - X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t larger) - { - if constexpr (descend) { - UNUSED(larger); - return smaller; - } - else { - UNUSED(smaller); - return larger; - } - } - - // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample - // Try just doing the next largest value greater than this seemingly very common value to seperate them out - X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median) - { - if constexpr (descend) { return median; } - else { - return next_value(median); - } - } - - // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample - // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition - X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median) - { - if constexpr (descend) { return prev_value(median); } - else { - return median; - } - } -}; - -#endif // XSS_COMMON_COMPARATORS +#ifndef XSS_COMMON_COMPARATORS +#define XSS_COMMON_COMPARATORS + +template +type_t prev_value(type_t value) +{ + // TODO this probably handles non-native float16 wrong + if constexpr (std::is_floating_point::value) { + return std::nextafter(value, -std::numeric_limits::infinity()); + } + else { + if (value > std::numeric_limits::min()) { return value - 1; } + else { + return value; + } + } +} + +template +type_t next_value(type_t value) +{ + // TODO this probably handles non-native float16 wrong + if constexpr (std::is_floating_point::value) { + return std::nextafter(value, std::numeric_limits::infinity()); + } + else { + if (value < std::numeric_limits::max()) { return value + 1; } + else { + return value; + } + } +} + +template +X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); + +template +struct Comparator { + using reg_t = typename vtype::reg_t; + using opmask_t = typename vtype::opmask_t; + using type_t = typename vtype::type_t; + + X86_SIMD_SORT_INLINE bool STDSortComparator(const type_t &a, + const type_t &b) + { + if constexpr (descend) { return comparison_func(b, a); } + else { + return comparison_func(a, b); + } + } + + X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b) + { + if constexpr (descend) { return vtype::ge(b, a); } + else { + return vtype::ge(a, b); + } + } + + X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b) + { + if constexpr (descend) { ::COEX(b, a); } + else { + ::COEX(a, b); + } + } + + // Returns a vector of values that would be sorted as far right as possible + // For ascending order, this is the maximum possible value + X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec() + { + if constexpr (descend) { return vtype::zmm_min(); } + else { + return vtype::zmm_max(); + } + } + + // Returns the value that would be leftmost of the two when sorted + // For ascending order, that is the smaller value + X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t larger) + { + if constexpr (descend) { + UNUSED(smaller); + return larger; + } + else { + UNUSED(larger); + return smaller; + } + } + + // Returns the value that would be rightmost of the two when sorted + // For ascending order, that is the larger value + X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t larger) + { + if constexpr (descend) { + UNUSED(larger); + return smaller; + } + else { + UNUSED(smaller); + return larger; + } + } + + // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample + // Try just doing the next largest value greater than this seemingly very common value to seperate them out + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median) + { + if constexpr (descend) { return median; } + else { + return next_value(median); + } + } + + // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample + // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median) + { + if constexpr (descend) { return prev_value(median); } + else { + return median; + } + } +}; + +#endif // XSS_COMMON_COMPARATORS diff --git a/src/xss-network-qsort.hpp b/src/xss-network-qsort.hpp index 0c1d1d8..4bceda1 100644 --- a/src/xss-network-qsort.hpp +++ b/src/xss-network-qsort.hpp @@ -1,235 +1,235 @@ -#ifndef XSS_NETWORK_QSORT -#define XSS_NETWORK_QSORT - -#include "xss-optimal-networks.hpp" - -template -X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); - -template -X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) -{ - if constexpr (numVecs == 1) { - UNUSED(regs); - return; - } - else if constexpr (numVecs == 2) { - comparator::COEX(regs[0], regs[1]); - } - else if constexpr (numVecs == 4) { - optimal_sort_4(regs); - } - else if constexpr (numVecs == 8) { - optimal_sort_8(regs); - } - else if constexpr (numVecs == 16) { - optimal_sort_16(regs); - } - else if constexpr (numVecs == 32) { - optimal_sort_32(regs); - } - else { - static_assert(numVecs == -1, "should not reach here"); - } -} - -/* - * Swizzle ops explained: - * swap_n: swap neighbouring blocks of size within block of size - * reg i = [7,6,5,4,3,2,1,0] - * swap_n<2>: = [[6,7],[4,5],[2,3],[0,1]] - * swap_n<4>: = [[5,4,7,6],[1,0,3,2]] - * swap_n<8>: = [[3,2,1,0,7,6,5,4]] - * reverse_n: reverse elements within block of size - * reg i = [7,6,5,4,3,2,1,0] - * rev_n<2>: = [[6,7],[4,5],[2,3],[0,1]] - * rev_n<4>: = [[4,5,6,7],[0,1,2,3]] - * rev_n<8>: = [[0,1,2,3,4,5,6,7]] - * merge_n: merge blocks of elements from two regs - * reg b,a = [a,a,a,a,a,a,a,a], [b,b,b,b,b,b,b,b] - * merge_n<2> = [a,b,a,b,a,b,a,b] - * merge_n<4> = [a,a,b,b,a,a,b,b] - * merge_n<8> = [a,a,a,a,b,b,b,b] - */ - -template -X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg) -{ - using reg_t = typename vtype::reg_t; - using swizzle = typename vtype::swizzle_ops; - if constexpr (scale <= 1) { - UNUSED(reg); - return; - } - else { - if constexpr (first) { - // Use reverse then merge - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs; i++) { - reg_t &v = reg[i]; - reg_t rev = swizzle::template reverse_n(v); - comparator::COEX(rev, v); - v = swizzle::template merge_n(v, rev); - } - } - else { - // Use swap then merge - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs; i++) { - reg_t &v = reg[i]; - reg_t swap = swizzle::template swap_n(v); - comparator::COEX(swap, v); - v = swizzle::template merge_n(v, swap); - } - } - internal_merge_n_vec(reg); - } -} - -template -X86_SIMD_SORT_FINLINE void merge_substep_n_vec(reg_t *regs) -{ - using swizzle = typename vtype::swizzle_ops; - if constexpr (numVecs <= 1) { - UNUSED(regs); - return; - } - - // Reverse upper half of vectors - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2; i < numVecs; i++) { - regs[i] = swizzle::template reverse_n(regs[i]); - } - // Do compare exchanges - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs / 2; i++) { - comparator::COEX(regs[i], regs[numVecs - 1 - i]); - } - - merge_substep_n_vec(regs); - merge_substep_n_vec(regs - + numVecs / 2); -} - -template -X86_SIMD_SORT_FINLINE void merge_step_n_vec(reg_t *regs) -{ - // Do cross vector merges - merge_substep_n_vec(regs); - - // Do internal vector merges - internal_merge_n_vec(regs); -} - -template -X86_SIMD_SORT_FINLINE void merge_n_vec(reg_t *regs) -{ - if constexpr (numPer > vtype::numlanes) { - UNUSED(regs); - return; - } - else { - merge_step_n_vec(regs); - merge_n_vec(regs); - } -} - -template -X86_SIMD_SORT_FINLINE void sort_vectors(reg_t *vecs) -{ - /* Run the initial sorting network to sort the columns of the [numVecs x - * num_lanes] matrix - */ - bitonic_sort_n_vec(vecs); - - // Merge the vectors using bitonic merging networks - merge_n_vec(vecs); -} - -template -X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) -{ - static_assert(numVecs > 0, "numVecs should be > 0"); - if constexpr (numVecs > 1) { - if (N * 2 <= numVecs * vtype::numlanes) { - sort_n_vec(arr, N); - return; - } - } - - reg_t vecs[numVecs]; - - // Generate masks for loading and storing - typename vtype::opmask_t ioMasks[numVecs - numVecs / 2]; - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - uint64_t num_to_read - = std::min((uint64_t)std::max(0, N - i * vtype::numlanes), - (uint64_t)vtype::numlanes); - ioMasks[j] = vtype::get_partial_loadmask(num_to_read); - } - - // Unmasked part of the load - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs / 2; i++) { - vecs[i] = vtype::loadu(arr + i * vtype::numlanes); - } - // Masked part of the load - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - vecs[i] = vtype::mask_loadu(comparator::rightmostPossibleVec(), - ioMasks[j], - arr + i * vtype::numlanes); - } - - sort_vectors(vecs); - - // Unmasked part of the store - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs / 2; i++) { - vtype::storeu(arr + i * vtype::numlanes, vecs[i]); - } - // Masked part of the store - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - vtype::mask_storeu(arr + i * vtype::numlanes, ioMasks[j], vecs[i]); - } -} - -template -X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N) -{ - constexpr int numVecs = maxN / vtype::numlanes; - constexpr bool isMultiple = (maxN == (vtype::numlanes * numVecs)); - constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); - static_assert(powerOfTwo == true && isMultiple == true, - "maxN must be vtype::numlanes times a power of 2"); - - sort_n_vec(arr, N); -} -#endif +#ifndef XSS_NETWORK_QSORT +#define XSS_NETWORK_QSORT + +#include "xss-optimal-networks.hpp" + +template +X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); + +template +X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) +{ + if constexpr (numVecs == 1) { + UNUSED(regs); + return; + } + else if constexpr (numVecs == 2) { + comparator::COEX(regs[0], regs[1]); + } + else if constexpr (numVecs == 4) { + optimal_sort_4(regs); + } + else if constexpr (numVecs == 8) { + optimal_sort_8(regs); + } + else if constexpr (numVecs == 16) { + optimal_sort_16(regs); + } + else if constexpr (numVecs == 32) { + optimal_sort_32(regs); + } + else { + static_assert(numVecs == -1, "should not reach here"); + } +} + +/* + * Swizzle ops explained: + * swap_n: swap neighbouring blocks of size within block of size + * reg i = [7,6,5,4,3,2,1,0] + * swap_n<2>: = [[6,7],[4,5],[2,3],[0,1]] + * swap_n<4>: = [[5,4,7,6],[1,0,3,2]] + * swap_n<8>: = [[3,2,1,0,7,6,5,4]] + * reverse_n: reverse elements within block of size + * reg i = [7,6,5,4,3,2,1,0] + * rev_n<2>: = [[6,7],[4,5],[2,3],[0,1]] + * rev_n<4>: = [[4,5,6,7],[0,1,2,3]] + * rev_n<8>: = [[0,1,2,3,4,5,6,7]] + * merge_n: merge blocks of elements from two regs + * reg b,a = [a,a,a,a,a,a,a,a], [b,b,b,b,b,b,b,b] + * merge_n<2> = [a,b,a,b,a,b,a,b] + * merge_n<4> = [a,a,b,b,a,a,b,b] + * merge_n<8> = [a,a,a,a,b,b,b,b] + */ + +template +X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg) +{ + using reg_t = typename vtype::reg_t; + using swizzle = typename vtype::swizzle_ops; + if constexpr (scale <= 1) { + UNUSED(reg); + return; + } + else { + if constexpr (first) { + // Use reverse then merge + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + reg_t &v = reg[i]; + reg_t rev = swizzle::template reverse_n(v); + comparator::COEX(rev, v); + v = swizzle::template merge_n(v, rev); + } + } + else { + // Use swap then merge + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + reg_t &v = reg[i]; + reg_t swap = swizzle::template swap_n(v); + comparator::COEX(swap, v); + v = swizzle::template merge_n(v, swap); + } + } + internal_merge_n_vec(reg); + } +} + +template +X86_SIMD_SORT_FINLINE void merge_substep_n_vec(reg_t *regs) +{ + using swizzle = typename vtype::swizzle_ops; + if constexpr (numVecs <= 1) { + UNUSED(regs); + return; + } + + // Reverse upper half of vectors + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2; i < numVecs; i++) { + regs[i] = swizzle::template reverse_n(regs[i]); + } + // Do compare exchanges + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + comparator::COEX(regs[i], regs[numVecs - 1 - i]); + } + + merge_substep_n_vec(regs); + merge_substep_n_vec(regs + + numVecs / 2); +} + +template +X86_SIMD_SORT_FINLINE void merge_step_n_vec(reg_t *regs) +{ + // Do cross vector merges + merge_substep_n_vec(regs); + + // Do internal vector merges + internal_merge_n_vec(regs); +} + +template +X86_SIMD_SORT_FINLINE void merge_n_vec(reg_t *regs) +{ + if constexpr (numPer > vtype::numlanes) { + UNUSED(regs); + return; + } + else { + merge_step_n_vec(regs); + merge_n_vec(regs); + } +} + +template +X86_SIMD_SORT_FINLINE void sort_vectors(reg_t *vecs) +{ + /* Run the initial sorting network to sort the columns of the [numVecs x + * num_lanes] matrix + */ + bitonic_sort_n_vec(vecs); + + // Merge the vectors using bitonic merging networks + merge_n_vec(vecs); +} + +template +X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) +{ + static_assert(numVecs > 0, "numVecs should be > 0"); + if constexpr (numVecs > 1) { + if (N * 2 <= numVecs * vtype::numlanes) { + sort_n_vec(arr, N); + return; + } + } + + reg_t vecs[numVecs]; + + // Generate masks for loading and storing + typename vtype::opmask_t ioMasks[numVecs - numVecs / 2]; + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * vtype::numlanes), + (uint64_t)vtype::numlanes); + ioMasks[j] = vtype::get_partial_loadmask(num_to_read); + } + + // Unmasked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + vecs[i] = vtype::loadu(arr + i * vtype::numlanes); + } + // Masked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + vecs[i] = vtype::mask_loadu(comparator::rightmostPossibleVec(), + ioMasks[j], + arr + i * vtype::numlanes); + } + + sort_vectors(vecs); + + // Unmasked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + vtype::storeu(arr + i * vtype::numlanes, vecs[i]); + } + // Masked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + vtype::mask_storeu(arr + i * vtype::numlanes, ioMasks[j], vecs[i]); + } +} + +template +X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N) +{ + constexpr int numVecs = maxN / vtype::numlanes; + constexpr bool isMultiple = (maxN == (vtype::numlanes * numVecs)); + constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); + static_assert(powerOfTwo == true && isMultiple == true, + "maxN must be vtype::numlanes times a power of 2"); + + sort_n_vec(arr, N); +} +#endif diff --git a/src/xss-optimal-networks.hpp b/src/xss-optimal-networks.hpp index e722b1f..1bed39d 100644 --- a/src/xss-optimal-networks.hpp +++ b/src/xss-optimal-networks.hpp @@ -1,328 +1,328 @@ -// All of these sources files are generated from the optimal networks described in -// https://bertdobbelaere.github.io/sorting_networks.html - -template -X86_SIMD_SORT_FINLINE void optimal_sort_4(reg_t *vecs) -{ - comparator::COEX(vecs[0], vecs[2]); - comparator::COEX(vecs[1], vecs[3]); - - comparator::COEX(vecs[0], vecs[1]); - comparator::COEX(vecs[2], vecs[3]); - - comparator::COEX(vecs[1], vecs[2]); -} - -template -X86_SIMD_SORT_FINLINE void optimal_sort_8(reg_t *vecs) -{ - comparator::COEX(vecs[0], vecs[2]); - comparator::COEX(vecs[1], vecs[3]); - comparator::COEX(vecs[4], vecs[6]); - comparator::COEX(vecs[5], vecs[7]); - - comparator::COEX(vecs[0], vecs[4]); - comparator::COEX(vecs[1], vecs[5]); - comparator::COEX(vecs[2], vecs[6]); - comparator::COEX(vecs[3], vecs[7]); - - comparator::COEX(vecs[0], vecs[1]); - comparator::COEX(vecs[2], vecs[3]); - comparator::COEX(vecs[4], vecs[5]); - comparator::COEX(vecs[6], vecs[7]); - - comparator::COEX(vecs[2], vecs[4]); - comparator::COEX(vecs[3], vecs[5]); - - comparator::COEX(vecs[1], vecs[4]); - comparator::COEX(vecs[3], vecs[6]); - - comparator::COEX(vecs[1], vecs[2]); - comparator::COEX(vecs[3], vecs[4]); - comparator::COEX(vecs[5], vecs[6]); -} - -template -X86_SIMD_SORT_FINLINE void optimal_sort_16(reg_t *vecs) -{ - comparator::COEX(vecs[0], vecs[13]); - comparator::COEX(vecs[1], vecs[12]); - comparator::COEX(vecs[2], vecs[15]); - comparator::COEX(vecs[3], vecs[14]); - comparator::COEX(vecs[4], vecs[8]); - comparator::COEX(vecs[5], vecs[6]); - comparator::COEX(vecs[7], vecs[11]); - comparator::COEX(vecs[9], vecs[10]); - - comparator::COEX(vecs[0], vecs[5]); - comparator::COEX(vecs[1], vecs[7]); - comparator::COEX(vecs[2], vecs[9]); - comparator::COEX(vecs[3], vecs[4]); - comparator::COEX(vecs[6], vecs[13]); - comparator::COEX(vecs[8], vecs[14]); - comparator::COEX(vecs[10], vecs[15]); - comparator::COEX(vecs[11], vecs[12]); - - comparator::COEX(vecs[0], vecs[1]); - comparator::COEX(vecs[2], vecs[3]); - comparator::COEX(vecs[4], vecs[5]); - comparator::COEX(vecs[6], vecs[8]); - comparator::COEX(vecs[7], vecs[9]); - comparator::COEX(vecs[10], vecs[11]); - comparator::COEX(vecs[12], vecs[13]); - comparator::COEX(vecs[14], vecs[15]); - - comparator::COEX(vecs[0], vecs[2]); - comparator::COEX(vecs[1], vecs[3]); - comparator::COEX(vecs[4], vecs[10]); - comparator::COEX(vecs[5], vecs[11]); - comparator::COEX(vecs[6], vecs[7]); - comparator::COEX(vecs[8], vecs[9]); - comparator::COEX(vecs[12], vecs[14]); - comparator::COEX(vecs[13], vecs[15]); - - comparator::COEX(vecs[1], vecs[2]); - comparator::COEX(vecs[3], vecs[12]); - comparator::COEX(vecs[4], vecs[6]); - comparator::COEX(vecs[5], vecs[7]); - comparator::COEX(vecs[8], vecs[10]); - comparator::COEX(vecs[9], vecs[11]); - comparator::COEX(vecs[13], vecs[14]); - - comparator::COEX(vecs[1], vecs[4]); - comparator::COEX(vecs[2], vecs[6]); - comparator::COEX(vecs[5], vecs[8]); - comparator::COEX(vecs[7], vecs[10]); - comparator::COEX(vecs[9], vecs[13]); - comparator::COEX(vecs[11], vecs[14]); - - comparator::COEX(vecs[2], vecs[4]); - comparator::COEX(vecs[3], vecs[6]); - comparator::COEX(vecs[9], vecs[12]); - comparator::COEX(vecs[11], vecs[13]); - - comparator::COEX(vecs[3], vecs[5]); - comparator::COEX(vecs[6], vecs[8]); - comparator::COEX(vecs[7], vecs[9]); - comparator::COEX(vecs[10], vecs[12]); - - comparator::COEX(vecs[3], vecs[4]); - comparator::COEX(vecs[5], vecs[6]); - comparator::COEX(vecs[7], vecs[8]); - comparator::COEX(vecs[9], vecs[10]); - comparator::COEX(vecs[11], vecs[12]); - - comparator::COEX(vecs[6], vecs[7]); - comparator::COEX(vecs[8], vecs[9]); -} - -template -X86_SIMD_SORT_FINLINE void optimal_sort_32(reg_t *vecs) -{ - comparator::COEX(vecs[0], vecs[1]); - comparator::COEX(vecs[2], vecs[3]); - comparator::COEX(vecs[4], vecs[5]); - comparator::COEX(vecs[6], vecs[7]); - comparator::COEX(vecs[8], vecs[9]); - comparator::COEX(vecs[10], vecs[11]); - comparator::COEX(vecs[12], vecs[13]); - comparator::COEX(vecs[14], vecs[15]); - comparator::COEX(vecs[16], vecs[17]); - comparator::COEX(vecs[18], vecs[19]); - comparator::COEX(vecs[20], vecs[21]); - comparator::COEX(vecs[22], vecs[23]); - comparator::COEX(vecs[24], vecs[25]); - comparator::COEX(vecs[26], vecs[27]); - comparator::COEX(vecs[28], vecs[29]); - comparator::COEX(vecs[30], vecs[31]); - - comparator::COEX(vecs[0], vecs[2]); - comparator::COEX(vecs[1], vecs[3]); - comparator::COEX(vecs[4], vecs[6]); - comparator::COEX(vecs[5], vecs[7]); - comparator::COEX(vecs[8], vecs[10]); - comparator::COEX(vecs[9], vecs[11]); - comparator::COEX(vecs[12], vecs[14]); - comparator::COEX(vecs[13], vecs[15]); - comparator::COEX(vecs[16], vecs[18]); - comparator::COEX(vecs[17], vecs[19]); - comparator::COEX(vecs[20], vecs[22]); - comparator::COEX(vecs[21], vecs[23]); - comparator::COEX(vecs[24], vecs[26]); - comparator::COEX(vecs[25], vecs[27]); - comparator::COEX(vecs[28], vecs[30]); - comparator::COEX(vecs[29], vecs[31]); - - comparator::COEX(vecs[0], vecs[4]); - comparator::COEX(vecs[1], vecs[5]); - comparator::COEX(vecs[2], vecs[6]); - comparator::COEX(vecs[3], vecs[7]); - comparator::COEX(vecs[8], vecs[12]); - comparator::COEX(vecs[9], vecs[13]); - comparator::COEX(vecs[10], vecs[14]); - comparator::COEX(vecs[11], vecs[15]); - comparator::COEX(vecs[16], vecs[20]); - comparator::COEX(vecs[17], vecs[21]); - comparator::COEX(vecs[18], vecs[22]); - comparator::COEX(vecs[19], vecs[23]); - comparator::COEX(vecs[24], vecs[28]); - comparator::COEX(vecs[25], vecs[29]); - comparator::COEX(vecs[26], vecs[30]); - comparator::COEX(vecs[27], vecs[31]); - - comparator::COEX(vecs[0], vecs[8]); - comparator::COEX(vecs[1], vecs[9]); - comparator::COEX(vecs[2], vecs[10]); - comparator::COEX(vecs[3], vecs[11]); - comparator::COEX(vecs[4], vecs[12]); - comparator::COEX(vecs[5], vecs[13]); - comparator::COEX(vecs[6], vecs[14]); - comparator::COEX(vecs[7], vecs[15]); - comparator::COEX(vecs[16], vecs[24]); - comparator::COEX(vecs[17], vecs[25]); - comparator::COEX(vecs[18], vecs[26]); - comparator::COEX(vecs[19], vecs[27]); - comparator::COEX(vecs[20], vecs[28]); - comparator::COEX(vecs[21], vecs[29]); - comparator::COEX(vecs[22], vecs[30]); - comparator::COEX(vecs[23], vecs[31]); - - comparator::COEX(vecs[0], vecs[16]); - comparator::COEX(vecs[1], vecs[8]); - comparator::COEX(vecs[2], vecs[4]); - comparator::COEX(vecs[3], vecs[12]); - comparator::COEX(vecs[5], vecs[10]); - comparator::COEX(vecs[6], vecs[9]); - comparator::COEX(vecs[7], vecs[14]); - comparator::COEX(vecs[11], vecs[13]); - comparator::COEX(vecs[15], vecs[31]); - comparator::COEX(vecs[17], vecs[24]); - comparator::COEX(vecs[18], vecs[20]); - comparator::COEX(vecs[19], vecs[28]); - comparator::COEX(vecs[21], vecs[26]); - comparator::COEX(vecs[22], vecs[25]); - comparator::COEX(vecs[23], vecs[30]); - comparator::COEX(vecs[27], vecs[29]); - - comparator::COEX(vecs[1], vecs[2]); - comparator::COEX(vecs[3], vecs[5]); - comparator::COEX(vecs[4], vecs[8]); - comparator::COEX(vecs[6], vecs[22]); - comparator::COEX(vecs[7], vecs[11]); - comparator::COEX(vecs[9], vecs[25]); - comparator::COEX(vecs[10], vecs[12]); - comparator::COEX(vecs[13], vecs[14]); - comparator::COEX(vecs[17], vecs[18]); - comparator::COEX(vecs[19], vecs[21]); - comparator::COEX(vecs[20], vecs[24]); - comparator::COEX(vecs[23], vecs[27]); - comparator::COEX(vecs[26], vecs[28]); - comparator::COEX(vecs[29], vecs[30]); - - comparator::COEX(vecs[1], vecs[17]); - comparator::COEX(vecs[2], vecs[18]); - comparator::COEX(vecs[3], vecs[19]); - comparator::COEX(vecs[4], vecs[20]); - comparator::COEX(vecs[5], vecs[10]); - comparator::COEX(vecs[7], vecs[23]); - comparator::COEX(vecs[8], vecs[24]); - comparator::COEX(vecs[11], vecs[27]); - comparator::COEX(vecs[12], vecs[28]); - comparator::COEX(vecs[13], vecs[29]); - comparator::COEX(vecs[14], vecs[30]); - comparator::COEX(vecs[21], vecs[26]); - - comparator::COEX(vecs[3], vecs[17]); - comparator::COEX(vecs[4], vecs[16]); - comparator::COEX(vecs[5], vecs[21]); - comparator::COEX(vecs[6], vecs[18]); - comparator::COEX(vecs[7], vecs[9]); - comparator::COEX(vecs[8], vecs[20]); - comparator::COEX(vecs[10], vecs[26]); - comparator::COEX(vecs[11], vecs[23]); - comparator::COEX(vecs[13], vecs[25]); - comparator::COEX(vecs[14], vecs[28]); - comparator::COEX(vecs[15], vecs[27]); - comparator::COEX(vecs[22], vecs[24]); - - comparator::COEX(vecs[1], vecs[4]); - comparator::COEX(vecs[3], vecs[8]); - comparator::COEX(vecs[5], vecs[16]); - comparator::COEX(vecs[7], vecs[17]); - comparator::COEX(vecs[9], vecs[21]); - comparator::COEX(vecs[10], vecs[22]); - comparator::COEX(vecs[11], vecs[19]); - comparator::COEX(vecs[12], vecs[20]); - comparator::COEX(vecs[14], vecs[24]); - comparator::COEX(vecs[15], vecs[26]); - comparator::COEX(vecs[23], vecs[28]); - comparator::COEX(vecs[27], vecs[30]); - - comparator::COEX(vecs[2], vecs[5]); - comparator::COEX(vecs[7], vecs[8]); - comparator::COEX(vecs[9], vecs[18]); - comparator::COEX(vecs[11], vecs[17]); - comparator::COEX(vecs[12], vecs[16]); - comparator::COEX(vecs[13], vecs[22]); - comparator::COEX(vecs[14], vecs[20]); - comparator::COEX(vecs[15], vecs[19]); - comparator::COEX(vecs[23], vecs[24]); - comparator::COEX(vecs[26], vecs[29]); - - comparator::COEX(vecs[2], vecs[4]); - comparator::COEX(vecs[6], vecs[12]); - comparator::COEX(vecs[9], vecs[16]); - comparator::COEX(vecs[10], vecs[11]); - comparator::COEX(vecs[13], vecs[17]); - comparator::COEX(vecs[14], vecs[18]); - comparator::COEX(vecs[15], vecs[22]); - comparator::COEX(vecs[19], vecs[25]); - comparator::COEX(vecs[20], vecs[21]); - comparator::COEX(vecs[27], vecs[29]); - - comparator::COEX(vecs[5], vecs[6]); - comparator::COEX(vecs[8], vecs[12]); - comparator::COEX(vecs[9], vecs[10]); - comparator::COEX(vecs[11], vecs[13]); - comparator::COEX(vecs[14], vecs[16]); - comparator::COEX(vecs[15], vecs[17]); - comparator::COEX(vecs[18], vecs[20]); - comparator::COEX(vecs[19], vecs[23]); - comparator::COEX(vecs[21], vecs[22]); - comparator::COEX(vecs[25], vecs[26]); - - comparator::COEX(vecs[3], vecs[5]); - comparator::COEX(vecs[6], vecs[7]); - comparator::COEX(vecs[8], vecs[9]); - comparator::COEX(vecs[10], vecs[12]); - comparator::COEX(vecs[11], vecs[14]); - comparator::COEX(vecs[13], vecs[16]); - comparator::COEX(vecs[15], vecs[18]); - comparator::COEX(vecs[17], vecs[20]); - comparator::COEX(vecs[19], vecs[21]); - comparator::COEX(vecs[22], vecs[23]); - comparator::COEX(vecs[24], vecs[25]); - comparator::COEX(vecs[26], vecs[28]); - - comparator::COEX(vecs[3], vecs[4]); - comparator::COEX(vecs[5], vecs[6]); - comparator::COEX(vecs[7], vecs[8]); - comparator::COEX(vecs[9], vecs[10]); - comparator::COEX(vecs[11], vecs[12]); - comparator::COEX(vecs[13], vecs[14]); - comparator::COEX(vecs[15], vecs[16]); - comparator::COEX(vecs[17], vecs[18]); - comparator::COEX(vecs[19], vecs[20]); - comparator::COEX(vecs[21], vecs[22]); - comparator::COEX(vecs[23], vecs[24]); - comparator::COEX(vecs[25], vecs[26]); - comparator::COEX(vecs[27], vecs[28]); -} +// All of these sources files are generated from the optimal networks described in +// https://bertdobbelaere.github.io/sorting_networks.html + +template +X86_SIMD_SORT_FINLINE void optimal_sort_4(reg_t *vecs) +{ + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); + + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); + + comparator::COEX(vecs[1], vecs[2]); +} + +template +X86_SIMD_SORT_FINLINE void optimal_sort_8(reg_t *vecs) +{ + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); + comparator::COEX(vecs[4], vecs[6]); + comparator::COEX(vecs[5], vecs[7]); + + comparator::COEX(vecs[0], vecs[4]); + comparator::COEX(vecs[1], vecs[5]); + comparator::COEX(vecs[2], vecs[6]); + comparator::COEX(vecs[3], vecs[7]); + + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); + comparator::COEX(vecs[4], vecs[5]); + comparator::COEX(vecs[6], vecs[7]); + + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[3], vecs[5]); + + comparator::COEX(vecs[1], vecs[4]); + comparator::COEX(vecs[3], vecs[6]); + + comparator::COEX(vecs[1], vecs[2]); + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[5], vecs[6]); +} + +template +X86_SIMD_SORT_FINLINE void optimal_sort_16(reg_t *vecs) +{ + comparator::COEX(vecs[0], vecs[13]); + comparator::COEX(vecs[1], vecs[12]); + comparator::COEX(vecs[2], vecs[15]); + comparator::COEX(vecs[3], vecs[14]); + comparator::COEX(vecs[4], vecs[8]); + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[7], vecs[11]); + comparator::COEX(vecs[9], vecs[10]); + + comparator::COEX(vecs[0], vecs[5]); + comparator::COEX(vecs[1], vecs[7]); + comparator::COEX(vecs[2], vecs[9]); + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[6], vecs[13]); + comparator::COEX(vecs[8], vecs[14]); + comparator::COEX(vecs[10], vecs[15]); + comparator::COEX(vecs[11], vecs[12]); + + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); + comparator::COEX(vecs[4], vecs[5]); + comparator::COEX(vecs[6], vecs[8]); + comparator::COEX(vecs[7], vecs[9]); + comparator::COEX(vecs[10], vecs[11]); + comparator::COEX(vecs[12], vecs[13]); + comparator::COEX(vecs[14], vecs[15]); + + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); + comparator::COEX(vecs[4], vecs[10]); + comparator::COEX(vecs[5], vecs[11]); + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); + comparator::COEX(vecs[12], vecs[14]); + comparator::COEX(vecs[13], vecs[15]); + + comparator::COEX(vecs[1], vecs[2]); + comparator::COEX(vecs[3], vecs[12]); + comparator::COEX(vecs[4], vecs[6]); + comparator::COEX(vecs[5], vecs[7]); + comparator::COEX(vecs[8], vecs[10]); + comparator::COEX(vecs[9], vecs[11]); + comparator::COEX(vecs[13], vecs[14]); + + comparator::COEX(vecs[1], vecs[4]); + comparator::COEX(vecs[2], vecs[6]); + comparator::COEX(vecs[5], vecs[8]); + comparator::COEX(vecs[7], vecs[10]); + comparator::COEX(vecs[9], vecs[13]); + comparator::COEX(vecs[11], vecs[14]); + + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[3], vecs[6]); + comparator::COEX(vecs[9], vecs[12]); + comparator::COEX(vecs[11], vecs[13]); + + comparator::COEX(vecs[3], vecs[5]); + comparator::COEX(vecs[6], vecs[8]); + comparator::COEX(vecs[7], vecs[9]); + comparator::COEX(vecs[10], vecs[12]); + + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[7], vecs[8]); + comparator::COEX(vecs[9], vecs[10]); + comparator::COEX(vecs[11], vecs[12]); + + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); +} + +template +X86_SIMD_SORT_FINLINE void optimal_sort_32(reg_t *vecs) +{ + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); + comparator::COEX(vecs[4], vecs[5]); + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); + comparator::COEX(vecs[10], vecs[11]); + comparator::COEX(vecs[12], vecs[13]); + comparator::COEX(vecs[14], vecs[15]); + comparator::COEX(vecs[16], vecs[17]); + comparator::COEX(vecs[18], vecs[19]); + comparator::COEX(vecs[20], vecs[21]); + comparator::COEX(vecs[22], vecs[23]); + comparator::COEX(vecs[24], vecs[25]); + comparator::COEX(vecs[26], vecs[27]); + comparator::COEX(vecs[28], vecs[29]); + comparator::COEX(vecs[30], vecs[31]); + + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); + comparator::COEX(vecs[4], vecs[6]); + comparator::COEX(vecs[5], vecs[7]); + comparator::COEX(vecs[8], vecs[10]); + comparator::COEX(vecs[9], vecs[11]); + comparator::COEX(vecs[12], vecs[14]); + comparator::COEX(vecs[13], vecs[15]); + comparator::COEX(vecs[16], vecs[18]); + comparator::COEX(vecs[17], vecs[19]); + comparator::COEX(vecs[20], vecs[22]); + comparator::COEX(vecs[21], vecs[23]); + comparator::COEX(vecs[24], vecs[26]); + comparator::COEX(vecs[25], vecs[27]); + comparator::COEX(vecs[28], vecs[30]); + comparator::COEX(vecs[29], vecs[31]); + + comparator::COEX(vecs[0], vecs[4]); + comparator::COEX(vecs[1], vecs[5]); + comparator::COEX(vecs[2], vecs[6]); + comparator::COEX(vecs[3], vecs[7]); + comparator::COEX(vecs[8], vecs[12]); + comparator::COEX(vecs[9], vecs[13]); + comparator::COEX(vecs[10], vecs[14]); + comparator::COEX(vecs[11], vecs[15]); + comparator::COEX(vecs[16], vecs[20]); + comparator::COEX(vecs[17], vecs[21]); + comparator::COEX(vecs[18], vecs[22]); + comparator::COEX(vecs[19], vecs[23]); + comparator::COEX(vecs[24], vecs[28]); + comparator::COEX(vecs[25], vecs[29]); + comparator::COEX(vecs[26], vecs[30]); + comparator::COEX(vecs[27], vecs[31]); + + comparator::COEX(vecs[0], vecs[8]); + comparator::COEX(vecs[1], vecs[9]); + comparator::COEX(vecs[2], vecs[10]); + comparator::COEX(vecs[3], vecs[11]); + comparator::COEX(vecs[4], vecs[12]); + comparator::COEX(vecs[5], vecs[13]); + comparator::COEX(vecs[6], vecs[14]); + comparator::COEX(vecs[7], vecs[15]); + comparator::COEX(vecs[16], vecs[24]); + comparator::COEX(vecs[17], vecs[25]); + comparator::COEX(vecs[18], vecs[26]); + comparator::COEX(vecs[19], vecs[27]); + comparator::COEX(vecs[20], vecs[28]); + comparator::COEX(vecs[21], vecs[29]); + comparator::COEX(vecs[22], vecs[30]); + comparator::COEX(vecs[23], vecs[31]); + + comparator::COEX(vecs[0], vecs[16]); + comparator::COEX(vecs[1], vecs[8]); + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[3], vecs[12]); + comparator::COEX(vecs[5], vecs[10]); + comparator::COEX(vecs[6], vecs[9]); + comparator::COEX(vecs[7], vecs[14]); + comparator::COEX(vecs[11], vecs[13]); + comparator::COEX(vecs[15], vecs[31]); + comparator::COEX(vecs[17], vecs[24]); + comparator::COEX(vecs[18], vecs[20]); + comparator::COEX(vecs[19], vecs[28]); + comparator::COEX(vecs[21], vecs[26]); + comparator::COEX(vecs[22], vecs[25]); + comparator::COEX(vecs[23], vecs[30]); + comparator::COEX(vecs[27], vecs[29]); + + comparator::COEX(vecs[1], vecs[2]); + comparator::COEX(vecs[3], vecs[5]); + comparator::COEX(vecs[4], vecs[8]); + comparator::COEX(vecs[6], vecs[22]); + comparator::COEX(vecs[7], vecs[11]); + comparator::COEX(vecs[9], vecs[25]); + comparator::COEX(vecs[10], vecs[12]); + comparator::COEX(vecs[13], vecs[14]); + comparator::COEX(vecs[17], vecs[18]); + comparator::COEX(vecs[19], vecs[21]); + comparator::COEX(vecs[20], vecs[24]); + comparator::COEX(vecs[23], vecs[27]); + comparator::COEX(vecs[26], vecs[28]); + comparator::COEX(vecs[29], vecs[30]); + + comparator::COEX(vecs[1], vecs[17]); + comparator::COEX(vecs[2], vecs[18]); + comparator::COEX(vecs[3], vecs[19]); + comparator::COEX(vecs[4], vecs[20]); + comparator::COEX(vecs[5], vecs[10]); + comparator::COEX(vecs[7], vecs[23]); + comparator::COEX(vecs[8], vecs[24]); + comparator::COEX(vecs[11], vecs[27]); + comparator::COEX(vecs[12], vecs[28]); + comparator::COEX(vecs[13], vecs[29]); + comparator::COEX(vecs[14], vecs[30]); + comparator::COEX(vecs[21], vecs[26]); + + comparator::COEX(vecs[3], vecs[17]); + comparator::COEX(vecs[4], vecs[16]); + comparator::COEX(vecs[5], vecs[21]); + comparator::COEX(vecs[6], vecs[18]); + comparator::COEX(vecs[7], vecs[9]); + comparator::COEX(vecs[8], vecs[20]); + comparator::COEX(vecs[10], vecs[26]); + comparator::COEX(vecs[11], vecs[23]); + comparator::COEX(vecs[13], vecs[25]); + comparator::COEX(vecs[14], vecs[28]); + comparator::COEX(vecs[15], vecs[27]); + comparator::COEX(vecs[22], vecs[24]); + + comparator::COEX(vecs[1], vecs[4]); + comparator::COEX(vecs[3], vecs[8]); + comparator::COEX(vecs[5], vecs[16]); + comparator::COEX(vecs[7], vecs[17]); + comparator::COEX(vecs[9], vecs[21]); + comparator::COEX(vecs[10], vecs[22]); + comparator::COEX(vecs[11], vecs[19]); + comparator::COEX(vecs[12], vecs[20]); + comparator::COEX(vecs[14], vecs[24]); + comparator::COEX(vecs[15], vecs[26]); + comparator::COEX(vecs[23], vecs[28]); + comparator::COEX(vecs[27], vecs[30]); + + comparator::COEX(vecs[2], vecs[5]); + comparator::COEX(vecs[7], vecs[8]); + comparator::COEX(vecs[9], vecs[18]); + comparator::COEX(vecs[11], vecs[17]); + comparator::COEX(vecs[12], vecs[16]); + comparator::COEX(vecs[13], vecs[22]); + comparator::COEX(vecs[14], vecs[20]); + comparator::COEX(vecs[15], vecs[19]); + comparator::COEX(vecs[23], vecs[24]); + comparator::COEX(vecs[26], vecs[29]); + + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[6], vecs[12]); + comparator::COEX(vecs[9], vecs[16]); + comparator::COEX(vecs[10], vecs[11]); + comparator::COEX(vecs[13], vecs[17]); + comparator::COEX(vecs[14], vecs[18]); + comparator::COEX(vecs[15], vecs[22]); + comparator::COEX(vecs[19], vecs[25]); + comparator::COEX(vecs[20], vecs[21]); + comparator::COEX(vecs[27], vecs[29]); + + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[8], vecs[12]); + comparator::COEX(vecs[9], vecs[10]); + comparator::COEX(vecs[11], vecs[13]); + comparator::COEX(vecs[14], vecs[16]); + comparator::COEX(vecs[15], vecs[17]); + comparator::COEX(vecs[18], vecs[20]); + comparator::COEX(vecs[19], vecs[23]); + comparator::COEX(vecs[21], vecs[22]); + comparator::COEX(vecs[25], vecs[26]); + + comparator::COEX(vecs[3], vecs[5]); + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); + comparator::COEX(vecs[10], vecs[12]); + comparator::COEX(vecs[11], vecs[14]); + comparator::COEX(vecs[13], vecs[16]); + comparator::COEX(vecs[15], vecs[18]); + comparator::COEX(vecs[17], vecs[20]); + comparator::COEX(vecs[19], vecs[21]); + comparator::COEX(vecs[22], vecs[23]); + comparator::COEX(vecs[24], vecs[25]); + comparator::COEX(vecs[26], vecs[28]); + + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[7], vecs[8]); + comparator::COEX(vecs[9], vecs[10]); + comparator::COEX(vecs[11], vecs[12]); + comparator::COEX(vecs[13], vecs[14]); + comparator::COEX(vecs[15], vecs[16]); + comparator::COEX(vecs[17], vecs[18]); + comparator::COEX(vecs[19], vecs[20]); + comparator::COEX(vecs[21], vecs[22]); + comparator::COEX(vecs[23], vecs[24]); + comparator::COEX(vecs[25], vecs[26]); + comparator::COEX(vecs[27], vecs[28]); +}