Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ data types.
## Arg sort routines on arrays
```cpp
std::vector<size_t> arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending);
std::vector<size_t> arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan);
std::vector<size_t> arg = x86simdsort::argselect(const T* arr, size_t k, size_t size, bool hasnan);
```
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double,
uint64_t, int64_t]` Note that argsort and argselect are not accelerated with SIMD when using 16-bit
Expand Down
2 changes: 1 addition & 1 deletion lib/x86simdsort-avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
} \
template <> \
std::vector<size_t> argselect( \
type *arr, size_t k, size_t arrsize, bool hasnan) \
const type *arr, size_t k, size_t arrsize, bool hasnan) \
{ \
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
}
Expand Down
2 changes: 1 addition & 1 deletion lib/x86simdsort-internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
bool descending = false); \
template <typename T> \
XSS_HIDE_SYMBOL std::vector<size_t> \
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); \
argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false); \
}

namespace xss {
Expand Down
3 changes: 2 additions & 1 deletion lib/x86simdsort-scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ namespace scalar {
return arg;
}
template <typename T>
std::vector<size_t> argselect(T *arr, size_t k, size_t arrsize, bool hasnan)
std::vector<size_t>
argselect(const T *arr, size_t k, size_t arrsize, bool hasnan)
{
UNUSED(hasnan);
std::vector<size_t> arg(arrsize);
Expand Down
2 changes: 1 addition & 1 deletion lib/x86simdsort-skx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
} \
template <> \
std::vector<size_t> argselect( \
type *arr, size_t k, size_t arrsize, bool hasnan) \
const type *arr, size_t k, size_t arrsize, bool hasnan) \
{ \
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
}
Expand Down
4 changes: 2 additions & 2 deletions lib/x86simdsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ namespace x86simdsort {

#define DECLARE_INTERNAL_argselect(TYPE) \
static std::vector<size_t> (*internal_argselect##TYPE)( \
TYPE *, size_t, size_t, bool) \
const TYPE *, size_t, size_t, bool) \
= NULL; \
template <> \
std::vector<size_t> argselect( \
TYPE *arr, size_t k, size_t arrsize, bool hasnan) \
const TYPE *arr, size_t k, size_t arrsize, bool hasnan) \
{ \
return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \
}
Expand Down
2 changes: 1 addition & 1 deletion lib/x86simdsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ XSS_EXPORT_SYMBOL std::vector<size_t> argsort(const T *arr,
// argselect
template <typename T>
XSS_EXPORT_SYMBOL std::vector<size_t>
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false);

// keyvalue sort
template <typename T1, typename T2>
Expand Down
2 changes: 1 addition & 1 deletion src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Equivalent to `np.argselect` in
[NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html).

```cpp
void x86simdsortStatic::argselect<T>(T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false);
void x86simdsortStatic::argselect<T>(const T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false);
```
Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and
`double`.
Expand Down
10 changes: 5 additions & 5 deletions src/x86simdsort-static-incl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ X86_SIMD_SORT_FINLINE void argsort(const T *arr,

template <typename T>
X86_SIMD_SORT_FINLINE std::vector<size_t>
argselect(T *arr, size_t k, size_t size, bool hasnan = false);
argselect(const T *arr, size_t k, size_t size, bool hasnan = false);

/* argselect API required by NumPy: */
template <typename T>
void X86_SIMD_SORT_FINLINE
argselect(T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false);
void X86_SIMD_SORT_FINLINE argselect(
const T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false);

template <typename T1, typename T2>
X86_SIMD_SORT_FINLINE void keyvalue_qsort(T1 *key,
Expand Down Expand Up @@ -112,13 +112,13 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key,
} \
template <typename T> \
X86_SIMD_SORT_FINLINE void x86simdsortStatic::argselect( \
T *arr, size_t *arg, size_t k, size_t size, bool hasnan) \
const T *arr, size_t *arg, size_t k, size_t size, bool hasnan) \
{ \
ISA##_argselect(arr, arg, k, size, hasnan); \
} \
template <typename T> \
X86_SIMD_SORT_FINLINE std::vector<size_t> x86simdsortStatic::argselect( \
T *arr, size_t k, size_t size, bool hasnan) \
const T *arr, size_t k, size_t size, bool hasnan) \
{ \
std::vector<size_t> indices(size); \
std::iota(indices.begin(), indices.end(), 0); \
Expand Down
11 changes: 8 additions & 3 deletions src/xss-common-argsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,10 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
arrsize_t arrsize,
bool hasnan = false)
{
xss_argselect<T, zmm_vector, ymm_vector>(arr, arg, k, arrsize, hasnan);
// Safe: argselect never mutates arr; const is dropped only for SIMD type instantiation
using base_t = std::remove_const_t<T>;
xss_argselect<base_t, zmm_vector, ymm_vector>(
const_cast<base_t *>(arr), arg, k, arrsize, hasnan);
}

template <typename T>
Expand All @@ -751,8 +754,10 @@ X86_SIMD_SORT_INLINE void avx2_argselect(T *arr,
arrsize_t arrsize,
bool hasnan = false)
{
xss_argselect<T, avx2_vector, avx2_half_vector>(
arr, arg, k, arrsize, hasnan);
// Safe: argselect never mutates arr; const is dropped only for SIMD type instantiation
using base_t = std::remove_const_t<T>;
xss_argselect<base_t, avx2_vector, avx2_half_vector>(
const_cast<base_t *>(arr), arg, k, arrsize, hasnan);
}

#endif // XSS_COMMON_ARGSORT
Loading