diff --git a/README.md b/README.md index b79e44a..580111f 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ data types. ## Arg sort routines on arrays ```cpp std::vector arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending); -std::vector arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan); +std::vector arg = x86simdsort::argselect(const T* arr, size_t k, size_t size, bool hasnan); ``` Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double, uint64_t, int64_t]` Note that argsort and argselect are not accelerated with SIMD when using 16-bit diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 1e0761e..36b43e8 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -29,7 +29,7 @@ } \ template <> \ std::vector argselect( \ - type *arr, size_t k, size_t arrsize, bool hasnan) \ + const type *arr, size_t k, size_t arrsize, bool hasnan) \ { \ return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ } diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index f8a14c0..055df2b 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -50,7 +50,7 @@ bool descending = false); \ template \ XSS_HIDE_SYMBOL std::vector \ - argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); \ + argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false); \ } namespace xss { diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 95fab42..9f08f9b 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -88,7 +88,8 @@ namespace scalar { return arg; } template - std::vector argselect(T *arr, size_t k, size_t arrsize, bool hasnan) + std::vector + argselect(const T *arr, size_t k, size_t arrsize, bool hasnan) { UNUSED(hasnan); std::vector arg(arrsize); diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index f4c4125..6260bab 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -29,7 +29,7 @@ } \ template <> \ std::vector argselect( \ - type *arr, size_t k, size_t arrsize, bool hasnan) \ + const type *arr, size_t k, size_t arrsize, bool hasnan) \ { \ return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ } diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 7aecbea..35d6ce4 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -99,11 +99,11 @@ namespace x86simdsort { #define DECLARE_INTERNAL_argselect(TYPE) \ static std::vector (*internal_argselect##TYPE)( \ - TYPE *, size_t, size_t, bool) \ + const TYPE *, size_t, size_t, bool) \ = NULL; \ template <> \ std::vector argselect( \ - TYPE *arr, size_t k, size_t arrsize, bool hasnan) \ + const TYPE *arr, size_t k, size_t arrsize, bool hasnan) \ { \ return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \ } diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index c410918..34ed101 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -43,7 +43,7 @@ XSS_EXPORT_SYMBOL std::vector argsort(const T *arr, // argselect template XSS_EXPORT_SYMBOL std::vector -argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); +argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false); // keyvalue sort template diff --git a/src/README.md b/src/README.md index 2e52a45..ad5fc7b 100644 --- a/src/README.md +++ b/src/README.md @@ -75,7 +75,7 @@ Equivalent to `np.argselect` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html). ```cpp -void x86simdsortStatic::argselect(T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false); +void x86simdsortStatic::argselect(const T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 852da1f..2b0a11e 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -40,12 +40,12 @@ X86_SIMD_SORT_FINLINE void argsort(const T *arr, template X86_SIMD_SORT_FINLINE std::vector -argselect(T *arr, size_t k, size_t size, bool hasnan = false); +argselect(const T *arr, size_t k, size_t size, bool hasnan = false); /* argselect API required by NumPy: */ template -void X86_SIMD_SORT_FINLINE -argselect(T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false); +void X86_SIMD_SORT_FINLINE argselect( + const T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false); template X86_SIMD_SORT_FINLINE void keyvalue_qsort(T1 *key, @@ -112,13 +112,13 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::argselect( \ - T *arr, size_t *arg, size_t k, size_t size, bool hasnan) \ + const T *arr, size_t *arg, size_t k, size_t size, bool hasnan) \ { \ ISA##_argselect(arr, arg, k, size, hasnan); \ } \ template \ X86_SIMD_SORT_FINLINE std::vector x86simdsortStatic::argselect( \ - T *arr, size_t k, size_t size, bool hasnan) \ + const T *arr, size_t k, size_t size, bool hasnan) \ { \ std::vector indices(size); \ std::iota(indices.begin(), indices.end(), 0); \ diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 9af9e70..3e80083 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -741,7 +741,10 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, arrsize_t arrsize, bool hasnan = false) { - xss_argselect(arr, arg, k, arrsize, hasnan); + // Safe: argselect never mutates arr; const is dropped only for SIMD type instantiation + using base_t = std::remove_const_t; + xss_argselect( + const_cast(arr), arg, k, arrsize, hasnan); } template @@ -751,8 +754,10 @@ X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, arrsize_t arrsize, bool hasnan = false) { - xss_argselect( - arr, arg, k, arrsize, hasnan); + // Safe: argselect never mutates arr; const is dropped only for SIMD type instantiation + using base_t = std::remove_const_t; + xss_argselect( + const_cast(arr), arg, k, arrsize, hasnan); } #endif // XSS_COMMON_ARGSORT