diff --git a/assets/test/bool.npy b/assets/test/bool.npy new file mode 100644 index 0000000..5128932 Binary files /dev/null and b/assets/test/bool.npy differ diff --git a/include/npy/npy.h b/include/npy/npy.h index f8c529e..dde3c10 100644 --- a/include/npy/npy.h +++ b/include/npy/npy.h @@ -85,10 +85,46 @@ enum class data_type_t : char { COMPLEX64, /// 128-bit complex number (std::complex) COMPLEX128, + /// Boolean value + BOOL, /// Unicode string (std::wstring) UNICODE_STRING }; +/// @brief Boolean datatype which uses 1-byte storage +/// @details std::vector uses bitfields to store boolean values, which +/// makes it inefficient for reading and copying data from numpy, which stores +/// boolean values as bytes. This struct mimics this scheme while remaining +/// implicitly convertable to and from boolean values and expressions. +struct boolean { + /// @brief The storage of the boolean, either 1 (true) or 0 (false) + std::uint8_t value; + + /// @brief Default constructor, sets value to false. + boolean() { value = 0; } + + /// @brief Templated constructor, allowing conversion from any type that can + /// be interpreted as a boolean. + /// @param v The value to convert into a boolean + /// @tparam T Must be implicitly convertable to bool + template boolean(const T &v) { value = v > 0 ? 1 : 0; } + + /// @brief Constructor from boolean values. + /// @param b The bool to store + boolean(bool b) { value = b ? 1 : 0; } + + /// @brief Implicit cast operator to bool + operator bool() const { return value != 0; } + + /// @brief Assignment operator + /// @param v The value to convert into a boolean + /// @tparam T Must be implicitly convertable to bool + template boolean &operator=(const T &v) { + value = v ? 1 : 0; + return *this; + } +}; + /// @brief Convert a data type and endianness to a NPY dtype string. /// @param dtype the data type /// @param endian the endianness. Defaults to the current endianness of the @@ -377,8 +413,8 @@ class npzstringwriter { compression_method_t compression = compression_method_t::STORED, endian_t endianness = npy::endian_t::NATIVE); - /// @brief Destructor. This will call @ref npy::npzstringwriter::close, if it has - /// not been called already. + /// @brief Destructor. This will call @ref npy::npzstringwriter::close, if it + /// has not been called already. ~npzstringwriter(); /// @brief Returns the contents of the string stream as a string. @@ -452,8 +488,8 @@ class npzfilewriter { compression_method_t compression = compression_method_t::STORED, endian_t endianness = npy::endian_t::NATIVE); - /// @brief Destructor. This will call @ref npy::npzfilewriter::close, if it has - /// not been called already. + /// @brief Destructor. This will call @ref npy::npzfilewriter::close, if it + /// has not been called already. ~npzfilewriter(); /// @brief Returns whether the NPZ file is open. diff --git a/src/dtype.cpp b/src/dtype.cpp index 8f81735..12cc250 100644 --- a/src/dtype.cpp +++ b/src/dtype.cpp @@ -9,13 +9,13 @@ #define GETC(x) static_cast((x).get()) namespace { -std::array BIG_ENDIAN_DTYPES = {"|i1", "|u1", ">i2", ">u2", - ">i4", ">u4", ">i8", ">u8", - ">f4", ">f8", ">c8", ">c16"}; +std::array BIG_ENDIAN_DTYPES = { + "|i1", "|u1", ">i2", ">u2", ">i4", ">u4", ">i8", + ">u8", ">f4", ">f8", ">c8", ">c16", "|b1"}; std::array LITTLE_ENDIAN_DTYPES = { - "|i1", "|u1", "> DTYPE_MAP = { {"|u1", {npy::data_type_t::UINT8, npy::endian_t::NATIVE}}, @@ -39,7 +39,8 @@ std::map> DTYPE_MAP = { {"c8", {npy::data_type_t::COMPLEX64, npy::endian_t::BIG}}, {"c16", {npy::data_type_t::COMPLEX128, npy::endian_t::BIG}}}; + {">c16", {npy::data_type_t::COMPLEX128, npy::endian_t::BIG}}, + {"|b1", {npy::data_type_t::BOOL, npy::endian_t::NATIVE}}}; } // namespace namespace npy { @@ -99,6 +100,19 @@ void read_values<>(std::basic_istream &input, int8_t *data_ptr, input.read(start, num_elements); } +template <> +void write_values<>(std::basic_ostream &output, const boolean *data_ptr, + size_t num_elements, endian_t) { + output.write(reinterpret_cast(data_ptr), num_elements); +} + +template <> +void read_values<>(std::basic_istream &input, boolean *data_ptr, + size_t num_elements, const header_info &) { + char *start = reinterpret_cast(data_ptr); + input.read(start, num_elements); +} + template <> void write_values<>(std::basic_ostream &output, const uint_least16_t *data_ptr, size_t num_elements, diff --git a/src/tensor.cpp b/src/tensor.cpp index 3493dcf..822c82d 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -54,4 +54,8 @@ template <> data_type_t tensor::get_dtype() { return data_type_t::UNICODE_STRING; } +template <> data_type_t tensor::get_dtype() { + return data_type_t::BOOL; +} + } // namespace npy \ No newline at end of file diff --git a/test/libnpy_tests.h b/test/libnpy_tests.h index 94f9e83..f6996bd 100644 --- a/test/libnpy_tests.h +++ b/test/libnpy_tests.h @@ -190,6 +190,17 @@ inline npy::tensor test_tensor(const std::vector &shape) { return tensor; } +template <> +inline npy::tensor test_tensor(const std::vector &shape) { + npy::tensor tensor(shape); + int i = 0; + for (auto it = tensor.begin(); it != tensor.end(); ++it, ++i) { + *it = (i % 2) == 1; // Alternating false, true, false, true, ... + } + + return tensor; +} + template npy::tensor test_fortran_tensor() { std::vector values = {0, 10, 20, 30, 40, 5, 15, 25, 35, 45, 1, 11, 21, 31, 41, 6, 16, 26, 36, 46, 2, 12, 22, 32, 42, 7, @@ -220,6 +231,21 @@ template <> inline npy::tensor test_fortran_tensor() { return tensor; } +template <> inline npy::tensor test_fortran_tensor() { + std::vector values = {0, 10, 20, 30, 40, 5, 15, 25, 35, 45, 1, 11, 21, + 31, 41, 6, 16, 26, 36, 46, 2, 12, 22, 32, 42, 7, + 17, 27, 37, 47, 3, 13, 23, 33, 43, 8, 18, 28, 38, + 48, 4, 14, 24, 34, 44, 9, 19, 29, 39, 49}; + npy::tensor tensor({5, 2, 5}, true); + auto dst = tensor.begin(); + auto src = values.begin(); + for (; dst != tensor.end(); ++src, ++dst) { + *dst = (*src % 2) == 1; + } + + return tensor; +} + template std::string npy_stream(npy::endian_t endianness = npy::endian_t::NATIVE) { std::ostringstream actual_stream; diff --git a/test/npy_read.cpp b/test/npy_read.cpp index aaf8b75..31c5c52 100644 --- a/test/npy_read.cpp +++ b/test/npy_read.cpp @@ -22,6 +22,7 @@ int test_npy_read() { test_read>(result, "complex64"); test_read>(result, "complex128"); test_read(result, "unicode"); + test_read(result, "bool"); return result; } \ No newline at end of file diff --git a/test/npy_write.cpp b/test/npy_write.cpp index aecaeaf..ac1c3d8 100644 --- a/test/npy_write.cpp +++ b/test/npy_write.cpp @@ -78,5 +78,9 @@ int test_npy_write() { actual = test::npy_stream(npy::endian_t::LITTLE); test::assert_equal(expected, actual, result, "npy_write_unicode"); + expected = test::read_asset("bool.npy"); + actual = test::npy_stream(); + test::assert_equal(expected, actual, result, "npy_write_bool"); + return result; };