-
Notifications
You must be signed in to change notification settings - Fork 34
Make block_size flexible #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: rfsaliev/cpp-runtime-binding
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -69,9 +69,10 @@ class DynamicVamanaIndexImpl { | |||
|
|
||||
| StorageKind get_storage_kind() const { return storage_kind_; } | ||||
|
|
||||
| void add(data::ConstSimpleDataView<float> data, std::span<const size_t> labels) { | ||||
| void add(data::ConstSimpleDataView<float> data, std::span<const size_t> labels, | ||||
| int blocksize_exp) { | ||||
| if (!impl_) { | ||||
| return init_impl(data, labels); | ||||
| return init_impl(data, labels, blocksize_exp); | ||||
| } | ||||
|
|
||||
| impl_->add_points(data, labels); | ||||
|
|
@@ -389,14 +390,17 @@ class DynamicVamanaIndexImpl { | |||
| const index::vamana::VamanaBuildParameters& parameters, | ||||
| const svs::data::ConstSimpleDataView<float>& data, | ||||
| std::span<const size_t> labels, | ||||
| int blocksize_exp, | ||||
| StorageArgs&&... storage_args | ||||
| ) { | ||||
| auto threadpool = default_threadpool(); | ||||
|
|
||||
| auto blocksize_bytes = svs::lib::PowerOfTwo(blocksize_exp); | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is a runtime value, we'd probably need some validation. Do you think it could be a good idea to create a struct similar to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that make sense. What values of exponents are acceptable? In the best of my understanding, blocksize_bytes should be in range [4KB, 1GB] (based on possible page sizes
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that sounds reasonable. @mihaic did you ever experiment with even bigger hugepages?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since Xeon only goes to 1 GiB, we can have that as the limit. |
||||
| auto storage = make_storage( | ||||
| std::forward<Tag>(tag), | ||||
| data, | ||||
| threadpool, | ||||
| blocksize_bytes, | ||||
| std::forward<StorageArgs>(storage_args)... | ||||
| ); | ||||
|
|
||||
|
|
@@ -413,25 +417,29 @@ class DynamicVamanaIndexImpl { | |||
| } | ||||
|
|
||||
| virtual void | ||||
| init_impl(data::ConstSimpleDataView<float> data, std::span<const size_t> labels) { | ||||
| init_impl(data::ConstSimpleDataView<float> data, std::span<const size_t> labels, | ||||
| int blocksize_exp) { | ||||
| impl_.reset(storage::dispatch_storage_kind( | ||||
| get_storage_kind(), | ||||
| [this]( | ||||
| auto&& tag, | ||||
| data::ConstSimpleDataView<float> data, | ||||
| std::span<const size_t> labels | ||||
| std::span<const size_t> labels, | ||||
| int blocksize_exp | ||||
| ) { | ||||
| using Tag = std::decay_t<decltype(tag)>; | ||||
| return build_impl( | ||||
| std::forward<Tag>(tag), | ||||
| this->metric_type_, | ||||
| this->vamana_build_parameters(), | ||||
| data, | ||||
| labels | ||||
| labels, | ||||
| blocksize_exp | ||||
| ); | ||||
| }, | ||||
| data, | ||||
| labels | ||||
| labels, | ||||
| blocksize_exp | ||||
| )); | ||||
| } | ||||
|
|
||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -198,7 +198,8 @@ template <> struct StorageFactory<UnsupportedStorageType> { | |
| template <svs::threads::ThreadPool Pool> | ||
| static StorageType init( | ||
| const svs::data::ConstSimpleDataView<float>& SVS_UNUSED(data), | ||
| Pool& SVS_UNUSED(pool) | ||
| Pool& SVS_UNUSED(pool), | ||
| svs::lib::PowerOfTwo SVS_UNUSED(blocksize_bytes) | ||
| ) { | ||
| throw StatusException( | ||
| ErrorCode::NOT_IMPLEMENTED, "Requested storage kind is not supported" | ||
|
|
@@ -218,8 +219,12 @@ template <typename ElementType> struct StorageFactory<SimpleDatasetType<ElementT | |
| using StorageType = SimpleDatasetType<ElementType>; | ||
|
|
||
| template <svs::threads::ThreadPool Pool> | ||
| static StorageType init(const svs::data::ConstSimpleDataView<float>& data, Pool& pool) { | ||
| StorageType result(data.size(), data.dimensions()); | ||
| static StorageType init(const svs::data::ConstSimpleDataView<float>& data, Pool& pool, | ||
| svs::lib::PowerOfTwo blocksize_bytes = svs::data::BlockingParameters::default_blocksize_bytes) { | ||
| auto parameters = svs::data::BlockingParameters{ | ||
| .blocksize_bytes = blocksize_bytes}; | ||
| typename StorageType::allocator_type alloc(parameters); | ||
| StorageType result(data.size(), data.dimensions(), alloc); | ||
| svs::threads::parallel_for( | ||
| pool, | ||
| svs::threads::StaticPartition(result.size()), | ||
|
|
@@ -244,7 +249,8 @@ struct StorageFactory<SQStorageType> { | |
| using StorageType = SQStorageType; | ||
|
|
||
| template <svs::threads::ThreadPool Pool> | ||
| static StorageType init(const svs::data::ConstSimpleDataView<float>& data, Pool& pool) { | ||
| static StorageType init(const svs::data::ConstSimpleDataView<float>& data, Pool& pool, | ||
| svs::lib::PowerOfTwo SVS_UNUSED(blocksize_bytes)) { | ||
| return SQStorageType::compress(data, pool); | ||
| } | ||
|
|
||
|
|
@@ -275,7 +281,8 @@ struct StorageFactory<LVQStorageType> { | |
| using StorageType = LVQStorageType; | ||
|
|
||
| template <svs::threads::ThreadPool Pool> | ||
| static StorageType init(const svs::data::ConstSimpleDataView<float>& data, Pool& pool) { | ||
| static StorageType init(const svs::data::ConstSimpleDataView<float>& data, Pool& pool, | ||
| svs::lib::PowerOfTwo SVS_UNUSED(blocksize_bytes)) { | ||
| return LVQStorageType::compress(data, pool, 0); | ||
| } | ||
|
|
||
|
|
@@ -309,6 +316,7 @@ struct StorageFactory<LeanVecStorageType> { | |
| static StorageType init( | ||
| const svs::data::ConstSimpleDataView<float>& data, | ||
| Pool& pool, | ||
| svs::lib::PowerOfTwo SVS_UNUSED(blocksize_bytes), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| size_t leanvec_d = 0, | ||
| std::optional<svs::leanvec::LeanVecMatrices<svs::Dynamic>> matrices = std::nullopt | ||
| ) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -83,6 +83,7 @@ void write_and_read_index( | |
| size_t n, | ||
| size_t d, | ||
| svs::runtime::v0::StorageKind storage_kind, | ||
| int blocksize_exp2, | ||
| svs::runtime::v0::MetricType metric = svs::runtime::v0::MetricType::L2 | ||
| ) { | ||
| // Build index | ||
|
|
@@ -99,7 +100,7 @@ void write_and_read_index( | |
| std::vector<size_t> labels(n); | ||
| std::iota(labels.begin(), labels.end(), 0); | ||
|
|
||
| status = index->add(n, labels.data(), xb.data()); | ||
| status = index->add(n, labels.data(), xb.data(), blocksize_exp2); | ||
| CATCH_REQUIRE(status.ok()); | ||
|
|
||
| svs_test::prepare_temp_directory(); | ||
|
|
@@ -141,7 +142,7 @@ void write_and_read_index( | |
|
|
||
| // Helper that writes and reads and index of requested size | ||
| // Reports memory usage | ||
| UsageInfo run_save_and_load_test(const size_t target_mibytes) { | ||
| UsageInfo run_save_and_load_test(const size_t target_mibytes, int blocksize_exp2) { | ||
| // Generate requested MiB of test data | ||
| constexpr size_t mem_test_d = 128; | ||
| const size_t target_bytes = target_mibytes * 1024 * 1024; | ||
|
|
@@ -171,7 +172,7 @@ UsageInfo run_save_and_load_test(const size_t target_mibytes) { | |
| ); | ||
| CATCH_REQUIRE(status.ok()); | ||
| CATCH_REQUIRE(index != nullptr); | ||
| status = index->add(mem_test_n, labels.data(), large_test_data.data()); | ||
| status = index->add(mem_test_n, labels.data(), large_test_data.data(), blocksize_exp2); | ||
| CATCH_REQUIRE(status.ok()); | ||
|
|
||
| std::ofstream out(filename, std::ios::binary); | ||
|
|
@@ -224,7 +225,7 @@ CATCH_TEST_CASE("WriteAndReadIndexSVS", "[runtime]") { | |
| ); | ||
| }; | ||
| write_and_read_index( | ||
| build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::FP32 | ||
| build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::FP32, 30 | ||
| ); | ||
| } | ||
|
|
||
|
|
@@ -241,7 +242,7 @@ CATCH_TEST_CASE("WriteAndReadIndexSVSFP16", "[runtime]") { | |
| ); | ||
| }; | ||
| write_and_read_index( | ||
| build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::FP16 | ||
| build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::FP16, 30 | ||
| ); | ||
| } | ||
|
|
||
|
|
@@ -258,7 +259,7 @@ CATCH_TEST_CASE("WriteAndReadIndexSVSSQI8", "[runtime]") { | |
| ); | ||
| }; | ||
| write_and_read_index( | ||
| build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::SQI8 | ||
| build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::SQI8, 30 | ||
| ); | ||
| } | ||
|
|
||
|
|
@@ -275,7 +276,7 @@ CATCH_TEST_CASE("WriteAndReadIndexSVSLVQ4x4", "[runtime]") { | |
| ); | ||
| }; | ||
| write_and_read_index( | ||
| build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LVQ4x4 | ||
| build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LVQ4x4, 30 | ||
| ); | ||
| } | ||
|
|
||
|
|
@@ -293,7 +294,7 @@ CATCH_TEST_CASE("WriteAndReadIndexSVSVamanaLeanVec4x4", "[runtime]") { | |
| ); | ||
| }; | ||
| write_and_read_index( | ||
| build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LeanVec4x4 | ||
| build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LeanVec4x4, 30 | ||
| ); | ||
| } | ||
|
|
||
|
|
@@ -330,6 +331,40 @@ CATCH_TEST_CASE("LeanVecWithTrainingData", "[runtime]") { | |
| svs::runtime::v0::DynamicVamanaIndex::destroy(index); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("LeanVecWithTrainingDataCustomBlockSize", "[runtime]") { | ||
| const auto& test_data = get_test_data(); | ||
| // Build LeanVec index with explicit training | ||
| svs::runtime::v0::DynamicVamanaIndex* index = nullptr; | ||
| svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; | ||
| svs::runtime::v0::Status status = svs::runtime::v0::DynamicVamanaIndexLeanVec::build( | ||
| &index, | ||
| test_d, | ||
| svs::runtime::v0::MetricType::L2, | ||
| svs::runtime::v0::StorageKind::LeanVec4x4, | ||
| 32, | ||
| build_params | ||
| ); | ||
| if (!svs::runtime::v0::DynamicVamanaIndex::check_storage_kind( | ||
| svs::runtime::v0::StorageKind::LeanVec4x4 | ||
| ) | ||
| .ok()) { | ||
| CATCH_REQUIRE(!status.ok()); | ||
| CATCH_SKIP("Storage kind is not supported, skipping test."); | ||
| } | ||
| CATCH_REQUIRE(status.ok()); | ||
| CATCH_REQUIRE(index != nullptr); | ||
|
|
||
| // Add data - should work with provided leanvec dims | ||
| std::vector<size_t> labels(test_n); | ||
| std::iota(labels.begin(), labels.end(), 0); | ||
|
|
||
| int block_size_exp = 17; // block_size = 2^block_size_exp | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Beyond testing if passing the param works, we should also test if the block size actually changed. |
||
| status = index->add(test_n, labels.data(), test_data.data(), block_size_exp); | ||
| CATCH_REQUIRE(status.ok()); | ||
|
|
||
| svs::runtime::v0::DynamicVamanaIndex::destroy(index); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("FlatIndexWriteAndRead", "[runtime]") { | ||
| const auto& test_data = get_test_data(); | ||
| svs::runtime::v0::FlatIndex* index = nullptr; | ||
|
|
@@ -399,7 +434,7 @@ CATCH_TEST_CASE("SearchWithIDFilter", "[runtime]") { | |
| // Add data | ||
| std::vector<size_t> labels(test_n); | ||
| std::iota(labels.begin(), labels.end(), 0); | ||
| status = index->add(test_n, labels.data(), test_data.data()); | ||
| status = index->add(test_n, labels.data(), test_data.data(), 30); | ||
| CATCH_REQUIRE(status.ok()); | ||
|
|
||
| const int nq = 8; | ||
|
|
@@ -445,7 +480,7 @@ CATCH_TEST_CASE("RangeSearchFunctional", "[runtime]") { | |
| // Add data | ||
| std::vector<size_t> labels(test_n); | ||
| std::iota(labels.begin(), labels.end(), 0); | ||
| status = index->add(test_n, labels.data(), test_data.data()); | ||
| status = index->add(test_n, labels.data(), test_data.data(), 30); | ||
| CATCH_REQUIRE(status.ok()); | ||
|
|
||
| const int nq = 5; | ||
|
|
@@ -472,19 +507,19 @@ CATCH_TEST_CASE("MemoryUsageOnLoad", "[runtime][memory]") { | |
| }; | ||
|
|
||
| CATCH_SECTION("SmallIndex") { | ||
| auto stats = run_save_and_load_test(10); | ||
| auto stats = run_save_and_load_test(10, 30); | ||
| CATCH_REQUIRE(stats.file_size < 20 * 1024 * 1024); | ||
| CATCH_REQUIRE(stats.rss_increase < threshold(stats.file_size)); | ||
| } | ||
|
|
||
| CATCH_SECTION("MediumIndex") { | ||
| auto stats = run_save_and_load_test(50); | ||
| auto stats = run_save_and_load_test(50, 30); | ||
| CATCH_REQUIRE(stats.file_size < 100 * 1024 * 1024); | ||
| CATCH_REQUIRE(stats.rss_increase < threshold(stats.file_size)); | ||
| } | ||
|
|
||
| CATCH_SECTION("LargeIndex") { | ||
| auto stats = run_save_and_load_test(200); | ||
| auto stats = run_save_and_load_test(200, 30); | ||
| CATCH_REQUIRE(stats.file_size < 400 * 1024 * 1024); | ||
| CATCH_REQUIRE(stats.rss_increase < threshold(stats.file_size)); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if there's a cleaner way to communicate this value to
init_impl(). Maybe @rfsaliev has an opinion?