diff --git a/.gitmodules b/.gitmodules index 5086d05f..321de80b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -39,3 +39,6 @@ path = thirdparty/magic_enum/magic_enum-0.9.7 url = https://github.com/Neargye/magic_enum.git ignore = all +[submodule "thirdparty/thirdparty/RaBitQ-Library/RaBitQ-Library-0.1"] + path = thirdparty/RaBitQ-Library/RaBitQ-Library-0.1 + url = https://github.com/VectorDB-NTU/RaBitQ-Library.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 294af340..db2d35b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,6 +28,16 @@ message(STATUS "BUILD_PYTHON_BINDINGS:${BUILD_PYTHON_BINDINGS}") option(BUILD_TOOLS "Build tools" ON) message(STATUS "BUILD_TOOLS:${BUILD_TOOLS}") +if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64|AMD64") + set(RABITQ_SUPPORTED ON) + add_definitions(-DRABITQ_SUPPORTED=1) + message(STATUS "RaBitQ support enabled for Linux x86_64") +else() + set(RABITQ_SUPPORTED OFF) + add_definitions(-DRABITQ_SUPPORTED=0) + message(STATUS "RaBitQ support disabled - only supported on Linux x86_64") +endif() + cc_directory(thirdparty) cc_directories(src) cc_directories(tests) diff --git a/python/tests/test_collection_hnsw_rabitq.py b/python/tests/test_collection_hnsw_rabitq.py new file mode 100644 index 00000000..7dfefe75 --- /dev/null +++ b/python/tests/test_collection_hnsw_rabitq.py @@ -0,0 +1,574 @@ +# Copyright 2025-present the zvec project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import platform +import sys + +import pytest +import math +import zvec + +pytestmark = pytest.mark.skipif( + not (sys.platform == "linux" and platform.machine() in ("x86_64", "AMD64")), + reason="HNSW RaBitQ only supported on Linux x86_64", +) +from zvec import ( + Collection, + CollectionOption, + DataType, + Doc, + FieldSchema, + HnswRabitqIndexParam, + HnswRabitqQueryParam, + MetricType, + VectorSchema, + VectorQuery, +) + + +# ==================== Fixtures ==================== + + +@pytest.fixture(scope="session") +def hnsw_rabitq_collection_schema(): + """Create a collection schema with HNSW RaBitQ index.""" + return zvec.CollectionSchema( + name="test_hnsw_rabitq_collection", + fields=[ + FieldSchema("id", DataType.INT64, nullable=False), + FieldSchema("name", DataType.STRING, nullable=False), + ], + vectors=[ + VectorSchema( + "embedding", + DataType.VECTOR_FP32, + dimension=128, + index_param=HnswRabitqIndexParam( + metric_type=MetricType.L2, + m=16, + ef_construction=200, + total_bits=7, + num_clusters=64, + ), + ), + ], + ) + + +@pytest.fixture(scope="session") +def collection_option(): + """Create collection options.""" + return CollectionOption(read_only=False, enable_mmap=True) + + +@pytest.fixture +def single_doc(): + """Create a single document for testing.""" + return Doc( + id="0", + fields={"id": 0, "name": "test_doc_0"}, + vectors={"embedding": [0.1 + i * 0.01 for i in range(128)]}, + ) + + +@pytest.fixture +def multiple_docs(): + """Create multiple documents for testing.""" + return [ + Doc( + id=f"{i}", + fields={"id": i, "name": f"test_doc_{i}"}, + vectors={"embedding": [i * 0.1 + j * 0.01 for j in range(128)]}, + ) + for i in range(1, 101) + ] + + +@pytest.fixture(scope="function") +def hnsw_rabitq_collection( + tmp_path_factory, hnsw_rabitq_collection_schema, collection_option +) -> Collection: + """ + Function-scoped fixture: creates and opens a collection with HNSW RaBitQ index. + """ + temp_dir = tmp_path_factory.mktemp("zvec_hnsw_rabitq") + collection_path = temp_dir / "test_hnsw_rabitq_collection" + + coll = zvec.create_and_open( + path=str(collection_path), + schema=hnsw_rabitq_collection_schema, + option=collection_option, + ) + + assert coll is not None, "Failed to create and open HNSW RaBitQ collection" + assert coll.path == str(collection_path) + assert coll.schema.name == hnsw_rabitq_collection_schema.name + + try: + yield coll + finally: + if hasattr(coll, "destroy") and coll is not None: + try: + coll.destroy() + except Exception as e: + print(f"Warning: failed to destroy collection: {e}") + + +@pytest.fixture +def collection_with_single_doc( + hnsw_rabitq_collection: Collection, single_doc: Doc +) -> Collection: + """Setup: insert single doc into collection.""" + assert hnsw_rabitq_collection.stats.doc_count == 0 + result = hnsw_rabitq_collection.insert(single_doc) + assert bool(result) + assert result.ok() + assert hnsw_rabitq_collection.stats.doc_count == 1 + + yield hnsw_rabitq_collection + + # Teardown: delete single doc + hnsw_rabitq_collection.delete(single_doc.id) + assert hnsw_rabitq_collection.stats.doc_count == 0 + + +@pytest.fixture +def collection_with_multiple_docs( + hnsw_rabitq_collection: Collection, multiple_docs: list[Doc] +) -> Collection: + """Setup: insert multiple docs into collection.""" + assert hnsw_rabitq_collection.stats.doc_count == 0 + result = hnsw_rabitq_collection.insert(multiple_docs) + assert len(result) == len(multiple_docs) + for item in result: + assert item.ok() + assert hnsw_rabitq_collection.stats.doc_count == len(multiple_docs) + + yield hnsw_rabitq_collection + + # Teardown: delete multiple docs + hnsw_rabitq_collection.delete([doc.id for doc in multiple_docs]) + + +# ==================== Tests ==================== + + +@pytest.mark.usefixtures("hnsw_rabitq_collection") +class TestHnswRabitqCollectionCreation: + """Test HNSW RaBitQ collection creation and schema validation.""" + + def test_collection_creation( + self, hnsw_rabitq_collection: Collection, hnsw_rabitq_collection_schema + ): + """Test that collection is created with correct schema.""" + assert hnsw_rabitq_collection is not None + assert hnsw_rabitq_collection.schema.name == hnsw_rabitq_collection_schema.name + assert len(hnsw_rabitq_collection.schema.fields) == len( + hnsw_rabitq_collection_schema.fields + ) + assert len(hnsw_rabitq_collection.schema.vectors) == len( + hnsw_rabitq_collection_schema.vectors + ) + + def test_vector_schema_validation(self, hnsw_rabitq_collection: Collection): + """Test that vector schema has correct HNSW RaBitQ configuration.""" + vector_schema = hnsw_rabitq_collection.schema.vector("embedding") + assert vector_schema is not None + assert vector_schema.name == "embedding" + assert vector_schema.data_type == DataType.VECTOR_FP32 + assert vector_schema.dimension == 128 + + index_param = vector_schema.index_param + assert index_param is not None + assert index_param.metric_type == MetricType.L2 + assert index_param.m == 16 + assert index_param.ef_construction == 200 + assert index_param.total_bits == 7 + assert index_param.num_clusters == 64 + + def test_collection_stats(self, hnsw_rabitq_collection: Collection): + """Test initial collection statistics.""" + stats = hnsw_rabitq_collection.stats + assert stats is not None + assert stats.doc_count == 0 + assert len(stats.index_completeness) == 1 + assert stats.index_completeness["embedding"] == 1 + + +@pytest.mark.usefixtures("hnsw_rabitq_collection") +class TestHnswRabitqCollectionInsert: + """Test document insertion into HNSW RaBitQ collection.""" + + def test_insert_single_doc( + self, hnsw_rabitq_collection: Collection, single_doc: Doc + ): + """Test inserting a single document.""" + result = hnsw_rabitq_collection.insert(single_doc) + assert bool(result) + assert result.ok() + + stats = hnsw_rabitq_collection.stats + assert stats is not None + assert stats.doc_count == 1 + + def test_insert_multiple_docs( + self, hnsw_rabitq_collection: Collection, multiple_docs: list[Doc] + ): + """Test inserting multiple documents.""" + result = hnsw_rabitq_collection.insert(multiple_docs) + assert len(result) == len(multiple_docs) + for item in result: + assert item.ok() + + stats = hnsw_rabitq_collection.stats + assert stats is not None + assert stats.doc_count == len(multiple_docs) + + +@pytest.mark.usefixtures("hnsw_rabitq_collection") +class TestHnswRabitqCollectionFetch: + """Test document fetching from HNSW RaBitQ collection.""" + + def test_fetch_single_doc( + self, collection_with_single_doc: Collection, single_doc: Doc + ): + """Test fetching a single document by ID.""" + result = collection_with_single_doc.fetch(ids=[single_doc.id]) + assert bool(result) + assert single_doc.id in result.keys() + + doc = result[single_doc.id] + assert doc is not None + assert doc.id == single_doc.id + assert doc.field("id") == single_doc.field("id") + assert doc.field("name") == single_doc.field("name") + + def test_fetch_multiple_docs( + self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] + ): + """Test fetching multiple documents by IDs.""" + ids = [doc.id for doc in multiple_docs[:10]] + result = collection_with_multiple_docs.fetch(ids=ids) + assert bool(result) + assert len(result) == len(ids) + + for doc_id in ids: + assert doc_id in result + doc = result[doc_id] + assert doc is not None + assert doc.id == doc_id + + def test_fetch_nonexistent_doc(self, collection_with_single_doc: Collection): + """Test fetching a non-existent document.""" + result = collection_with_single_doc.fetch(ids=["nonexistent_id"]) + assert len(result) == 0 + + +@pytest.mark.usefixtures("hnsw_rabitq_collection") +class TestHnswRabitqCollectionQuery: + """Test vector search queries on HNSW RaBitQ collection.""" + + def test_query_by_vector( + self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] + ): + """Test querying by vector with HNSW RaBitQ index.""" + query_vector = multiple_docs[0].vector("embedding") + query = VectorQuery( + field_name="embedding", + vector=query_vector, + param=HnswRabitqQueryParam(ef=300), + ) + + result = collection_with_multiple_docs.query(vectors=query, topk=10) + assert len(result) > 0 + assert len(result) <= 10 + + # First result should be the query document itself (or very close) + first_doc = result[0] + assert first_doc is not None + assert first_doc.id is not None + + def test_query_by_id( + self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] + ): + """Test querying by document ID with HNSW RaBitQ index.""" + query = VectorQuery( + field_name="embedding", + id=multiple_docs[0].id, + param=HnswRabitqQueryParam(ef=300), + ) + + result = collection_with_multiple_docs.query(vectors=query, topk=10) + assert len(result) > 0 + assert len(result) <= 10 + + def test_query_with_different_ef_values( + self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] + ): + """Test querying with different ef parameter values.""" + query_vector = multiple_docs[0].vector("embedding") + + # Test with ef=100 + query_100 = VectorQuery( + field_name="embedding", + vector=query_vector, + param=HnswRabitqQueryParam(ef=100), + ) + result_100 = collection_with_multiple_docs.query(vectors=query_100, topk=10) + assert len(result_100) > 0 + + # Test with ef=500 + query_500 = VectorQuery( + field_name="embedding", + vector=query_vector, + param=HnswRabitqQueryParam(ef=500), + ) + result_500 = collection_with_multiple_docs.query(vectors=query_500, topk=10) + assert len(result_500) > 0 + + def test_query_with_topk( + self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] + ): + """Test querying with different topk values.""" + query_vector = multiple_docs[0].vector("embedding") + query = VectorQuery( + field_name="embedding", + vector=query_vector, + param=HnswRabitqQueryParam(ef=300), + ) + + # Test topk=5 + result_5 = collection_with_multiple_docs.query(vectors=query, topk=5) + assert len(result_5) <= 5 + + # Test topk=20 + result_20 = collection_with_multiple_docs.query(vectors=query, topk=20) + assert len(result_20) <= 20 + + def test_query_with_filter( + self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] + ): + """Test querying with filter conditions.""" + query_vector = multiple_docs[0].vector("embedding") + query = VectorQuery( + field_name="embedding", + vector=query_vector, + param=HnswRabitqQueryParam(ef=300), + ) + + # Query with id filter + result = collection_with_multiple_docs.query( + vectors=query, topk=10, filter="id < 50" + ) + assert len(result) > 0 + for doc in result: + assert doc.field("id") < 50 + + def test_query_with_output_fields( + self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] + ): + """Test querying with specific output fields.""" + query_vector = multiple_docs[0].vector("embedding") + query = VectorQuery( + field_name="embedding", + vector=query_vector, + param=HnswRabitqQueryParam(ef=300), + ) + + result = collection_with_multiple_docs.query( + vectors=query, topk=10, output_fields=["id", "name"] + ) + assert len(result) > 0 + + first_doc = result[0] + assert "id" in first_doc.field_names() + assert "name" in first_doc.field_names() + + def test_query_with_include_vector( + self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] + ): + """Test querying with vector data included in results.""" + query_vector = multiple_docs[0].vector("embedding") + query = VectorQuery( + field_name="embedding", + vector=query_vector, + param=HnswRabitqQueryParam(ef=300), + ) + + result = collection_with_multiple_docs.query( + vectors=query, topk=10, include_vector=True + ) + assert len(result) > 0 + + first_doc = result[0] + assert first_doc.vector("embedding") is not None + assert len(first_doc.vector("embedding")) == 128 + + +@pytest.mark.usefixtures("hnsw_rabitq_collection") +class TestHnswRabitqCollectionUpdate: + """Test document update in HNSW RaBitQ collection.""" + + def test_update_doc_fields( + self, collection_with_single_doc: Collection, single_doc: Doc + ): + """Test updating document fields.""" + updated_doc = Doc( + id=single_doc.id, + fields={"id": single_doc.field("id"), "name": "updated_name"}, + ) + + result = collection_with_single_doc.update(updated_doc) + assert bool(result) + assert result.ok() + + # Verify update + fetched = collection_with_single_doc.fetch(ids=[single_doc.id]) + assert single_doc.id in fetched + doc = fetched[single_doc.id] + assert doc.field("name") == "updated_name" + + def test_update_doc_vector( + self, collection_with_single_doc: Collection, single_doc: Doc + ): + """Test updating document vector.""" + new_vector = [0.5 + i * 0.01 for i in range(128)] + updated_doc = Doc( + id=single_doc.id, + vectors={"embedding": new_vector}, + ) + + result = collection_with_single_doc.update(updated_doc) + assert bool(result) + assert result.ok() + + # Verify update + fetched = collection_with_single_doc.fetch( + ids=[single_doc.id], + ) + assert single_doc.id in fetched + doc = fetched[single_doc.id] + assert doc.vector("embedding") is not None + embedding = doc.vector("embedding") + assert len(embedding) == 128 + # Verify vector values are approximately equal (float comparison) + for i in range(128): + assert math.isclose(embedding[i], new_vector[i], rel_tol=1e-5) + + +@pytest.mark.usefixtures("hnsw_rabitq_collection") +class TestHnswRabitqCollectionDelete: + """Test document deletion from HNSW RaBitQ collection.""" + + def test_delete_single_doc( + self, collection_with_single_doc: Collection, single_doc: Doc + ): + """Test deleting a single document.""" + result = collection_with_single_doc.delete(single_doc.id) + assert bool(result) + assert result.ok() + + stats = collection_with_single_doc.stats + assert stats.doc_count == 0 + + def test_delete_multiple_docs( + self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] + ): + """Test deleting multiple documents.""" + ids_to_delete = [doc.id for doc in multiple_docs[:10]] + result = collection_with_multiple_docs.delete(ids_to_delete) + assert len(result) == len(ids_to_delete) + for item in result: + assert item.ok() + + stats = collection_with_multiple_docs.stats + assert stats.doc_count == len(multiple_docs) - len(ids_to_delete) + + +@pytest.mark.usefixtures("hnsw_rabitq_collection") +class TestHnswRabitqCollectionOptimizeAndReopen: + """Test collection optimize and reopen functionality.""" + + def test_optimize_close_reopen_and_query( + self, + tmp_path_factory, + hnsw_rabitq_collection_schema, + collection_option, + multiple_docs: list[Doc], + ): + """Test inserting 100 docs, optimize, close, reopen and query.""" + # Create collection and insert 100 documents + temp_dir = tmp_path_factory.mktemp("zvec_hnsw_rabitq_optimize") + collection_path = temp_dir / "test_optimize_collection" + + coll = zvec.create_and_open( + path=str(collection_path), + schema=hnsw_rabitq_collection_schema, + option=collection_option, + ) + + assert coll is not None + assert coll.stats.doc_count == 0 + + # Insert 100 documents + result = coll.insert(multiple_docs) + assert len(result) == len(multiple_docs) + for item in result: + assert item.ok() + assert coll.stats.doc_count == len(multiple_docs) + + # Call optimize + from zvec import OptimizeOption + + coll.optimize(option=OptimizeOption()) + + # Verify data is still accessible after optimize + query_vector = multiple_docs[0].vector("embedding") + query = VectorQuery( + field_name="embedding", + vector=query_vector, + param=HnswRabitqQueryParam(ef=300), + ) + result_before_close = coll.query(vectors=query, topk=10) + assert len(result_before_close) > 0 + + # Close collection (destroy will close it) + collection_path_str = str(collection_path) + del coll + + # Reopen collection + reopened_coll = zvec.open(path=collection_path_str, option=collection_option) + assert reopened_coll is not None + assert reopened_coll.stats.doc_count == len(multiple_docs) + + # Execute query on reopened collection + query_after_reopen = VectorQuery( + field_name="embedding", + vector=query_vector, + param=HnswRabitqQueryParam(ef=300), + ) + result_after_reopen = reopened_coll.query(vectors=query_after_reopen, topk=10) + assert len(result_after_reopen) > 0 + assert len(result_after_reopen) <= 10 + + # Verify query results are valid + first_doc = result_after_reopen[0] + assert first_doc is not None + assert first_doc.id is not None + assert first_doc.field("id") is not None + assert first_doc.field("name") is not None + + # Cleanup + reopened_coll.destroy() diff --git a/python/zvec/__init__.py b/python/zvec/__init__.py index ec35829d..877c0f43 100644 --- a/python/zvec/__init__.py +++ b/python/zvec/__init__.py @@ -44,6 +44,8 @@ FlatIndexParam, HnswIndexParam, HnswQueryParam, + HnswRabitqIndexParam, + HnswRabitqQueryParam, IndexOption, InvertIndexParam, IVFIndexParam, @@ -90,6 +92,8 @@ "VectorQuery", "InvertIndexParam", "HnswIndexParam", + "HnswRabitqIndexParam", + "HnswRabitqQueryParam", "FlatIndexParam", "IVFIndexParam", "CollectionOption", diff --git a/python/zvec/model/param/__init__.py b/python/zvec/model/param/__init__.py index 4dbeb249..c613edf5 100644 --- a/python/zvec/model/param/__init__.py +++ b/python/zvec/model/param/__init__.py @@ -20,6 +20,8 @@ FlatIndexParam, HnswIndexParam, HnswQueryParam, + HnswRabitqIndexParam, + HnswRabitqQueryParam, IndexOption, InvertIndexParam, IVFIndexParam, @@ -34,6 +36,8 @@ "FlatIndexParam", "HnswIndexParam", "HnswQueryParam", + "HnswRabitqIndexParam", + "HnswRabitqQueryParam", "IVFIndexParam", "IVFQueryParam", "IndexOption", diff --git a/python/zvec/model/param/__init__.pyi b/python/zvec/model/param/__init__.pyi index 7ecfadf8..cd1491ef 100644 --- a/python/zvec/model/param/__init__.pyi +++ b/python/zvec/model/param/__init__.pyi @@ -16,6 +16,8 @@ __all__: list[str] = [ "FlatIndexParam", "HnswIndexParam", "HnswQueryParam", + "HnswRabitqIndexParam", + "HnswRabitqQueryParam", "IVFIndexParam", "IVFQueryParam", "IndexOption", @@ -285,6 +287,135 @@ class HnswQueryParam(QueryParam): int: Size of the dynamic candidate list during HNSW search. """ +class HnswRabitqIndexParam(VectorIndexParam): + """ + + Parameters for configuring an HNSW (Hierarchical Navigable Small World) index with RabitQ quantization. + + HNSW is a graph-based approximate nearest neighbor search index. RabitQ is a + quantization method that provides high compression with minimal accuracy loss. + + Attributes: + metric_type (MetricType): Distance metric used for similarity computation. + Default is ``MetricType.IP`` (inner product). + total_bits (int): Total bits for RabitQ quantization. Default is 7. + num_clusters (int): Number of clusters for RabitQ. Default is 16. + m (int): Number of bi-directional links created for every new element + during construction. Higher values improve accuracy but increase + memory usage and construction time. Default is 50. + ef_construction (int): Size of the dynamic candidate list for nearest + neighbors during index construction. Larger values yield better + graph quality at the cost of slower build time. Default is 500. + sample_count (int): Sample count for RabitQ training. Default is 0. + + Examples: + >>> from zvec.typing import MetricType + >>> params = HnswRabitqIndexParam( + ... metric_type=MetricType.COSINE, + ... total_bits=8, + ... num_clusters=256, + ... m=16, + ... ef_construction=200, + ... sample_count=10000 + ... ) + >>> print(params) + {'metric_type': 'COSINE', 'total_bits': 8, 'num_clusters': 256, 'm': 16, 'ef_construction': 200, 'sample_count': 10000} + """ + + def __getstate__(self) -> tuple: ... + def __init__( + self, + metric_type: _zvec.typing.MetricType = ..., + total_bits: typing.SupportsInt = 7, + num_clusters: typing.SupportsInt = 16, + m: typing.SupportsInt = 50, + ef_construction: typing.SupportsInt = 500, + sample_count: typing.SupportsInt = 0, + ) -> None: ... + def __repr__(self) -> str: ... + def __setstate__(self, arg0: tuple) -> None: ... + def to_dict(self) -> dict: + """ + Convert to dictionary with all fields + """ + + @property + def ef_construction(self) -> int: + """ + int: Candidate list size during index construction. + """ + + @property + def m(self) -> int: + """ + int: Maximum number of neighbors per node. + """ + + @property + def total_bits(self) -> int: + """ + int: Total bits for RabitQ quantization. + """ + + @property + def num_clusters(self) -> int: + """ + int: Number of clusters for RabitQ. + """ + + @property + def sample_count(self) -> int: + """ + int: Sample count for RabitQ training. + """ + +class HnswRabitqQueryParam(QueryParam): + """ + + Query parameters for HNSW index with RabitQ quantization. + + Controls the trade-off between search speed and accuracy via the `ef` parameter. + + Attributes: + type (IndexType): Always ``IndexType.HNSW_RABITQ``. + ef (int): Size of the dynamic candidate list during search. + Larger values improve recall but slow down search. + Default is 300. + radius (float): Search radius for range queries. Default is 0.0. + is_linear (bool): Force linear search. Default is False. + is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. + + Examples: + >>> params = HnswRabitqQueryParam(ef=300) + >>> print(params.ef) + 300 + """ + def __getstate__(self) -> tuple: ... + def __init__( + self, + ef: typing.SupportsInt = 300, + radius: typing.SupportsFloat = 0.0, + is_linear: bool = False, + is_using_refiner: bool = False, + ) -> None: + """ + Constructs an HnswRabitqQueryParam instance. + + Args: + ef (int, optional): Search-time candidate list size. + Higher values improve accuracy. Defaults to 300. + radius (float, optional): Search radius for range queries. Default is 0.0. + is_linear (bool, optional): Force linear search. Default is False. + is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. + """ + def __repr__(self) -> str: ... + def __setstate__(self, arg0: tuple) -> None: ... + @property + def ef(self) -> int: + """ + int: Size of the dynamic candidate list during HNSW search. + """ + class IVFIndexParam(VectorIndexParam): """ diff --git a/src/binding/python/CMakeLists.txt b/src/binding/python/CMakeLists.txt index 83bfd4ce..d14d825c 100644 --- a/src/binding/python/CMakeLists.txt +++ b/src/binding/python/CMakeLists.txt @@ -24,6 +24,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Linux") $ $ $ + $ $ $ $ @@ -42,6 +43,7 @@ elseif (APPLE) -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ + -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ diff --git a/src/binding/python/model/param/python_param.cc b/src/binding/python/model/param/python_param.cc index 9a5c953f..a3818153 100644 --- a/src/binding/python/model/param/python_param.cc +++ b/src/binding/python/model/param/python_param.cc @@ -31,6 +31,8 @@ static std::string index_type_to_string(const IndexType type) { return "IVF"; case IndexType::HNSW: return "HNSW"; + case IndexType::HNSW_RABITQ: + return "HNSW_RABITQ"; default: return "UNDEFINED"; } @@ -59,6 +61,8 @@ static std::string quantize_type_to_string(const QuantizeType type) { return "INT4"; case QuantizeType::FP16: return "FP16"; + case QuantizeType::RABITQ: + return "RABITQ"; default: return "UNDEFINED"; } @@ -376,6 +380,105 @@ encapsulates its construction hyperparameters. t[3].cast()); })); + // binding hnsw rabitq index params + py::class_> + hnsw_rabitq_params(m, "HnswRabitqIndexParam", R"pbdoc( +Parameters for configuring an HNSW (Hierarchical Navigable Small World) index with RabitQ quantization. + +HNSW is a graph-based approximate nearest neighbor search index. RabitQ is a +quantization method that provides high compression with minimal accuracy loss. + +Attributes: + metric_type (MetricType): Distance metric used for similarity computation. + Default is ``MetricType.IP`` (inner product). + m (int): Number of bi-directional links created for every new element + during construction. Higher values improve accuracy but increase + memory usage and construction time. Default is 50. + ef_construction (int): Size of the dynamic candidate list for nearest + neighbors during index construction. Larger values yield better + graph quality at the cost of slower build time. Default is 500. + +Examples: + >>> from zvec.typing import MetricType + >>> params = HnswRabitqIndexParam( + ... metric_type=MetricType.COSINE, + ... m=16, + ... ef_construction=200 + ... ) + >>> print(params) + {'metric_type': 'COSINE', 'm': 16, 'ef_construction': 200} +)pbdoc"); + hnsw_rabitq_params + .def(py::init(), + py::arg("metric_type") = MetricType::IP, + py::arg("total_bits") = core_interface::kDefaultRabitqTotalBits, + py::arg("num_clusters") = core_interface::kDefaultRabitqNumClusters, + py::arg("m") = core_interface::kDefaultHnswNeighborCnt, + py::arg("ef_construction") = + core_interface::kDefaultHnswEfConstruction, + py::arg("sample_count") = 0) + .def_property_readonly("m", &HnswRabitqIndexParams::m, + "int: Maximum number of neighbors per node.") + .def_property_readonly( + "ef_construction", &HnswRabitqIndexParams::ef_construction, + "int: Candidate list size during index construction.") + .def_property_readonly("total_bits", &HnswRabitqIndexParams::total_bits, + "int: Total bits for RabitQ quantization.") + .def_property_readonly("num_clusters", + &HnswRabitqIndexParams::num_clusters, + "int: Number of clusters for RabitQ.") + .def_property_readonly("sample_count", + &HnswRabitqIndexParams::sample_count, + "int: Sample count for RabitQ training.") + .def( + "to_dict", + [](const HnswRabitqIndexParams &self) -> py::dict { + py::dict dict; + dict["type"] = index_type_to_string(self.type()); + dict["metric_type"] = metric_type_to_string(self.metric_type()); + dict["quantize_type"] = + quantize_type_to_string(self.quantize_type()); + dict["total_bits"] = self.total_bits(); + dict["num_clusters"] = self.num_clusters(); + dict["sample_count"] = self.sample_count(); + dict["m"] = self.m(); + dict["ef_construction"] = self.ef_construction(); + return dict; + }, + "Convert to dictionary with all fields") + .def( + "__repr__", + [](const HnswRabitqIndexParams &self) -> std::string { + return "{" + "\"type\":\"" + + index_type_to_string(self.type()) + + "\", \"metric_type\":\"" + + metric_type_to_string(self.metric_type()) + + "\", \"total_bits\":" + std::to_string(self.total_bits()) + + ", \"num_clusters\":" + std::to_string(self.num_clusters()) + + ", \"sample_count\":" + std::to_string(self.sample_count()) + + ", \"m\":" + std::to_string(self.m()) + + ", \"ef_construction\":" + + std::to_string(self.ef_construction()) + + "\", \"quantize_type\":\"" + + quantize_type_to_string(self.quantize_type()) + "\"}"; + }) + .def(py::pickle( + [](const HnswRabitqIndexParams &self) { + return py::make_tuple(self.metric_type(), self.total_bits(), + self.num_clusters(), self.m(), + self.ef_construction(), self.sample_count()); + }, + [](py::tuple t) { + if (t.size() != 6) + throw std::runtime_error( + "Invalid state for HnswRabitqIndexParams"); + return std::make_shared( + t[0].cast(), t[1].cast(), t[2].cast(), + t[3].cast(), t[4].cast(), t[5].cast()); + })); + // FlatIndexParams py::class_> @@ -709,10 +812,81 @@ Constructs an IVFQueryParam instance. obj->set_is_linear(t[2].cast()); return obj; })); + + // binding hnsw rabitq query params + py::class_> + hnsw_rabitq_query_params(m, "HnswRabitqQueryParam", R"pbdoc( +Query parameters for HNSW RaBitQ (Hierarchical Navigable Small World with RaBitQ quantization) index. + +Controls the trade-off between search speed and accuracy via the `ef` parameter. +RaBitQ provides efficient quantization while maintaining high search quality. + +Attributes: + type (IndexType): Always ``IndexType.HNSW_RABITQ``. + ef (int): Size of the dynamic candidate list during search. + Larger values improve recall but slow down search. + Default is 300. + radius (float): Search radius for range queries. Default is 0.0. + is_linear (bool): Force linear search. Default is False. + is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. + +Examples: + >>> params = HnswRabitqQueryParam(ef=300) + >>> print(params.ef) + 300 + >>> print(params.to_dict() if hasattr(params, 'to_dict') else params) + {"type":"HNSW_RABITQ", "ef":300} +)pbdoc"); + hnsw_rabitq_query_params + .def(py::init(), + py::arg("ef") = core_interface::kDefaultHnswEfSearch, + py::arg("radius") = 0.0f, py::arg("is_linear") = false, + py::arg("is_using_refiner") = false, + R"pbdoc( +Constructs an HnswRabitqQueryParam instance. + +Args: + ef (int, optional): Search-time candidate list size. + Higher values improve accuracy. Defaults to 300. + radius (float, optional): Search radius for range queries. Default is 0.0. + is_linear (bool, optional): Force linear search. Default is False. + is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. +)pbdoc") + .def_property_readonly( + "ef", + [](const HnswRabitqQueryParams &self) -> int { return self.ef(); }, + "int: Size of the dynamic candidate list during HNSW RaBitQ search.") + .def("__repr__", + [](const HnswRabitqQueryParams &self) -> std::string { + return "{" + "\"type\":\"" + + index_type_to_string(self.type()) + + "\", \"ef\":" + std::to_string(self.ef()) + + ", \"radius\":" + std::to_string(self.radius()) + + ", \"is_linear\":" + std::to_string(self.is_linear()) + + ", \"is_using_refiner\":" + + std::to_string(self.is_using_refiner()) + "}"; + }) + .def(py::pickle( + [](const HnswRabitqQueryParams &self) { + return py::make_tuple(self.ef(), self.radius(), self.is_linear(), + self.is_using_refiner()); + }, + [](py::tuple t) { + if (t.size() != 4) + throw std::runtime_error( + "Invalid state for HnswRabitqQueryParams"); + auto obj = + std::make_shared(t[0].cast()); + obj->set_radius(t[1].cast()); + obj->set_is_linear(t[2].cast()); + obj->set_is_using_refiner(t[3].cast()); + return obj; + })); } -void ZVecPyParams::bind_options(py::module_ &m) { - // binding collection options +void ZVecPyParams::bind_options(py::module_ &m) { // binding collection options py::class_(m, "CollectionOption", R"pbdoc( Options for opening or creating a collection. diff --git a/src/binding/python/typing/python_type.cc b/src/binding/python/typing/python_type.cc index ee057cf3..bb500346 100644 --- a/src/binding/python/typing/python_type.cc +++ b/src/binding/python/typing/python_type.cc @@ -96,6 +96,7 @@ Enumeration of supported index types in Zvec. )pbdoc") .value("UNDEFINED", IndexType::UNDEFINED) .value("HNSW", IndexType::HNSW) + .value("HNSW_RABITQ", IndexType::HNSW_RABITQ) .value("IVF", IndexType::IVF) .value("FLAT", IndexType::FLAT) .value("INVERT", IndexType::INVERT); @@ -131,7 +132,8 @@ Enumeration of supported quantization types for vector compression. .value("UNDEFINED", QuantizeType::UNDEFINED) .value("FP16", QuantizeType::FP16) .value("INT8", QuantizeType::INT8) - .value("INT4", QuantizeType::INT4); + .value("INT4", QuantizeType::INT4) + .value("RABITQ", QuantizeType::RABITQ); } void ZVecPyTyping::bind_status(py::module_ &m) { diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 7742db59..ed386853 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -12,6 +12,11 @@ cc_directory(mixed_reducer) git_version(GIT_SRCS_VER ${CMAKE_CURRENT_SOURCE_DIR}) file(GLOB_RECURSE ALL_CORE_SRCS *.cc *.c *.h) +# Remove hnsw-rabitq files if not supported +if(NOT RABITQ_SUPPORTED) + list(FILTER ALL_CORE_SRCS EXCLUDE REGEX ".*hnsw-rabitq.*") +endif() + cc_library( NAME zvec_core STATIC STRICT PACKED SRCS ${ALL_CORE_SRCS} diff --git a/src/core/algorithm/CMakeLists.txt b/src/core/algorithm/CMakeLists.txt index 648dbefe..32a9153e 100644 --- a/src/core/algorithm/CMakeLists.txt +++ b/src/core/algorithm/CMakeLists.txt @@ -6,4 +6,25 @@ cc_directory(flat) cc_directory(flat_sparse) cc_directory(ivf) cc_directory(hnsw) -cc_directory(hnsw_sparse) \ No newline at end of file +cc_directory(hnsw_sparse) +if(RABITQ_SUPPORTED) + message(STATUS "BUILD RABITQ") + cc_directory(hnsw-rabitq) +else() + message(STATUS "NOT BUILD RABITQ") + # Empty stub library for unsupported platforms + file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/rabitq_stub.cc + "// Stub implementation for unsupported platforms\n" + "// RaBitQ only supports Linux x86_64\n" + "namespace zvec { namespace core { /* empty namespace for compatibility */ } }\n" + ) + + cc_library( + NAME core_knn_hnsw_rabitq + STATIC SHARED STRICT ALWAYS_LINK + SRCS ${CMAKE_CURRENT_BINARY_DIR}/rabitq_stub.cc + LIBS core_framework + INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm + VERSION "${PROXIMA_ZVEC_VERSION}" + ) +endif() diff --git a/src/core/algorithm/cluster/opt_kmeans_cluster.cc b/src/core/algorithm/cluster/opt_kmeans_cluster.cc index 49ae8f5f..26192df7 100644 --- a/src/core/algorithm/cluster/opt_kmeans_cluster.cc +++ b/src/core/algorithm/cluster/opt_kmeans_cluster.cc @@ -1250,7 +1250,7 @@ int OptKmeansCluster::init(const IndexMeta &meta, const ailego::Params ¶ms) { auto type_ = meta.data_type(); - if (meta.metric_name() == "InnerProduct") { + if (meta.metric_name() == "InnerProduct" || meta.metric_name() == "Cosine") { switch (type_) { case IndexMeta::DataType::DT_FP16: { algorithm_.reset( diff --git a/src/core/algorithm/flat/flat_builder.cc b/src/core/algorithm/flat/flat_builder.cc index d5f32305..be6810c0 100644 --- a/src/core/algorithm/flat/flat_builder.cc +++ b/src/core/algorithm/flat/flat_builder.cc @@ -127,7 +127,6 @@ int FlatBuilder::dump(const IndexDumper::Pointer &dumper) { return error_code; } - holder_ = nullptr; stats_.set_dumped_count(keys.size()); stats_.set_dumped_costtime(stamp.milli_seconds()); return 0; diff --git a/src/core/algorithm/hnsw-rabitq/CMakeLists.txt b/src/core/algorithm/hnsw-rabitq/CMakeLists.txt new file mode 100644 index 00000000..495dc4be --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/CMakeLists.txt @@ -0,0 +1,11 @@ +include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) +include(${PROJECT_ROOT_DIR}/cmake/option.cmake) + +cc_library( + NAME core_knn_hnsw_rabitq + STATIC SHARED STRICT ALWAYS_LINK + SRCS *.cc rabitq/*.cc + LIBS core_framework sparsehash rabitqlib + INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm + VERSION "${PROXIMA_ZVEC_VERSION}" + ) diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_algorithm.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_algorithm.cc new file mode 100644 index 00000000..b5f1fdda --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_algorithm.cc @@ -0,0 +1,525 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "hnsw_rabitq_algorithm.h" +#include +#include +#include +#include "hnsw_rabitq_entity.h" + +namespace zvec { +namespace core { + +HnswRabitqAlgorithm::HnswRabitqAlgorithm(HnswRabitqEntity &entity) + : entity_(entity), + mt_(std::chrono::system_clock::now().time_since_epoch().count()), + lock_pool_(kLockCnt) {} + +int HnswRabitqAlgorithm::cleanup() { + return 0; +} + +int HnswRabitqAlgorithm::add_node(node_id_t id, level_t level, + HnswRabitqContext *ctx) { + spin_lock_.lock(); + + // std::cout << "id: " << id << ", level: " << level << std::endl; + + auto cur_max_level = entity_.cur_max_level(); + auto entry_point = entity_.entry_point(); + if (ailego_unlikely(entry_point == kInvalidNodeId)) { + entity_.update_ep_and_level(id, level); + spin_lock_.unlock(); + return 0; + } + spin_lock_.unlock(); + + if (ailego_unlikely(level > cur_max_level)) { + mutex_.lock(); + // re-check max level + cur_max_level = entity_.cur_max_level(); + entry_point = entity_.entry_point(); + if (level <= cur_max_level) { + mutex_.unlock(); + } + } + + level_t cur_level = cur_max_level; + ResultRecord dist = ctx->dist_calculator()(entry_point); + for (; cur_level > level; --cur_level) { + select_entry_point(cur_level, &entry_point, &dist, ctx); + } + + for (; cur_level >= 0; --cur_level) { + search_neighbors(cur_level, &entry_point, &dist, ctx->level_topk(cur_level), + ctx); + } + + // add neighbors from down level to top level, to avoid upper level visible + // to knn_search but the under layer level not ready + for (cur_level = 0; cur_level <= level; ++cur_level) { + add_neighbors(id, cur_level, ctx->level_topk(cur_level), ctx); + ctx->level_topk(cur_level).clear(); + } + + if (ailego_unlikely(level > cur_max_level)) { + spin_lock_.lock(); + entity_.update_ep_and_level(id, level); + spin_lock_.unlock(); + mutex_.unlock(); + } + + return 0; +} + +int HnswRabitqAlgorithm::search(HnswRabitqContext *ctx) const { + spin_lock_.lock(); + auto maxLevel = entity_.cur_max_level(); + auto entry_point = entity_.entry_point(); + spin_lock_.unlock(); + + if (ailego_unlikely(entry_point == kInvalidNodeId)) { + return 0; + } + + ResultRecord dist = ctx->dist_calculator().dist(entry_point); + for (level_t cur_level = maxLevel; cur_level >= 1; --cur_level) { + select_entry_point(cur_level, &entry_point, &dist, ctx); + } + + auto &topk_heap = ctx->topk_heap(); + topk_heap.clear(); + search_neighbors(0, &entry_point, &dist, topk_heap, ctx); + + if (ctx->group_by_search()) { + expand_neighbors_by_group(topk_heap, ctx); + } + + return 0; +} + +//! select_entry_point on hnsw level, ef = 1 +void HnswRabitqAlgorithm::select_entry_point(level_t level, + node_id_t *entry_point, + ResultRecord *dist, + HnswRabitqContext *ctx) const { + auto &entity = ctx->get_entity(); + HnswRabitqAddDistCalculator &dc = ctx->dist_calculator(); + while (true) { + const Neighbors neighbors = entity.get_neighbors(level, *entry_point); + if (ailego_unlikely(ctx->debugging())) { + (*ctx->mutable_stats_get_neighbors())++; + } + uint32_t size = neighbors.size(); + if (size == 0) { + break; + } + + std::vector neighbor_vec_blocks; + int ret = dc.get_vector(&neighbors[0], size, neighbor_vec_blocks); + if (ailego_unlikely(ctx->debugging())) { + (*ctx->mutable_stats_get_vector())++; + } + if (ailego_unlikely(ret != 0)) { + break; + } + + bool find_closer = false; + + float dists[size]; + const void *neighbor_vecs[size]; + for (uint32_t i = 0; i < size; ++i) { + neighbor_vecs[i] = neighbor_vec_blocks[i].data(); + } + + dc.batch_dist(neighbor_vecs, size, dists); + + for (uint32_t i = 0; i < size; ++i) { + ResultRecord cur_dist = dists[i]; + + if (cur_dist < *dist) { + *entry_point = neighbors[i]; + *dist = cur_dist; + find_closer = true; + } + } + + if (!find_closer) { + break; + } + } + + return; +} + +void HnswRabitqAlgorithm::add_neighbors(node_id_t id, level_t level, + TopkHeap &topk_heap, + HnswRabitqContext *ctx) { + if (ailego_unlikely(topk_heap.size() == 0)) { + return; + } + + HnswRabitqAddDistCalculator &dc = ctx->dist_calculator(); + + update_neighbors(dc, id, level, topk_heap); + + // reverse update neighbors + for (size_t i = 0; i < topk_heap.size(); ++i) { + reverse_update_neighbors(dc, topk_heap[i].first, level, id, + topk_heap[i].second, ctx->update_heap()); + } + + return; +} + +void HnswRabitqAlgorithm::search_neighbors(level_t level, + node_id_t *entry_point, + ResultRecord *dist, TopkHeap &topk, + HnswRabitqContext *ctx) const { + const auto &entity = ctx->get_entity(); + HnswRabitqAddDistCalculator &dc = ctx->dist_calculator(); + VisitFilter &visit = ctx->visit_filter(); + CandidateHeap &candidates = ctx->candidates(); + std::function filter = [](node_id_t) { return false; }; + if (ctx->filter().is_valid()) { + filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); }; + } + + candidates.clear(); + visit.clear(); + visit.set_visited(*entry_point); + if (!filter(*entry_point)) { + topk.emplace(*entry_point, *dist); + } + + candidates.emplace(*entry_point, *dist); + while (!candidates.empty() && !ctx->reach_scan_limit()) { + auto top = candidates.begin(); + node_id_t main_node = top->first; + ResultRecord main_dist = top->second; + + if (topk.full() && main_dist > topk[0].second) { + break; + } + + candidates.pop(); + const Neighbors neighbors = entity.get_neighbors(level, main_node); + ailego_prefetch(neighbors.data); + if (ailego_unlikely(ctx->debugging())) { + (*ctx->mutable_stats_get_neighbors())++; + } + + node_id_t neighbor_ids[neighbors.size()]; + uint32_t size = 0; + for (uint32_t i = 0; i < neighbors.size(); ++i) { + node_id_t node = neighbors[i]; + if (visit.visited(node)) { + if (ailego_unlikely(ctx->debugging())) { + (*ctx->mutable_stats_visit_dup_cnt())++; + } + continue; + } + visit.set_visited(node); + neighbor_ids[size++] = node; + } + if (size == 0) { + continue; + } + + std::vector neighbor_vec_blocks; + int ret = dc.get_vector(neighbor_ids, size, neighbor_vec_blocks); + if (ailego_unlikely(ctx->debugging())) { + (*ctx->mutable_stats_get_vector())++; + } + if (ailego_unlikely(ret != 0)) { + break; + } + + // do prefetch + static constexpr node_id_t BATCH_SIZE = 12; + static constexpr node_id_t PREFETCH_STEP = 2; + for (uint32_t i = 0; i < std::min(BATCH_SIZE * PREFETCH_STEP, size); ++i) { + ailego_prefetch(neighbor_vec_blocks[i].data()); + } + // done + + float dists[size]; + const void *neighbor_vecs[size]; + + for (uint32_t i = 0; i < size; ++i) { + neighbor_vecs[i] = neighbor_vec_blocks[i].data(); + } + + dc.batch_dist(neighbor_vecs, size, dists); + + for (uint32_t i = 0; i < size; ++i) { + node_id_t node = neighbor_ids[i]; + ResultRecord cur_dist = dists[i]; + + if ((!topk.full()) || cur_dist < topk[0].second) { + candidates.emplace(node, cur_dist); + // update entry_point for next level scan + if (cur_dist < *dist) { + *entry_point = node; + *dist = cur_dist; + } + if (!filter(node)) { + topk.emplace(node, cur_dist); + } + } // end if + } // end for + } // while + + return; +} + +void HnswRabitqAlgorithm::expand_neighbors_by_group( + TopkHeap &topk, HnswRabitqContext *ctx) const { + // if (!ctx->group_by().is_valid()) { + // return; + // } + + // const auto &entity = ctx->get_entity(); + // std::function group_by = [&](node_id_t id) { + // return ctx->group_by()(entity.get_key(id)); + // }; + + // // devide into groups + // std::map &group_topk_heaps = + // ctx->group_topk_heaps(); for (uint32_t i = 0; i < topk.size(); ++i) { + // node_id_t id = topk[i].first; + // auto score = topk[i].second; + + // std::string group_id = group_by(id); + + // auto &topk_heap = group_topk_heaps[group_id]; + // if (topk_heap.empty()) { + // topk_heap.limit(ctx->group_topk()); + // } + // topk_heap.emplace_back(id, score); + // } + + // // stage 2, expand to reach group num as possible + // if (group_topk_heaps.size() < ctx->group_num()) { + // VisitFilter &visit = ctx->visit_filter(); + // CandidateHeap &candidates = ctx->candidates(); + // HnswRabitqAddDistCalculator &dc = ctx->dist_calculator(); + + // std::function filter = [](node_id_t) { return false; }; + // if (ctx->filter().is_valid()) { + // filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); + // }; + // } + + // // refill to get enough groups + // candidates.clear(); + // visit.clear(); + // for (uint32_t i = 0; i < topk.size(); ++i) { + // node_id_t id = topk[i].first; + // ResultRecord score = topk[i].second; + + // visit.set_visited(id); + // candidates.emplace_back(id, score); + // } + + // // do expand + // while (!candidates.empty() && !ctx->reach_scan_limit()) { + // auto top = candidates.begin(); + // node_id_t main_node = top->first; + + // candidates.pop(); + // const Neighbors neighbors = entity.get_neighbors(0, main_node); + // if (ailego_unlikely(ctx->debugging())) { + // (*ctx->mutable_stats_get_neighbors())++; + // } + + // node_id_t neighbor_ids[neighbors.size()]; + // uint32_t size = 0; + // for (uint32_t i = 0; i < neighbors.size(); ++i) { + // node_id_t node = neighbors[i]; + // if (visit.visited(node)) { + // if (ailego_unlikely(ctx->debugging())) { + // (*ctx->mutable_stats_visit_dup_cnt())++; + // } + // continue; + // } + // visit.set_visited(node); + // neighbor_ids[size++] = node; + // } + // if (size == 0) { + // continue; + // } + + // std::vector neighbor_vec_blocks; + // int ret = entity.get_vector(neighbor_ids, size, neighbor_vec_blocks); + // if (ailego_unlikely(ctx->debugging())) { + // (*ctx->mutable_stats_get_vector())++; + // } + // if (ailego_unlikely(ret != 0)) { + // break; + // } + + // static constexpr node_id_t PREFETCH_STEP = 2; + // for (uint32_t i = 0; i < size; ++i) { + // node_id_t node = neighbor_ids[i]; + // node_id_t prefetch_id = i + PREFETCH_STEP; + // if (prefetch_id < size) { + // ailego_prefetch(neighbor_vec_blocks[prefetch_id].data()); + // } + // ResultRecord cur_dist = dc.dist(neighbor_vec_blocks[i].data()); + + // if (!filter(node)) { + // std::string group_id = group_by(node); + + // auto &topk_heap = group_topk_heaps[group_id]; + // if (topk_heap.empty()) { + // topk_heap.limit(ctx->group_topk()); + // } + // topk_heap.emplace_back(node, cur_dist); + + // if (group_topk_heaps.size() >= ctx->group_num()) { + // break; + // } + // } + + // candidates.emplace(node, cur_dist); + // } // end for + // } // end while + // } // end if +} + +void HnswRabitqAlgorithm::update_neighbors(HnswRabitqAddDistCalculator &dc, + node_id_t id, level_t level, + TopkHeap &topk_heap) { + topk_heap.sort(); + + uint32_t max_neighbor_cnt = entity_.neighbor_cnt(level); + if (topk_heap.size() <= static_cast(entity_.prune_cnt())) { + if (topk_heap.size() <= static_cast(max_neighbor_cnt)) { + entity_.update_neighbors(level, id, topk_heap); + return; + } + } + + uint32_t cur_size = 0; + for (size_t i = 0; i < topk_heap.size(); ++i) { + node_id_t cur_node = topk_heap[i].first; + ResultRecord cur_node_dist = topk_heap[i].second; + bool good = true; + for (uint32_t j = 0; j < cur_size; ++j) { + ResultRecord tmp_dist = dc.dist(cur_node, topk_heap[j].first); + if (tmp_dist <= cur_node_dist) { + good = false; + break; + } + } + + if (good) { + topk_heap[cur_size].first = cur_node; + topk_heap[cur_size].second = cur_node_dist; + cur_size++; + if (cur_size >= max_neighbor_cnt) { + break; + } + } + } + + // when after-prune neighbor count is too seldom, + // we use this strategy to make-up enough edges + // not only just make-up out-degrees + // we also make-up enough in-degrees + uint32_t min_neighbors = entity_.min_neighbor_cnt(); + for (size_t k = cur_size; cur_size < min_neighbors && k < topk_heap.size(); + ++k) { + bool exist = false; + for (size_t j = 0; j < cur_size; ++j) { + if (topk_heap[j].first == topk_heap[k].first) { + exist = true; + break; + } + } + if (!exist) { + topk_heap[cur_size].first = topk_heap[k].first; + topk_heap[cur_size].second = topk_heap[k].second; + cur_size++; + } + } + + topk_heap.resize(cur_size); + entity_.update_neighbors(level, id, topk_heap); + + return; +} + +void HnswRabitqAlgorithm::reverse_update_neighbors( + HnswRabitqAddDistCalculator &dc, node_id_t id, level_t level, + node_id_t link_id, ResultRecord dist, TopkHeap &update_heap) { + const size_t max_neighbor_cnt = entity_.neighbor_cnt(level); + + uint32_t lock_idx = id & kLockMask; + lock_pool_[lock_idx].lock(); + const Neighbors neighbors = entity_.get_neighbors(level, id); + size_t size = neighbors.size(); + ailego_assert_with(size <= max_neighbor_cnt, "invalid neighbor size"); + if (size < max_neighbor_cnt) { + entity_.add_neighbor(level, id, size, link_id); + lock_pool_[lock_idx].unlock(); + return; + } + + update_heap.emplace(link_id, dist); + + for (size_t i = 0; i < size; ++i) { + node_id_t node = neighbors[i]; + ResultRecord cur_dist = dc.dist(id, node); + update_heap.emplace(node, cur_dist); + } + + //! TODO: optimize prune + //! prune edges + update_heap.sort(); + size_t cur_size = 0; + for (size_t i = 0; i < update_heap.size(); ++i) { + node_id_t cur_node = update_heap[i].first; + ResultRecord cur_node_dist = update_heap[i].second; + bool good = true; + for (size_t j = 0; j < cur_size; ++j) { + ResultRecord tmp_dist = dc.dist(cur_node, update_heap[j].first); + if (tmp_dist <= cur_node_dist) { + good = false; + break; + } + } + + if (good) { + update_heap[cur_size].first = cur_node; + update_heap[cur_size].second = cur_node_dist; + cur_size++; + if (cur_size >= max_neighbor_cnt) { + break; + } + } + } + + update_heap.resize(cur_size); + entity_.update_neighbors(level, id, update_heap); + + lock_pool_[lock_idx].unlock(); + + update_heap.clear(); + + return; +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_algorithm.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_algorithm.h new file mode 100644 index 00000000..852d779e --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_algorithm.h @@ -0,0 +1,131 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "hnsw_rabitq_context.h" +#include "hnsw_rabitq_dist_calculator.h" +#include "hnsw_rabitq_entity.h" + +namespace zvec { +namespace core { + +//! hnsw graph algorithm implement +class HnswRabitqAlgorithm { + public: + typedef std::unique_ptr UPointer; + + public: + //! Constructor + explicit HnswRabitqAlgorithm(HnswRabitqEntity &entity); + + //! Destructor + ~HnswRabitqAlgorithm() = default; + + //! Cleanup HnswRabitqAlgorithm + int cleanup(); + + //! Add a node to hnsw graph + //! @id: the node unique id + //! @level: a node will be add to graph in each level [0, level] + //! return 0 on success, or errCode in failure + int add_node(node_id_t id, level_t level, HnswRabitqContext *ctx); + + //! do knn search in graph + //! return 0 on success, or errCode in failure. results saved in ctx + int search(HnswRabitqContext *ctx) const; + + //! Initiate HnswRabitqAlgorithm + int init() { + level_probas_.clear(); + double level_mult = + 1 / std::log(static_cast(entity_.scaling_factor())); + for (int level = 0;; level++) { + // refers faiss get_random_level alg + double proba = + std::exp(-level / level_mult) * (1 - std::exp(-1 / level_mult)); + if (proba < 1e-9) { + break; + } + level_probas_.push_back(proba); + } + + return 0; + } + + //! Generate a random level + //! return graph level + uint32_t get_random_level() const { + // gen rand float (0, 1) + double f = mt_() / static_cast(mt_.max()); + for (size_t level = 0; level < level_probas_.size(); level++) { + if (f < level_probas_[level]) { + return level; + } + f -= level_probas_[level]; + } + return level_probas_.size() - 1; + } + + private: + //! Select in upper layer to get entry point for next layer search + void select_entry_point(level_t level, node_id_t *entry_point, + ResultRecord *dist, HnswRabitqContext *ctx) const; + + //! update node id neighbors from topkHeap, and reverse link is also updated + void add_neighbors(node_id_t id, level_t level, TopkHeap &topk_heap, + HnswRabitqContext *ctx); + + //! Given a node id and level, search the nearest neighbors in graph + //! Note: the nearest neighbors result keeps in topk, and entry_point and + //! dist will be updated to current level nearest node id and distance + void search_neighbors(level_t level, node_id_t *entry_point, + ResultRecord *dist, TopkHeap &topk, + HnswRabitqContext *ctx) const; + + //! Update the node's neighbors + void update_neighbors(HnswRabitqAddDistCalculator &dc, node_id_t id, + level_t level, TopkHeap &topk_heap); + + //! Checking linkId could be id's new neighbor, and add as neighbor if true + //! @dc distance calculator + //! @updateHeap temporary heap in updating neighbors + void reverse_update_neighbors(HnswRabitqAddDistCalculator &dc, node_id_t id, + level_t level, node_id_t link_id, + ResultRecord dist, TopkHeap &update_heap); + + //! expand neighbors until group nums are reached + void expand_neighbors_by_group(TopkHeap &topk, HnswRabitqContext *ctx) const; + + private: + HnswRabitqAlgorithm(const HnswRabitqAlgorithm &) = delete; + HnswRabitqAlgorithm &operator=(const HnswRabitqAlgorithm &) = delete; + + private: + static constexpr uint32_t kLockCnt{1U << 8}; + static constexpr uint32_t kLockMask{kLockCnt - 1U}; + + HnswRabitqEntity &entity_; + mutable std::mt19937 mt_{}; + std::vector level_probas_{}; + + mutable ailego::SpinMutex spin_lock_{}; // global spin lock + std::mutex mutex_{}; // global mutex + // TODO: spin lock? + std::vector lock_pool_{}; +}; + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder.cc new file mode 100644 index 00000000..76c8439b --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder.cc @@ -0,0 +1,553 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "hnsw_rabitq_builder.h" +#include +#include +#include +#include +#include +#include +#include +#include "zvec/core/framework/index_error.h" +#include "zvec/core/framework/index_factory.h" +#include "zvec/core/framework/index_logger.h" +#include "zvec/core/framework/index_meta.h" +#include "zvec/core/framework/index_provider.h" +#include "hnsw_rabitq_algorithm.h" +#include "hnsw_rabitq_entity.h" +#include "hnsw_rabitq_params.h" +#include "rabitq_converter.h" +#include "rabitq_params.h" +#include "rabitq_reformer.h" + +namespace zvec { +namespace core { + +HnswRabitqBuilder::HnswRabitqBuilder() {} + +int HnswRabitqBuilder::init(const IndexMeta &meta, + const ailego::Params ¶ms) { + LOG_INFO("Begin HnswRabitqBuilder::init"); + + meta_ = meta; + auto params_copy = params; + meta_.set_builder("HnswRabitqBuilder", HnswRabitqEntity::kRevision, + std::move(params_copy)); + + size_t memory_quota = 0UL; + params.get(PARAM_HNSW_RABITQ_BUILDER_MEMORY_QUOTA, &memory_quota); + params.get(PARAM_HNSW_RABITQ_BUILDER_THREAD_COUNT, &thread_cnt_); + params.get(PARAM_HNSW_RABITQ_BUILDER_MIN_NEIGHBOR_COUNT, &min_neighbor_cnt_); + params.get(PARAM_HNSW_RABITQ_BUILDER_EFCONSTRUCTION, &ef_construction_); + params.get(PARAM_HNSW_RABITQ_BUILDER_CHECK_INTERVAL_SECS, + &check_interval_secs_); + + params.get(PARAM_HNSW_RABITQ_BUILDER_MAX_NEIGHBOR_COUNT, + &upper_max_neighbor_cnt_); + float multiplier = HnswRabitqEntity::kDefaultL0MaxNeighborCntMultiplier; + params.get(PARAM_HNSW_RABITQ_BUILDER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER, + &multiplier); + l0_max_neighbor_cnt_ = multiplier * upper_max_neighbor_cnt_; + scaling_factor_ = upper_max_neighbor_cnt_; + params.get(PARAM_HNSW_RABITQ_BUILDER_SCALING_FACTOR, &scaling_factor_); + + multiplier = HnswRabitqEntity::kDefaultNeighborPruneMultiplier; + params.get(PARAM_HNSW_RABITQ_BUILDER_NEIGHBOR_PRUNE_MULTIPLIER, &multiplier); + size_t prune_cnt = multiplier * upper_max_neighbor_cnt_; + + if (ef_construction_ == 0) { + ef_construction_ = HnswRabitqEntity::kDefaultEfConstruction; + } + if (upper_max_neighbor_cnt_ == 0) { + upper_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultUpperMaxNeighborCnt; + } + if (upper_max_neighbor_cnt_ > kMaxNeighborCnt) { + LOG_ERROR("[%s] must be in range (0,%d]", + PARAM_HNSW_RABITQ_BUILDER_MAX_NEIGHBOR_COUNT.c_str(), + kMaxNeighborCnt); + return IndexError_InvalidArgument; + } + if (min_neighbor_cnt_ > upper_max_neighbor_cnt_) { + LOG_ERROR("[%s]-[%d] must be <= [%s]-[%d]", + PARAM_HNSW_RABITQ_BUILDER_MIN_NEIGHBOR_COUNT.c_str(), + min_neighbor_cnt_, + PARAM_HNSW_RABITQ_BUILDER_MAX_NEIGHBOR_COUNT.c_str(), + upper_max_neighbor_cnt_); + return IndexError_InvalidArgument; + } + if (l0_max_neighbor_cnt_ == 0) { + l0_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultUpperMaxNeighborCnt; + } + if (l0_max_neighbor_cnt_ > HnswRabitqEntity::kMaxNeighborCnt) { + LOG_ERROR("L0MaxNeighborCnt must be in range (0,%d)", + HnswRabitqEntity::kMaxNeighborCnt); + return IndexError_InvalidArgument; + } + if (scaling_factor_ == 0U) { + scaling_factor_ = HnswRabitqEntity::kDefaultScalingFactor; + } + if (scaling_factor_ < 5 || scaling_factor_ > 1000) { + LOG_ERROR("[%s] must be in range [5,1000]", + PARAM_HNSW_RABITQ_BUILDER_SCALING_FACTOR.c_str()); + return IndexError_InvalidArgument; + } + if (thread_cnt_ == 0) { + thread_cnt_ = std::thread::hardware_concurrency(); + } + if (thread_cnt_ > std::thread::hardware_concurrency()) { + LOG_WARN("[%s] greater than cpu cores %zu", + PARAM_HNSW_RABITQ_BUILDER_THREAD_COUNT.c_str(), + static_cast(std::thread::hardware_concurrency())); + } + if (prune_cnt == 0UL) { + prune_cnt = upper_max_neighbor_cnt_; + } + + metric_ = IndexFactory::CreateMetric(meta_.metric_name()); + if (!metric_) { + LOG_ERROR("CreateMetric failed, name: %s", meta_.metric_name().c_str()); + return IndexError_NoExist; + } + int ret = metric_->init(meta_, meta_.metric_params()); + if (ret != 0) { + LOG_ERROR("IndexMetric init failed, ret=%d", ret); + return ret; + } + + uint32_t total_bits = 0; + params.get(PARAM_RABITQ_TOTAL_BITS, &total_bits); + if (total_bits == 0) { + total_bits = kDefaultRabitqTotalBits; + } + if (total_bits < 1 || total_bits > 9) { + LOG_ERROR("Invalid total_bits: %zu, must be in [1, 9]", (size_t)total_bits); + return IndexError_InvalidArgument; + } + uint8_t ex_bits = total_bits - 1; + entity_.set_ex_bits(ex_bits); + + uint32_t dimension = 0; + params.get(PARAM_HNSW_RABITQ_GENERAL_DIMENSION, &dimension); + if (dimension == 0) { + LOG_ERROR("%s not set", PARAM_HNSW_RABITQ_GENERAL_DIMENSION.c_str()); + return IndexError_InvalidArgument; + } + entity_.update_rabitq_params_and_vector_size(dimension); + + entity_.set_ef_construction(ef_construction_); + entity_.set_l0_neighbor_cnt(l0_max_neighbor_cnt_); + entity_.set_min_neighbor_cnt(min_neighbor_cnt_); + entity_.set_upper_neighbor_cnt(upper_max_neighbor_cnt_); + entity_.set_scaling_factor(scaling_factor_); + entity_.set_memory_quota(memory_quota); + entity_.set_prune_cnt(prune_cnt); + + ret = entity_.init(); + if (ret != 0) { + return ret; + } + + alg_ = HnswRabitqAlgorithm::UPointer(new HnswRabitqAlgorithm(entity_)); + + ret = alg_->init(); + if (ret != 0) { + return ret; + } + + // Create and initialize RaBitQ converter + converter_ = std::make_shared(); + if (!converter_) { + LOG_ERROR("Failed to create RabitqConverter"); + return IndexError_NoMemory; + } + + IndexMeta converter_meta = meta_; + converter_meta.set_dimension(dimension); + ret = converter_->init(converter_meta, params); + if (ret != 0) { + LOG_ERROR("Failed to initialize RabitqConverter: %d", ret); + return ret; + } + + state_ = BUILD_STATE_INITED; + LOG_INFO( + "End HnswRabitqBuilder::init, params: rawVectorSize=%u vectorSize=%zu " + "efConstruction=%u " + "l0NeighborCnt=%u upperNeighborCnt=%u scalingFactor=%u " + "memoryQuota=%zu neighborPruneCnt=%zu metricName=%s ", + meta_.element_size(), entity_.vector_size(), ef_construction_, + l0_max_neighbor_cnt_, upper_max_neighbor_cnt_, scaling_factor_, + memory_quota, prune_cnt, meta_.metric_name().c_str()); + + return 0; +} + +int HnswRabitqBuilder::cleanup(void) { + LOG_INFO("Begin HnswRabitqBuilder::cleanup"); + + l0_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultL0MaxNeighborCnt; + min_neighbor_cnt_ = 0; + upper_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultUpperMaxNeighborCnt; + ef_construction_ = HnswRabitqEntity::kDefaultEfConstruction; + scaling_factor_ = HnswRabitqEntity::kDefaultScalingFactor; + check_interval_secs_ = kDefaultLogIntervalSecs; + errcode_ = 0; + error_ = false; + entity_.cleanup(); + alg_->cleanup(); + meta_.clear(); + metric_.reset(); + stats_.clear_attributes(); + stats_.set_trained_count(0UL); + stats_.set_built_count(0UL); + stats_.set_dumped_count(0UL); + stats_.set_discarded_count(0UL); + stats_.set_trained_costtime(0UL); + stats_.set_built_costtime(0UL); + stats_.set_dumped_costtime(0UL); + state_ = BUILD_STATE_INIT; + + LOG_INFO("End HnswRabitqBuilder::cleanup"); + + return 0; +} + +int HnswRabitqBuilder::train(IndexThreads::Pointer, + IndexHolder::Pointer holder) { + if (state_ != BUILD_STATE_INITED) { + LOG_ERROR("Init the builder before HnswRabitqBuilder::train"); + return IndexError_NoReady; + } + + if (!holder) { + LOG_ERROR("Input holder is nullptr while training index"); + return IndexError_InvalidArgument; + } + if (!holder->is_matched(meta_)) { + LOG_ERROR("Input holder doesn't match index meta while training index"); + return IndexError_Mismatch; + } + LOG_INFO("Begin HnswRabitqBuilder::train"); + size_t trained_cost_time = 0; + size_t trained_count = 0; + + int ret = train_converter_and_load_reformer(holder); + if (ret != 0) { + return ret; + } + + if (metric_->support_train()) { + auto start_time = ailego::Monotime::MilliSeconds(); + auto iter = holder->create_iterator(); + if (!iter) { + LOG_ERROR("Create iterator for holder failed"); + return IndexError_Runtime; + } + while (iter->is_valid()) { + int ret = metric_->train(iter->data(), meta_.dimension()); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw build measure train failed, ret=%d", ret); + return ret; + } + iter->next(); + ++trained_count; + } + trained_cost_time = ailego::Monotime::MilliSeconds() - start_time; + } + stats_.set_trained_count(trained_count); + stats_.set_trained_costtime(trained_cost_time); + state_ = BUILD_STATE_TRAINED; + + LOG_INFO("End HnswRabitqBuilder::train"); + + return 0; +} + +int HnswRabitqBuilder::train_converter_and_load_reformer( + IndexHolder::Pointer holder) { + // Train converter (KMeans clustering) + int ret = converter_->train(holder); + if (ret != 0) { + LOG_ERROR("Failed to train RabitqConverter: %d", ret); + return ret; + } + auto memory_dumper = IndexFactory::CreateDumper("MemoryDumper"); + memory_dumper->init(ailego::Params()); + std::string file_id = ailego::StringHelper::Concat( + "rabitq_converter_", ailego::Monotime::MilliSeconds(), rand()); + ret = memory_dumper->create(file_id); + if (ret != 0) { + LOG_ERROR("Failed to create memory dumper: %d", ret); + return ret; + } + ret = converter_->dump(memory_dumper); + if (ret != 0) { + LOG_ERROR("Failed to dump RabitqConverter: %d", ret); + return ret; + } + ret = memory_dumper->close(); + if (ret != 0) { + LOG_ERROR("Failed to close memory dumper: %d", ret); + return ret; + } + + reformer_ = std::make_shared(); + ailego::Params reformer_params; + reformer_params.set(PARAM_RABITQ_METRIC_NAME, meta_.metric_name()); + ret = reformer_->init(reformer_params); + if (ret != 0) { + LOG_ERROR("Failed to initialize RabitqReformer: %d", ret); + return ret; + } + auto memory_storage = IndexFactory::CreateStorage("MemoryReadStorage"); + ret = memory_storage->open(file_id, false); + if (ret != 0) { + LOG_ERROR("Failed to open memory storage: %d", ret); + return ret; + } + ret = reformer_->load(memory_storage); + if (ret != 0) { + LOG_ERROR("Failed to load RabitqReformer: %d", ret); + return ret; + } + // TODO: release memory of memory_storage + return 0; +} + +int HnswRabitqBuilder::train(const IndexTrainer::Pointer & /*trainer*/) { + if (state_ != BUILD_STATE_INITED) { + LOG_ERROR("Init the builder before HnswRabitqBuilder::train"); + return IndexError_NoReady; + } + + LOG_INFO("Begin HnswRabitqBuilder::train by trainer"); + + stats_.set_trained_count(0UL); + stats_.set_trained_costtime(0UL); + state_ = BUILD_STATE_TRAINED; + + LOG_INFO("End HnswRabitqBuilder::train by trainer"); + + return 0; +} + +int HnswRabitqBuilder::build(IndexThreads::Pointer threads, + IndexHolder::Pointer holder) { + if (state_ != BUILD_STATE_TRAINED) { + LOG_ERROR("Train the index before HnswRabitqBuilder::build"); + return IndexError_NoReady; + } + + if (!holder) { + LOG_ERROR("Input holder is nullptr while building index"); + return IndexError_InvalidArgument; + } + if (!holder->is_matched(meta_)) { + LOG_ERROR("Input holder doesn't match index meta while building index"); + return IndexError_Mismatch; + } + IndexProvider::Pointer provider = + std::dynamic_pointer_cast(holder); + if (!provider) { + LOG_ERROR("Rabitq builder expect IndexProvider"); + return IndexError_InvalidArgument; + } + + if (!threads) { + threads = std::make_shared(thread_cnt_, false); + if (!threads) { + return IndexError_NoMemory; + } + } + + auto start_time = ailego::Monotime::MilliSeconds(); + LOG_INFO("Begin HnswRabitqBuilder::build"); + + if (holder->count() != static_cast(-1)) { + LOG_DEBUG("HnswRabitqBuilder holder documents count %lu", holder->count()); + int ret = entity_.reserve_space(holder->count()); + if (ret != 0) { + LOG_ERROR("HnswBuilde reserver space failed"); + return ret; + } + } + auto iter = holder->create_iterator(); + if (!iter) { + LOG_ERROR("Create iterator for holder failed"); + return IndexError_Runtime; + } + int ret; + error_ = false; + IndexQueryMeta ometa; + ometa.set_meta(holder->data_type(), holder->dimension()); + while (iter->is_valid()) { + const void *vec = iter->data(); + // quantize vector + std::string converted_vector; + IndexQueryMeta converted_meta; + ret = reformer_->convert(vec, ometa, &converted_vector, &converted_meta); + if (ret != 0) { + LOG_ERROR("Rabitq hnsw convert failed, ret=%d", ret); + return ret; + } + + + level_t level = alg_->get_random_level(); + node_id_t id; + + if (converted_vector.size() != entity_.vector_size()) { + LOG_ERROR( + "Converted vector size %zu is not equal to entity vector size %zu", + converted_vector.size(), entity_.vector_size()); + return IndexError_InvalidArgument; + } + ret = entity_.add_vector(level, iter->key(), converted_vector.data(), &id); + if (ailego_unlikely(ret != 0)) { + return ret; + } + iter->next(); + } + + LOG_INFO("Finished save vector, start build graph..."); + + auto task_group = threads->make_group(); + if (!task_group) { + LOG_ERROR("Failed to create task group"); + return IndexError_Runtime; + } + + std::atomic finished{0}; + for (size_t i = 0; i < threads->count(); ++i) { + task_group->submit(ailego::Closure ::New(this, &HnswRabitqBuilder::do_build, + i, threads->count(), provider, + &finished)); + } + + while (!task_group->is_finished()) { + std::unique_lock lk(mutex_); + cond_.wait_until(lk, std::chrono::system_clock::now() + + std::chrono::seconds(check_interval_secs_)); + if (error_.load(std::memory_order_acquire)) { + LOG_ERROR("Failed to build index while waiting finish"); + return errcode_; + } + LOG_INFO("Built cnt %zu, finished percent %.3f%%", + static_cast(finished.load()), + finished.load() * 100.0f / entity_.doc_cnt()); + } + if (error_.load(std::memory_order_acquire)) { + LOG_ERROR("Failed to build index while waiting finish"); + return errcode_; + } + task_group->wait_finish(); + + stats_.set_built_count(finished.load()); + stats_.set_built_costtime(ailego::Monotime::MilliSeconds() - start_time); + + state_ = BUILD_STATE_BUILT; + LOG_INFO("End HnswRabitqBuilder::build with RaBitQ quantization"); + return 0; +} + +void HnswRabitqBuilder::do_build(node_id_t idx, size_t step_size, + IndexProvider::Pointer provider, + std::atomic *finished) { + AILEGO_DEFER([&]() { + std::lock_guard latch(mutex_); + cond_.notify_one(); + }); + HnswRabitqContext *ctx = new (std::nothrow) HnswRabitqContext( + meta_.dimension(), metric_, + std::shared_ptr(&entity_, [](HnswRabitqEntity *) {})); + if (ailego_unlikely(ctx == nullptr)) { + if (!error_.exchange(true)) { + LOG_ERROR("Failed to create context"); + errcode_ = IndexError_NoMemory; + } + return; + } + HnswRabitqContext::Pointer auto_ptr(ctx); + ctx->set_provider(std::move(provider)); + ctx->set_max_scan_num(entity_.doc_cnt()); + int ret = ctx->init(HnswRabitqContext::kBuilderContext); + if (ret != 0) { + if (!error_.exchange(true)) { + LOG_ERROR("Failed to init context"); + errcode_ = IndexError_Runtime; + } + return; + } + + for (node_id_t id = idx; id < entity_.doc_cnt(); id += step_size) { + ctx->reset_query(ctx->dist_calculator().get_vector(id)); + ret = alg_->add_node(id, entity_.get_level(id), ctx); + if (ailego_unlikely(ret != 0)) { + if (!error_.exchange(true)) { + LOG_ERROR("Hnsw graph add node failed"); + errcode_ = ret; + } + return; + } + ctx->clear(); + (*finished)++; + } +} + +int HnswRabitqBuilder::dump(const IndexDumper::Pointer &dumper) { + if (state_ != BUILD_STATE_BUILT) { + LOG_INFO("Build the index before HnswRabitqBuilder::dump"); + return IndexError_NoReady; + } + + LOG_INFO("Begin HnswRabitqBuilder::dump"); + + meta_.set_searcher("HnswRabitqSearcher", HnswRabitqEntity::kRevision, + ailego::Params()); + auto start_time = ailego::Monotime::MilliSeconds(); + + int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); + if (ret != 0) { + LOG_ERROR("Failed to serialize meta into dumper."); + return ret; + } + + // Dump RaBitQ centroids first + if (converter_) { + ret = converter_->dump(dumper); + if (ret != 0) { + LOG_ERROR("Failed to dump RabitqConverter: %d", ret); + return ret; + } + LOG_INFO("RaBitQ centroids dumped: %zu bytes, cost %zu ms", + converter_->stats().dumped_size(), + static_cast(converter_->stats().dumped_costtime())); + } + + ret = entity_.dump(dumper); + if (ret != 0) { + LOG_ERROR("HnswRabitqBuilder dump index failed"); + return ret; + } + + stats_.set_dumped_count(entity_.doc_cnt()); + stats_.set_dumped_costtime(ailego::Monotime::MilliSeconds() - start_time); + + LOG_INFO("End HnswRabitqBuilder::dump"); + return 0; +} + +INDEX_FACTORY_REGISTER_BUILDER(HnswRabitqBuilder); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder.h new file mode 100644 index 00000000..ffb5abef --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder.h @@ -0,0 +1,101 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "zvec/core/framework/index_builder.h" +#include "zvec/core/framework/index_converter.h" +#include "zvec/core/framework/index_reformer.h" +#include "hnsw_rabitq_algorithm.h" +#include "hnsw_rabitq_builder_entity.h" + +namespace zvec { +namespace core { + +class HnswRabitqBuilder : public IndexBuilder { + public: + //! Constructor + HnswRabitqBuilder(); + + //! Initialize the builder + virtual int init(const IndexMeta &meta, + const ailego::Params ¶ms) override; + + //! Cleanup the builder + virtual int cleanup(void) override; + + //! Train the data + virtual int train(IndexThreads::Pointer, + IndexHolder::Pointer holder) override; + + //! Train the data + virtual int train(const IndexTrainer::Pointer &trainer) override; + + + //! Build the index + virtual int build(IndexThreads::Pointer threads, + IndexHolder::Pointer holder) override; + + //! Dump index into storage + virtual int dump(const IndexDumper::Pointer &dumper) override; + + //! Retrieve statistics + virtual const Stats &stats(void) const override { + return stats_; + } + + private: + void do_build(node_id_t idx, size_t step_size, + IndexProvider::Pointer provider, + std::atomic *finished); + + int train_converter_and_load_reformer(IndexHolder::Pointer holder); + + constexpr static uint32_t kDefaultLogIntervalSecs = 15U; + constexpr static uint32_t kMaxNeighborCnt = 65535; + + private: + enum BUILD_STATE { + BUILD_STATE_INIT = 0, + BUILD_STATE_INITED = 1, + BUILD_STATE_TRAINED = 2, + BUILD_STATE_BUILT = 3 + }; + + HnswRabitqBuilderEntity entity_{}; + HnswRabitqAlgorithm::UPointer alg_; // impl graph algorithm + uint32_t thread_cnt_{0}; + uint32_t min_neighbor_cnt_{0}; + uint32_t upper_max_neighbor_cnt_{ + HnswRabitqEntity::kDefaultUpperMaxNeighborCnt}; + uint32_t l0_max_neighbor_cnt_{HnswRabitqEntity::kDefaultL0MaxNeighborCnt}; + uint32_t ef_construction_{HnswRabitqEntity::kDefaultEfConstruction}; + uint32_t scaling_factor_{HnswRabitqEntity::kDefaultScalingFactor}; + uint32_t check_interval_secs_{kDefaultLogIntervalSecs}; + + int errcode_{0}; + std::atomic_bool error_{false}; + IndexMeta meta_{}; + IndexMetric::Pointer metric_{}; + IndexConverter::Pointer converter_{}; // RaBitQ converter + IndexReformer::Pointer reformer_{}; // RaBitQ reformer + std::mutex mutex_{}; + std::condition_variable cond_{}; + Stats stats_{}; + + BUILD_STATE state_{BUILD_STATE_INIT}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder_entity.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder_entity.cc new file mode 100644 index 00000000..43a3fc8a --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder_entity.cc @@ -0,0 +1,199 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "hnsw_rabitq_builder_entity.h" +#include +#include +#include "utility/sparse_utility.h" + +namespace zvec { +namespace core { + +HnswRabitqBuilderEntity::HnswRabitqBuilderEntity() { + update_ep_and_level(kInvalidNodeId, 0U); +} + +int HnswRabitqBuilderEntity::cleanup() { + memory_quota_ = 0UL; + neighbors_size_ = 0U; + upper_neighbors_size_ = 0U; + padding_size_ = 0U; + vectors_buffer_.clear(); + keys_buffer_.clear(); + neighbors_buffer_.clear(); + upper_neighbors_buffer_.clear(); + neighbors_index_.clear(); + + vectors_buffer_.shrink_to_fit(); + keys_buffer_.shrink_to_fit(); + neighbors_buffer_.shrink_to_fit(); + upper_neighbors_buffer_.shrink_to_fit(); + neighbors_index_.shrink_to_fit(); + + this->HnswRabitqEntity::cleanup(); + + return 0; +} + +int HnswRabitqBuilderEntity::init() { + size_t size = vector_size(); + + //! aligned size to 32 + set_node_size(AlignSize(size)); + //! if node size is aligned to 1k, the build performance will downgrade + if (node_size() % 1024 == 0) { + set_node_size(AlignSize(node_size() + 1)); + } + + padding_size_ = node_size() - size; + + neighbors_size_ = neighbors_size(); + upper_neighbors_size_ = upper_neighbors_size(); + + return 0; +} + +int HnswRabitqBuilderEntity::reserve_space(size_t docs) { + if (memory_quota_ > 0 && (node_size() * docs + neighbors_size_ * docs + + sizeof(NeighborIndex) * docs > + memory_quota_)) { + return IndexError_NoMemory; + } + + vectors_buffer_.reserve(node_size() * docs); + keys_buffer_.reserve(sizeof(key_t) * docs); + neighbors_buffer_.reserve(neighbors_size_ * docs); + neighbors_index_.reserve(docs); + + return 0; +} + +int HnswRabitqBuilderEntity::add_vector(level_t level, key_t key, + const void *vec, node_id_t *id) { + if (memory_quota_ > 0 && + (vectors_buffer_.capacity() + keys_buffer_.capacity() + + neighbors_buffer_.capacity() + upper_neighbors_buffer_.capacity() + + neighbors_index_.capacity() * sizeof(NeighborIndex)) > memory_quota_) { + LOG_ERROR("Add vector failed, used memory exceed quota, cur_doc=%zu", + static_cast(doc_cnt())); + return IndexError_NoMemory; + } + + vectors_buffer_.append(reinterpret_cast(vec), vector_size()); + vectors_buffer_.append(padding_size_, '\0'); + keys_buffer_.append(reinterpret_cast(&key), sizeof(key)); + + // init level 0 neighbors + neighbors_buffer_.append(neighbors_size_, '\0'); + + neighbors_index_.emplace_back(upper_neighbors_buffer_.size(), level); + + // init upper layer neighbors + for (level_t cur_level = 1; cur_level <= level; ++cur_level) { + upper_neighbors_buffer_.append(upper_neighbors_size_, '\0'); + } + + *id = (*mutable_doc_cnt())++; + + return 0; +} + +key_t HnswRabitqBuilderEntity::get_key(node_id_t id) const { + return *(reinterpret_cast(keys_buffer_.data() + + id * sizeof(key_t))); +} + +const void *HnswRabitqBuilderEntity::get_vector(node_id_t id) const { + return vectors_buffer_.data() + id * node_size(); +} + +int HnswRabitqBuilderEntity::get_vector( + const node_id_t id, IndexStorage::MemoryBlock &block) const { + const void *vec = get_vector(id); + block.reset((void *)vec); + return 0; +} + +int HnswRabitqBuilderEntity::get_vector(const node_id_t *ids, uint32_t count, + const void **vecs) const { + for (uint32_t i = 0; i < count; ++i) { + vecs[i] = vectors_buffer_.data() + ids[i] * node_size(); + } + + return 0; +} + +int HnswRabitqBuilderEntity::get_vector( + const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const { + const void *vecs[count]; + get_vector(ids, count, vecs); + for (uint32_t i = 0; i < count; ++i) { + vec_blocks.emplace_back(IndexStorage::MemoryBlock((void *)vecs[i])); + } + return 0; +} + +const Neighbors HnswRabitqBuilderEntity::get_neighbors(level_t level, + node_id_t id) const { + const NeighborsHeader *hd = get_neighbor_header(level, id); + return {hd->neighbor_cnt, hd->neighbors}; +} + +int HnswRabitqBuilderEntity::update_neighbors( + level_t level, node_id_t id, + const std::vector> &neighbors) { + NeighborsHeader *hd = + const_cast(get_neighbor_header(level, id)); + for (size_t i = 0; i < neighbors.size(); ++i) { + hd->neighbors[i] = neighbors[i].first; + } + hd->neighbor_cnt = neighbors.size(); + + // std::cout << "id: " << id << ", neighbour, id: "; + // for (size_t i = 0; i < neighbors.size(); ++i) { + // if (i == neighbors.size()-1) + // std::cout << neighbors[i].first << ", score:" << neighbors[i].second << + // std::endl; + // else + // std::cout << neighbors[i].first << ", score:" << neighbors[i].second << + // ", id: "; + // } + + return 0; +} + +void HnswRabitqBuilderEntity::add_neighbor(level_t level, node_id_t id, + uint32_t /*size*/, + node_id_t neighbor_id) { + NeighborsHeader *hd = + const_cast(get_neighbor_header(level, id)); + hd->neighbors[hd->neighbor_cnt++] = neighbor_id; + + return; +} + +int HnswRabitqBuilderEntity::dump(const IndexDumper::Pointer &dumper) { + key_t *keys = + reinterpret_cast(const_cast(keys_buffer_.data())); + auto ret = + dump_segments(dumper, keys, [&](node_id_t id) { return get_level(id); }); + if (ailego_unlikely(ret < 0)) { + return ret; + } + + return 0; +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder_entity.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder_entity.h new file mode 100644 index 00000000..7460674b --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_builder_entity.h @@ -0,0 +1,139 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "hnsw_rabitq_entity.h" + +namespace zvec { +namespace core { + +class HnswRabitqBuilderEntity : public HnswRabitqEntity { + public: + //! Add vector and key to hnsw entity, and local id will be saved to id + virtual int add_vector(level_t level, key_t key, const void *vec, + node_id_t *id) override; + + //! Get primary key of the node id + virtual key_t get_key(node_id_t id) const override; + + //! Get vector feature data by key + virtual const void *get_vector(node_id_t id) const override; + + //! Batch get vectors feature data by keys + virtual int get_vector(const node_id_t *ids, uint32_t count, + const void **vecs) const override; + + virtual int get_vector(const node_id_t id, + IndexStorage::MemoryBlock &block) const override; + virtual int get_vector( + const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const override; + + //! Get the node id's neighbors on graph level + const NeighborsHeader *get_neighbor_header(level_t level, + node_id_t id) const { + if (level == 0) { + return reinterpret_cast( + neighbors_buffer_.data() + neighbors_size_ * id); + } else { + size_t offset = neighbors_index_[id].offset; + return reinterpret_cast( + upper_neighbors_buffer_.data() + offset + + (level - 1) * upper_neighbors_size_); + } + } + + //! Get the node id's neighbors on graph level + virtual const Neighbors get_neighbors(level_t level, + node_id_t id) const override; + + //! Replace node id in level's neighbors + virtual int update_neighbors( + level_t level, node_id_t id, + const std::vector> &neighbors) + override; + + //! add a neighbor to id in graph level + virtual void add_neighbor(level_t level, node_id_t id, uint32_t size, + node_id_t neighbor_id) override; + + //! Dump the hnsw graph to dumper + virtual int dump(const IndexDumper::Pointer &dumper) override; + + //! Cleanup the entity + virtual int cleanup(void) override; + + public: + //! Constructor + HnswRabitqBuilderEntity(); + + //! Get the node graph level by id + level_t get_level(node_id_t id) const { + return neighbors_index_[id].level; + } + + //! Init builerEntity + int init(); + + //! reserve buffer space for documents + //! @param docs number of documents + int reserve_space(size_t docs); + + //! Set memory quota params + inline void set_memory_quota(size_t memory_quota) { + memory_quota_ = memory_quota; + } + + //! Get neighbors size + inline size_t neighbors_size() const { + return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); + } + + //! Get upper neighbors size + inline size_t upper_neighbors_size() const { + return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); + } + + public: + HnswRabitqBuilderEntity(const HnswRabitqBuilderEntity &) = delete; + HnswRabitqBuilderEntity &operator=(const HnswRabitqBuilderEntity &) = delete; + + private: + friend class HnswRabitqSearcherEntity; + //! class internal used only + struct NeighborIndex { + NeighborIndex(size_t off, level_t l) : offset(off), level(l) {} + uint64_t offset : 48; + uint64_t level : 16; + }; + + std::string vectors_buffer_{}; // aligned vectors + std::string keys_buffer_{}; // aligned vectors + std::string neighbors_buffer_{}; // level 0 neighbors buffer + std::string upper_neighbors_buffer_{}; // upper layer neighbors buffer + + std::string sparse_data_buffer_{}; // aligned spase data buffer + size_t sparse_data_offset_{0}; // + + // upper layer offset + level in upper_neighbors_buffer_ + std::vector neighbors_index_{}; + size_t memory_quota_{0UL}; + size_t neighbors_size_{0U}; // level 0 neighbors size + size_t upper_neighbors_size_{0U}; // level 0 neighbors size + size_t padding_size_{}; // padding size for each vector element +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_chunk.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_chunk.cc new file mode 100644 index 00000000..7be23a5f --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_chunk.cc @@ -0,0 +1,221 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "hnsw_rabitq_chunk.h" +#include +#include +#include +#include +#include "zvec/core/framework/index_error.h" +#include "zvec/core/framework/index_helper.h" +#include "zvec/core/framework/index_logger.h" +#include "zvec/core/framework/index_streamer.h" + +namespace zvec { +namespace core { + +int HnswRabitqChunkBroker::init_storage(size_t chunk_size) { + chunk_meta_.clear(); + chunk_meta_.chunk_size = chunk_size; + chunk_meta_.create_time = ailego::Realtime::Seconds(); + stats_.set_create_time(chunk_meta_.create_time); + chunk_meta_.update_time = ailego::Realtime::Seconds(); + stats_.set_update_time(chunk_meta_.update_time); + + //! alloc meta chunk + size_t size = sizeof(HnswChunkMeta); + size = (size + page_mask_) & (~page_mask_); + const std::string segment_id = + make_segment_id(CHUNK_TYPE_META, kDefaultChunkSeqId); + int ret = stg_->append(segment_id, size); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Storage append segment failed for %s", IndexError::What(ret)); + return ret; + } + chunk_meta_segment_ = get_chunk(CHUNK_TYPE_META, kDefaultChunkSeqId); + if (ailego_unlikely(!chunk_meta_segment_)) { + LOG_ERROR("Get meta segment failed"); + return IndexError_Runtime; + } + + //! update meta info and write to storage + chunk_meta_.chunk_cnts[CHUNK_TYPE_META] += 1; + chunk_meta_.total_size += size; + (*stats_.mutable_index_size()) += size; + size = chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswChunkMeta)); + if (ailego_unlikely(size != sizeof(HnswChunkMeta))) { + LOG_ERROR("Storage write data failed, wsize=%zu", size); + return IndexError_WriteData; + } + + return 0; +} + +int HnswRabitqChunkBroker::load_storage(size_t chunk_size) { + IndexStorage::MemoryBlock data_block; + size_t size = chunk_meta_segment_->read(0UL, data_block, + chunk_meta_segment_->data_size()); + if (size != sizeof(HnswChunkMeta)) { + LOG_ERROR("Invalid hnsw meta chunk, read size=%zu chunk size=%zu", size, + chunk_meta_segment_->data_size()); + return IndexError_InvalidFormat; + } + std::memcpy(&chunk_meta_, data_block.data(), size); + if (chunk_meta_.chunk_size != chunk_size) { + LOG_ERROR( + "Params hnsw chunk size=%zu mismatch from previous %zu " + "in index", + chunk_size, (size_t)chunk_meta_.chunk_size); + return IndexError_Mismatch; + } + + *stats_.mutable_check_point() = stg_->check_point(); + stats_.set_revision_id(chunk_meta_.revision_id); + stats_.set_update_time(chunk_meta_.update_time); + stats_.set_create_time(chunk_meta_.create_time); + + char create_time[32]; + char update_time[32]; + ailego::Realtime::Gmtime(chunk_meta_.create_time, "%Y-%m-%d %H:%M:%S", + create_time, sizeof(create_time)); + ailego::Realtime::Gmtime(chunk_meta_.update_time, "%Y-%m-%d %H:%M:%S", + update_time, sizeof(update_time)); + LOG_DEBUG( + "Load index, indexSize=%zu chunkSize=%zu nodeChunks=%zu " + "upperNeighborChunks=%zu revisionId=%zu " + "createTime=%s updateTime=%s", + (size_t)chunk_meta_.total_size, (size_t)chunk_meta_.chunk_size, + (size_t)chunk_meta_.chunk_cnts[CHUNK_TYPE_NODE], + (size_t)chunk_meta_.chunk_cnts[CHUNK_TYPE_UPPER_NEIGHBOR], + (size_t)chunk_meta_.revision_id, create_time, update_time); + + return 0; +} + +int HnswRabitqChunkBroker::open(IndexStorage::Pointer stg, + size_t max_index_size, size_t chunk_size, + bool check_crc) { + if (ailego_unlikely(stg_)) { + LOG_ERROR("An storage instance is already opened"); + return IndexError_Duplicate; + } + stg_ = std::move(stg); + if (stg_->isHugePage()) { + page_mask_ = ailego::MemoryHelper::HugePageSize() - 1; + } else { + page_mask_ = ailego::MemoryHelper::PageSize() - 1; + } + check_crc_ = check_crc; + max_chunks_size_ = max_index_size; + dirty_ = false; + + const std::string segment_id = + make_segment_id(CHUNK_TYPE_META, kDefaultChunkSeqId); + chunk_meta_segment_ = stg_->get(segment_id); + if (!chunk_meta_segment_) { + LOG_DEBUG("Create new index"); + return init_storage(chunk_size); + } + + return load_storage(chunk_size); +} + +int HnswRabitqChunkBroker::close(void) { + flush(0UL); + + stg_.reset(); + check_crc_ = false; + dirty_ = false; + + return 0; +} + +int HnswRabitqChunkBroker::flush(uint64_t checkpoint) { + ailego_assert_with(chunk_meta_segment_, "invalid meta segment"); + + chunk_meta_.update_time = ailego::Realtime::Seconds(); + stats_.set_update_time(chunk_meta_.update_time); + + size_t size = + chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswChunkMeta)); + if (ailego_unlikely(size != sizeof(HnswChunkMeta))) { + LOG_ERROR("Storage write data failed, wsize=%zu", size); + } + + stg_->refresh(checkpoint); + int ret = stg_->flush(); + if (ret == 0) { + (*stats_.mutable_check_point()) = checkpoint; + } else { + LOG_ERROR("Storage flush failed for %s", IndexError::What(ret)); + } + return ret; +} + +std::pair HnswRabitqChunkBroker::alloc_chunk( + int type, uint64_t seq_id, size_t size) { + ailego_assert_with(type < CHUNK_TYPE_MAX, "chunk type overflow"); + + Chunk::Pointer chunk; + if (ailego_unlikely(!stg_)) { + LOG_ERROR("Init storage first"); + return std::make_pair(IndexError_Uninitialized, chunk); + } + + //! check exist a empty chunk with the same name + chunk = get_chunk(type, seq_id); + if (chunk) { + if (ailego_unlikely(chunk->capacity() == size && + chunk->data_size() == 0UL)) { + LOG_ERROR("Exist invalid chunk size %zu, expect size %zu", + chunk->capacity(), size); + chunk.reset(); + return std::make_pair(IndexError_Runtime, chunk); + } + return std::make_pair(0, chunk); + } + //! align to page size + size = (size + page_mask_) & (~page_mask_); + if (ailego_unlikely(chunk_meta_.total_size + size >= max_chunks_size_)) { + LOG_ERROR("No space to new a chunk, curIndexSize=%zu allocSize=%zu", + (size_t)chunk_meta_.total_size, size); + return std::make_pair(IndexError_IndexFull, chunk); + } + + std::string segment_id = make_segment_id(type, seq_id); + int ret = stg_->append(segment_id, size); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Storage append segment failed for %s", IndexError::What(ret)); + return std::make_pair(ret, chunk); + } + chunk_meta_.chunk_cnts[type] += 1; + chunk_meta_.total_size += size; + (*stats_.mutable_index_size()) += size; + + size = chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswChunkMeta)); + if (ailego_unlikely(size != sizeof(HnswChunkMeta))) { + LOG_ERROR("Storage append segment failed, wsize=%zu", size); + } + + chunk = get_chunk(type, seq_id); + return std::make_pair(chunk ? 0 : IndexError_NoMemory, chunk); +} + +Chunk::Pointer HnswRabitqChunkBroker::get_chunk(int type, + uint64_t seq_id) const { + std::string segment_id = make_segment_id(type, seq_id); + return stg_->get(segment_id); +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_chunk.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_chunk.h new file mode 100644 index 00000000..823bf5c4 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_chunk.h @@ -0,0 +1,140 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "zvec/core/framework/index_error.h" +#include "zvec/core/framework/index_logger.h" +#include "zvec/core/framework/index_storage.h" +#include "zvec/core/framework/index_streamer.h" + +namespace zvec { +namespace core { + +using Chunk = IndexStorage::Segment; + +class HnswRabitqChunkBroker { + public: + typedef std::shared_ptr Pointer; + + enum CHUNK_TYPE { + CHUNK_TYPE_HEADER = 1, + CHUNK_TYPE_META = 2, + CHUNK_TYPE_NODE = 3, + CHUNK_TYPE_UPPER_NEIGHBOR = 4, + CHUNK_TYPE_NEIGHBOR_INDEX = 5, + CHUNK_TYPE_SPARSE_NODE = 6, + CHUNK_TYPE_MAX = 8 + }; + static constexpr size_t kDefaultChunkSeqId = 0UL; + + HnswRabitqChunkBroker(IndexStreamer::Stats &stats) : stats_(stats) {} + + //! Open storage + int open(IndexStorage::Pointer stg, size_t max_index_size, size_t chunk_size, + bool check_crc); + + int close(void); + + int flush(uint64_t checkpoint); + + //! alloc a new chunk with size, not thread-safe + std::pair alloc_chunk(int type, uint64_t seq_id, + size_t size); + + //! alloc a new chunk with chunk size + inline std::pair alloc_chunk(int type, uint64_t seq_id) { + return alloc_chunk(type, seq_id, chunk_meta_.chunk_size); + } + + Chunk::Pointer get_chunk(int type, uint64_t seq_id) const; + + inline size_t get_chunk_cnt(int type) const { + ailego_assert_with(type < CHUNK_TYPE_MAX, "chunk type overflow"); + return chunk_meta_.chunk_cnts[type]; + } + + inline bool dirty(void) const { + return dirty_; + } + + inline void mark_dirty(void) { + if (!dirty_) { + dirty_ = true; + chunk_meta_.revision_id += 1; + stats_.set_revision_id(chunk_meta_.revision_id); + } + } + + const IndexStorage::Pointer storage(void) const { + return stg_; + } + + private: + HnswRabitqChunkBroker(const HnswRabitqChunkBroker &) = delete; + HnswRabitqChunkBroker &operator=(const HnswRabitqChunkBroker &) = delete; + + struct HnswChunkMeta { + HnswChunkMeta(void) { + memset(this, 0, sizeof(HnswChunkMeta)); + } + void clear() { + memset(this, 0, sizeof(HnswChunkMeta)); + } + + uint64_t chunk_cnts[CHUNK_TYPE_MAX]; + uint64_t chunk_size; // size of per chunk + uint64_t total_size; // total size of allocated chunk + uint64_t revision_id; // index revision + uint64_t create_time; + uint64_t update_time; + uint64_t reserved[3]; + }; + + static_assert(sizeof(HnswChunkMeta) % 32 == 0, + "HnswChunkMeta must be aligned with 32 bytes"); + + //! Init the storage after open an empty index + int init_storage(size_t chunk_size); + + //! Load index from storage + int load_storage(size_t chunk_size); + + static inline const std::string make_segment_id(int type, uint64_t seq_id) { + return "HnswT" + ailego::StringHelper::ToString(type) + "S" + + ailego::StringHelper::ToString(seq_id); + } + + private: + IndexStreamer::Stats &stats_; + HnswChunkMeta chunk_meta_{}; + size_t page_mask_{0UL}; + size_t max_chunks_size_{0UL}; + IndexStorage::Pointer stg_{}; + IndexStorage::Segment::Pointer chunk_meta_segment_{}; + bool check_crc_{false}; + bool dirty_{false}; // set as true if index is modified , the flag + // will not be cleared even if flushed +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_context.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_context.cc new file mode 100644 index 00000000..5a528e38 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_context.cc @@ -0,0 +1,297 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "hnsw_rabitq_context.h" +#include +#include "hnsw_rabitq_params.h" + +namespace zvec { +namespace core { + +HnswRabitqContext::HnswRabitqContext(size_t dimension, + const IndexMetric::Pointer &metric, + const HnswRabitqEntity::Pointer &entity) + : IndexContext(metric), + entity_(entity), + add_dc_(entity_.get(), metric, dimension) {} + +HnswRabitqContext::HnswRabitqContext(const IndexMetric::Pointer &metric, + const HnswRabitqEntity::Pointer &entity) + : IndexContext(metric), entity_(entity), add_dc_(entity_.get(), metric) {} + +HnswRabitqContext::~HnswRabitqContext() { + visit_filter_.destroy(); +} + +int HnswRabitqContext::init(ContextType type) { + int ret; + uint32_t doc_cnt; + + type_ = type; + + switch (type) { + case kBuilderContext: + ret = visit_filter_.init(VisitFilter::ByteMap, entity_->doc_cnt(), + max_scan_num_, negative_probability_); + if (ret != 0) { + LOG_ERROR("Create filter failed, mode %d", filter_mode_); + return ret; + } + candidates_.limit(max_scan_num_); + update_heap_.limit(entity_->l0_neighbor_cnt() + 1); + break; + + case kSearcherContext: + ret = visit_filter_.init(filter_mode_, entity_->doc_cnt(), max_scan_num_, + negative_probability_); + if (ret != 0) { + LOG_ERROR("Create filter failed, mode %d", filter_mode_); + return ret; + } + candidates_.limit(max_scan_num_); + break; + + case kStreamerContext: + // maxScanNum is unknown if inited from streamer, so the docCnt may + // change. we need to compute maxScanNum by scan ratio, and preserve + // max_doc_cnt space from visit filter + doc_cnt = entity_->doc_cnt(); + max_scan_num_ = compute_max_scan_num(doc_cnt); + reserve_max_doc_cnt_ = doc_cnt + compute_reserve_cnt(doc_cnt); + ret = visit_filter_.init(filter_mode_, reserve_max_doc_cnt_, + max_scan_num_, negative_probability_); + if (ret != 0) { + LOG_ERROR("Create filter failed, mode %d", filter_mode_); + return ret; + } + + update_heap_.limit(entity_->l0_neighbor_cnt() + 1); + candidates_.limit(max_scan_num_); + + check_need_adjuct_ctx(); + break; + + default: + LOG_ERROR("Init context failed"); + return IndexError_Runtime; + } + + return 0; +} + +int HnswRabitqContext::update(const ailego::Params ¶ms) { + auto update_visit_filter_param = [&]() { + bool need_update = false; + std::string p; + switch (type_) { + case kSearcherContext: + p = PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_ENABLE; + break; + case kStreamerContext: + p = PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_ENABLE; + break; + } + + if (params.has(p)) { + bool bf_enabled; + params.get(p, &bf_enabled); + if (bf_enabled ^ (filter_mode_ == VisitFilter::BloomFilter)) { + need_update = true; + filter_mode_ = + bf_enabled ? VisitFilter::BloomFilter : VisitFilter::ByteMap; + } + } + + float prob = negative_probability_; + p.clear(); + switch (type_) { + case kSearcherContext: + p = PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB; + break; + case kStreamerContext: + p = PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB; + break; + } + params.get(p, &prob); + if (filter_mode_ == VisitFilter::BloomFilter && + std::abs(prob - negative_probability_) > 1e-6) { + need_update = true; + } + if (need_update) { + visit_filter_.destroy(); + int max_doc_cnt = 0; + if (type_ == kSearcherContext) { + max_doc_cnt = entity_->doc_cnt(); + } else { + max_doc_cnt = reserve_max_doc_cnt_; + } + int ret = visit_filter_.init(filter_mode_, max_doc_cnt, max_scan_num_, + negative_probability_); + if (ret != 0) { + LOG_ERROR("Create filter failed, mode %d", filter_mode_); + return ret; + } + } + return 0; + }; + + switch (type_) { + case kSearcherContext: + if (params.has(PARAM_HNSW_RABITQ_SEARCHER_EF)) { + params.get(PARAM_HNSW_RABITQ_SEARCHER_EF, &ef_); + topk_heap_.limit(std::max(topk_, ef_)); + } + + if (params.has(PARAM_HNSW_RABITQ_SEARCHER_MAX_SCAN_RATIO)) { + params.get(PARAM_HNSW_RABITQ_SEARCHER_MAX_SCAN_RATIO, &max_scan_ratio_); + max_scan_num_ = + static_cast(max_scan_ratio_ * entity_->doc_cnt()); + max_scan_num_ = std::max(10000U, max_scan_num_); + } + + if (params.has(PARAM_HNSW_RABITQ_SEARCHER_BRUTE_FORCE_THRESHOLD)) { + params.get(PARAM_HNSW_RABITQ_SEARCHER_BRUTE_FORCE_THRESHOLD, + &bruteforce_threshold_); + } + + return update_visit_filter_param(); + + case kStreamerContext: + if (params.has(PARAM_HNSW_RABITQ_STREAMER_EF)) { + params.get(PARAM_HNSW_RABITQ_STREAMER_EF, &ef_); + topk_heap_.limit(std::max(topk_, ef_)); + } + params.get(PARAM_HNSW_RABITQ_STREAMER_EF, &ef_); + params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_RATIO, &max_scan_ratio_); + params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_LIMIT, &max_scan_limit_); + params.get(PARAM_HNSW_RABITQ_STREAMER_MIN_SCAN_LIMIT, &min_scan_limit_); + if (max_scan_ratio_ <= 0.0f || max_scan_ratio_ > 1.0f) { + LOG_ERROR("[%s] must be in range (0.0f,1.0f]", + PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_RATIO.c_str()); + return IndexError_InvalidArgument; + } + if (max_scan_limit_ < min_scan_limit_) { + LOG_ERROR("[%s] must be >= [%s]", + PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_LIMIT.c_str(), + PARAM_HNSW_RABITQ_STREAMER_MIN_SCAN_LIMIT.c_str()); + return IndexError_InvalidArgument; + } + + if (params.has(PARAM_HNSW_RABITQ_STREAMER_BRUTE_FORCE_THRESHOLD)) { + params.get(PARAM_HNSW_RABITQ_STREAMER_BRUTE_FORCE_THRESHOLD, + &bruteforce_threshold_); + } + + return update_visit_filter_param(); + + default: + LOG_ERROR("update context failed, type=%zu", static_cast(type_)); + return IndexError_Runtime; + } +} + +int HnswRabitqContext::update_context(ContextType type, const IndexMeta &meta, + const IndexMetric::Pointer &metric, + const HnswRabitqEntity::Pointer &entity, + uint32_t magic_num) { + uint32_t doc_cnt; + + if (ailego_unlikely(type != type_)) { + LOG_ERROR( + "HnswRabitqContext doesn't support shared by different type, " + "src=%u dst=%u", + type_, type); + return IndexError_Unsupported; + } + + magic_ = kInvalidMgic; + + // TODO: support change filter mode? + switch (type) { + case kBuilderContext: + LOG_ERROR("BuildContext doesn't support update"); + return IndexError_NotImplemented; + + case kSearcherContext: + if (!visit_filter_.reset(entity->doc_cnt(), max_scan_num_)) { + LOG_ERROR("Reset filter failed, mode %d", visit_filter_.get_mode()); + return IndexError_Runtime; + } + + candidates_.limit(max_scan_num_); + topk_heap_.limit(std::max(topk_, ef_)); + break; + + case kStreamerContext: + doc_cnt = entity->doc_cnt(); + max_scan_num_ = compute_max_scan_num(doc_cnt); + reserve_max_doc_cnt_ = doc_cnt + compute_reserve_cnt(doc_cnt); + if (!visit_filter_.reset(reserve_max_doc_cnt_, max_scan_num_)) { + LOG_ERROR("Reset filter failed, mode %d", visit_filter_.get_mode()); + return IndexError_Runtime; + } + + update_heap_.limit(entity->l0_neighbor_cnt() + 1); + candidates_.limit(max_scan_num_); + topk_heap_.limit(std::max(topk_, ef_)); + break; + + default: + LOG_ERROR("update context failed"); + return IndexError_Runtime; + } + + entity_ = entity; + dc().update(entity_.get(), metric, meta.dimension()); + magic_ = magic_num; + level_topks_.clear(); + + return 0; +} + +void HnswRabitqContext::fill_random_to_topk_full(void) { + static std::mt19937 mt( + std::chrono::system_clock::now().time_since_epoch().count()); + std::uniform_int_distribution dt(0, entity_->doc_cnt() - 1); + std::function gen; + node_id_t seqid; + std::function myfilter = [](node_id_t) { return false; }; + if (this->filter().is_valid()) { + myfilter = [&](node_id_t id) { + return this->filter()(entity_->get_key(id)); + }; + } + + if (topk_heap_.limit() < entity_->doc_cnt() / 2) { + gen = [&](void) { return dt(mt); }; + } else { + // If topk limit is big value, gen sequential id from an random initial + seqid = dt(mt); + gen = [&](void) { + seqid = seqid == (entity_->doc_cnt() - 1) ? 0 : (seqid + 1); + return seqid; + }; + } + + for (size_t i = 0; !topk_heap_.full() && i < entity_->doc_cnt(); ++i) { + const auto id = gen(); + if (!visit_filter_.visited(id) && !myfilter(id)) { + visit_filter_.set_visited(id); + topk_heap_.emplace(id, dc().dist(id)); + } + } + return; +} + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_context.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_context.h new file mode 100644 index 00000000..7e73e9e8 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_context.h @@ -0,0 +1,542 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "utility/visit_filter.h" +#include "zvec/core/framework/index_context.h" +#include "zvec/core/framework/index_provider.h" +#include "hnsw_rabitq_dist_calculator.h" +#include "hnsw_rabitq_entity.h" + +namespace zvec { +namespace core { + +class HnswRabitqContext : public IndexContext { + public: + //! Index Context Pointer + typedef std::unique_ptr Pointer; + + enum ContextType { + kUnknownContext = 0, + kSearcherContext = 1, + kBuilderContext = 2, + kStreamerContext = 3 + }; + + //! Construct + HnswRabitqContext(size_t dimension, const IndexMetric::Pointer &metric, + const HnswRabitqEntity::Pointer &entity); + + //! Construct + HnswRabitqContext(const IndexMetric::Pointer &metric, + const HnswRabitqEntity::Pointer &entity); + + //! Destructor + virtual ~HnswRabitqContext(); + + public: + //! Set topk of search result + virtual void set_topk(uint32_t val) override { + topk_ = val; + topk_heap_.limit(std::max(val, ef_)); + } + + //! Retrieve search result + virtual const IndexDocumentList &result(void) const override { + return results_[0]; + } + + //! Retrieve search result + virtual const IndexDocumentList &result(size_t idx) const override { + return results_[idx]; + } + + //! Retrieve result object for output + virtual IndexDocumentList *mutable_result(size_t idx) override { + ailego_assert_with(idx < results_.size(), "invalid idx"); + return &results_[idx]; + } + + //! Retrieve search group result with index + virtual const IndexGroupDocumentList &group_result(void) const override { + return group_results_[0]; + } + + //! Retrieve search group result with index + virtual const IndexGroupDocumentList &group_result( + size_t idx) const override { + return group_results_[idx]; + } + + virtual uint32_t magic(void) const override { + return magic_; + } + + //! Set mode of debug + virtual void set_debug_mode(bool enable) override { + debug_mode_ = enable; + } + + //! Retrieve mode of debug + virtual bool debug_mode(void) const override { + return this->debugging(); + } + + //! Retrieve string of debug + virtual std::string debug_string(void) const override { + char buf[4096]; + size_t size = snprintf( + buf, sizeof(buf), + "scan_cnt=%zu,get_vector_cnt=%u,get_neighbors_cnt=%u,dup_node=%u", + get_scan_num(), stats_get_vector_cnt_, stats_get_neighbors_cnt_, + stats_visit_dup_cnt_); + return std::string(buf, size); + } + + //! Update the parameters of context + virtual int update(const ailego::Params ¶ms) override; + + public: + //! Init context + int init(ContextType type); + + //! Update context, the context may be shared by different searcher/streamer + int update_context(ContextType type, const IndexMeta &meta, + const IndexMetric::Pointer &metric, + const HnswRabitqEntity::Pointer &entity, + uint32_t magic_num); + + inline const HnswRabitqEntity &get_entity() const { + return *entity_; + } + + inline void resize_results(size_t size) { + if (group_by_search()) { + group_results_.resize(size); + } else { + results_.resize(size); + } + } + + inline void topk_to_result() { + return topk_to_result(0); + } + + //! Construct result from topk heap, result will be normalized + inline void topk_to_result(uint32_t idx) { + if (group_by_search()) { + topk_to_group_result(idx); + } else { + topk_to_single_result(idx); + } + } + + inline void topk_to_single_result(uint32_t idx) { + if (force_padding_topk_ && !topk_heap_.full() && + topk_heap_.size() < entity_->doc_cnt()) { + this->fill_random_to_topk_full(); + } + if (ailego_unlikely(topk_heap_.size() == 0)) { + return; + } + + ailego_assert_with(idx < results_.size(), "invalid idx"); + int size = std::min(topk_, static_cast(topk_heap_.size())); + topk_heap_.sort(); + results_[idx].clear(); + + for (int i = 0; i < size; ++i) { + auto score = topk_heap_[i].second; + if (score.est_dist > this->threshold()) { + break; + } + + node_id_t id = topk_heap_[i].first; + if (fetch_vector_) { + results_[idx].emplace_back(entity_->get_key(id), score.est_dist, id, + entity_->get_vector(id)); + } else { + results_[idx].emplace_back(entity_->get_key(id), score.est_dist, id); + } + } + + return; + } + + //! Construct result from topk heap, result will be normalized + inline void topk_to_group_result(uint32_t idx) { + // ailego_assert_with(idx < group_results_.size(), "invalid idx"); + + // group_results_[idx].clear(); + + // std::vector> group_topk_list; + // std::vector> best_score_in_groups; + // for (auto itr = group_topk_heaps_.begin(); itr != + // group_topk_heaps_.end(); + // itr++) { + // const std::string &group_id = (*itr).first; + // auto &heap = (*itr).second; + // heap.sort(); + + // if (heap.size() > 0) { + // float best_score = heap[0].second; + // best_score_in_groups.push_back(std::make_pair(group_id, best_score)); + // } + // } + + // std::sort(best_score_in_groups.begin(), best_score_in_groups.end(), + // [](const std::pair &a, + // const std::pair &b) -> int { + // return a.second < b.second; + // }); + + // // truncate to group num + // for (uint32_t i = 0; i < group_num() && i < best_score_in_groups.size(); + // ++i) { + // const std::string &group_id = best_score_in_groups[i].first; + + // group_topk_list.emplace_back( + // std::make_pair(group_id, group_topk_heaps_[group_id])); + // } + + // group_results_[idx].resize(group_topk_list.size()); + + // for (uint32_t i = 0; i < group_topk_list.size(); ++i) { + // const std::string &group_id = group_topk_list[i].first; + // group_results_[idx][i].set_group_id(group_id); + + // uint32_t size = std::min( + // group_topk_, + // static_cast(group_topk_list[i].second.size())); + + // for (uint32_t j = 0; j < size; ++j) { + // auto score = group_topk_list[i].second[j].second; + // if (score > this->threshold()) { + // break; + // } + + // node_id_t id = group_topk_list[i].second[j].first; + + // if (fetch_vector_) { + // group_results_[idx][i].mutable_docs()->emplace_back( + // entity_->get_key(id), score, id, entity_->get_vector(id)); + // } else { + // group_results_[idx][i].mutable_docs()->emplace_back( + // entity_->get_key(id), score, id); + // } + // } + // } + } + + inline void reset_query(const void *query) { + if (auto query_preprocess_func = index_metric_->get_query_preprocess_func(); + query_preprocess_func != nullptr) { + size_t dim = dc().dimension(); + preprocess_buffer_.resize(dim); + memcpy(preprocess_buffer_.data(), query, dim); + query_preprocess_func(preprocess_buffer_.data(), dim); + query = preprocess_buffer_.data(); + } + + dc().reset_query(query); + dc().clear_compare_cnt(); + query_ = query; + } + + inline HnswRabitqAddDistCalculator &dist_calculator() { + return dc(); + } + + inline TopkHeap &topk_heap() { + return topk_heap_; + } + + inline TopkHeap &update_heap() { + return update_heap_; + } + + inline VisitFilter &visit_filter() { + return visit_filter_; + } + + inline CandidateHeap &candidates() { + return candidates_; + } + + inline void set_max_scan_num(uint32_t max_scan_num) { + max_scan_num_ = max_scan_num; + } + + inline void set_max_scan_limit(uint32_t max_scan_limit) { + max_scan_limit_ = max_scan_limit; + } + + inline void set_min_scan_limit(uint32_t min_scan_limit) { + min_scan_limit_ = min_scan_limit; + } + + inline void set_ef(uint32_t v) { + ef_ = v; + } + + inline void set_filter_mode(uint32_t v) { + filter_mode_ = v; + } + + inline void set_filter_negative_probability(float v) { + negative_probability_ = v; + } + + inline void set_max_scan_ratio(float v) { + max_scan_ratio_ = v; + } + + virtual void set_magic(uint32_t v) { + magic_ = v; + } + + virtual void set_force_padding_topk(bool v) { + force_padding_topk_ = v; + } + + void set_bruteforce_threshold(uint32_t v) override { + bruteforce_threshold_ = v; + } + + inline uint32_t get_bruteforce_threshold() const { + return bruteforce_threshold_; + } + + void set_fetch_vector(bool v) override { + fetch_vector_ = v; + } + + bool fetch_vector() const override { + return fetch_vector_; + } + + //! Reset context + void reset(void) override { + set_filter(nullptr); + reset_threshold(); + set_fetch_vector(false); + set_group_params(0, 0); + reset_group_by(); + } + + inline std::map &group_topk_heaps() { + return group_topk_heaps_; + } + + inline TopkHeap &level_topk(int level) { + if (ailego_unlikely(level_topks_.size() <= static_cast(level))) { + int cur_level = level_topks_.size(); + level_topks_.resize(level + 1); + for (; cur_level <= level; ++cur_level) { + size_t heap_size = std::max(entity_->neighbor_cnt(cur_level), + entity_->ef_construction()); + level_topks_[cur_level].clear(); + level_topks_[cur_level].limit(heap_size); + } + } + + return level_topks_[level]; + } + + inline void check_need_adjuct_ctx(void) { + check_need_adjuct_ctx(entity_->doc_cnt()); + } + + inline size_t compute_reserve_cnt(uint32_t cur_doc) const { + if (cur_doc > kMaxReserveDocCnt) { + return kMaxReserveDocCnt; + } else if (cur_doc < kMinReserveDocCnt) { + return kMinReserveDocCnt; + } + return cur_doc; + } + + //! candidates heap and visitfilter need to resize as doc cnt growing up + inline void check_need_adjuct_ctx(uint32_t doc_cnt) { + if (ailego_unlikely(doc_cnt + kTriggerReserveCnt > reserve_max_doc_cnt_)) { + while (doc_cnt + kTriggerReserveCnt > reserve_max_doc_cnt_) { + reserve_max_doc_cnt_ = + reserve_max_doc_cnt_ + compute_reserve_cnt(reserve_max_doc_cnt_); + } + uint32_t max_scan_cnt = compute_max_scan_num(reserve_max_doc_cnt_); + max_scan_num_ = max_scan_cnt; + visit_filter_.reset(reserve_max_doc_cnt_, max_scan_cnt); + candidates_.clear(); + candidates_.limit(max_scan_num_); + } + } + + inline uint32_t compute_max_scan_num(uint32_t max_doc_cnt) const { + uint32_t max_scan = max_doc_cnt * max_scan_ratio_; + if (max_scan < min_scan_limit_) { + max_scan = min_scan_limit_; + } else if (max_scan > max_scan_limit_) { + max_scan = max_scan_limit_; + } + return max_scan; + } + + inline size_t get_scan_num() const { + return dc().compare_cnt(); + } + + inline uint64_t reach_scan_limit() const { + return dc().compare_cnt() >= max_scan_num_; + } + + inline bool error() const { + return dc().error(); + } + + inline void clear() { + add_dc_.clear(); + if (ailego_unlikely(this->debugging())) { + stats_get_neighbors_cnt_ = 0u; + stats_get_vector_cnt_ = 0u; + stats_visit_dup_cnt_ = 0u; + } + // do not clear results_ for the next query will need it + for (auto &it : results_) { + it.clear(); + } + } + + uint32_t *mutable_stats_get_neighbors() { + return &stats_get_neighbors_cnt_; + } + + uint32_t *mutable_stats_get_vector() { + return &stats_get_vector_cnt_; + } + + uint32_t *mutable_stats_visit_dup_cnt() { + return &stats_visit_dup_cnt_; + } + + inline bool debugging(void) const { + return debug_mode_; + } + + inline void update_dist_caculator_distance( + const IndexMetric::MatrixDistance &distance, + const IndexMetric::MatrixBatchDistance &batch_distance) { + dc().update_distance(distance, batch_distance); + } + + //! Get topk + inline uint32_t topk() const override { + return topk_; + } + + //! Get group topk + inline uint32_t group_topk() const { + return group_topk_; + } + + //! Get group num + inline uint32_t group_num() const { + return group_num_; + } + + //! Get if group by search + inline bool group_by_search() { + return group_num_ > 0; + } + + //! Set group params + void set_group_params(uint32_t group_num, uint32_t group_topk) override { + group_num_ = group_num; + group_topk_ = group_topk; + + topk_ = group_topk_ * group_num_; + + topk_heap_.limit(std::max(topk_, ef_)); + + group_topk_heaps_.clear(); + } + + void set_provider(IndexProvider::Pointer provider) { + add_dc_.set_provider(std::move(provider)); + } + + const void *query() const { + return query_; + } + + private: + inline HnswRabitqAddDistCalculator &dc() { + return add_dc_; + } + + inline const HnswRabitqAddDistCalculator &dc() const { + return add_dc_; + } + + private: + // Filling random nodes if topk not full + void fill_random_to_topk_full(void); + + constexpr static uint32_t kTriggerReserveCnt = 4096UL; + constexpr static uint32_t kMinReserveDocCnt = 4096UL; + constexpr static uint32_t kMaxReserveDocCnt = 128 * 1024UL; + constexpr static uint32_t kInvalidMgic = -1U; + + private: + HnswRabitqEntity::Pointer entity_; + HnswRabitqAddDistCalculator add_dc_; + IndexMetric::Pointer metric_; + + bool debug_mode_{false}; + bool force_padding_topk_{false}; + uint32_t max_scan_num_{0}; + uint32_t max_scan_limit_{0}; + uint32_t min_scan_limit_{0}; + uint32_t reserve_max_doc_cnt_{kMinReserveDocCnt}; + uint32_t topk_{0}; + uint32_t group_topk_{0}; + uint32_t filter_mode_{VisitFilter::ByteMap}; + float negative_probability_{HnswRabitqEntity::kDefaultBFNegativeProbability}; + uint32_t ef_{HnswRabitqEntity::kDefaultEf}; + float max_scan_ratio_{HnswRabitqEntity::kDefaultScanRatio}; + uint32_t magic_{0U}; + std::vector results_{}; + std::vector group_results_{}; + TopkHeap topk_heap_{}; + TopkHeap update_heap_{}; + std::vector level_topks_{}; + CandidateHeap candidates_{}; + VisitFilter visit_filter_{}; + uint32_t bruteforce_threshold_{}; + bool fetch_vector_{false}; + + uint32_t group_num_{0}; + std::map group_topk_heaps_{}; + + uint32_t type_{kUnknownContext}; + //! debug stats info + uint32_t stats_get_neighbors_cnt_{0u}; + uint32_t stats_get_vector_cnt_{0u}; + uint32_t stats_visit_dup_cnt_{0u}; + std::string preprocess_buffer_; + const void *query_{nullptr}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_dist_calculator.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_dist_calculator.cc new file mode 100644 index 00000000..04df5dbb --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_dist_calculator.cc @@ -0,0 +1,39 @@ +// Copyright 2025-present the centaurdb project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +#include "core/algorithm/hnsw-rabitq/hnsw_rabitq_dist_calculator.h" +#include "zvec/core/framework/index_error.h" + +namespace zvec::core { + +int HnswRabitqAddDistCalculator::get_vector( + const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const { + for (uint32_t i = 0; i < count; ++i) { + const node_id_t id = ids[i]; + key_t key = entity_->get_key(id); + if (key == kInvalidKey) { + return IndexError_NoExist; + } + IndexStorage::MemoryBlock block; + int ret = provider_->get_vector(key, block); + if (ret != 0) { + return ret; + } + vec_blocks.push_back(std::move(block)); + } + return 0; +} + +} // namespace zvec::core diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_dist_calculator.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_dist_calculator.h new file mode 100644 index 00000000..5ccda8ac --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_dist_calculator.h @@ -0,0 +1,240 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "zvec/core/framework/index_meta.h" +#include "zvec/core/framework/index_metric.h" +#include "zvec/core/framework/index_provider.h" +#include "hnsw_rabitq_entity.h" + +namespace zvec { +namespace core { + +//! HnswRabitqAddDistCalculator is only used for index construction +class HnswRabitqAddDistCalculator { + public: + typedef std::shared_ptr Pointer; + + public: + enum DistType { + DIST_NONE = 0, + DIST_DENSE = 1, + DIST_HYBRID = 2, + DIST_SPARSE = 3 + }; + + public: + //! Constructor + HnswRabitqAddDistCalculator(const HnswRabitqEntity *entity, + const IndexMetric::Pointer &metric, uint32_t dim) + : entity_(entity), + distance_(metric->distance()), + batch_distance_(metric->batch_distance()), + query_(nullptr), + dim_(dim), + compare_cnt_(0) {} + + //! Constructor + HnswRabitqAddDistCalculator(const HnswRabitqEntity *entity, + const IndexMetric::Pointer &metric, uint32_t dim, + const void *query) + : entity_(entity), + distance_(metric->distance()), + batch_distance_(metric->batch_distance()), + query_(query), + dim_(dim), + compare_cnt_(0) {} + + //! Constructor + HnswRabitqAddDistCalculator(const HnswRabitqEntity *entity, + const IndexMetric::Pointer &metric) + : entity_(entity), + distance_(metric->distance()), + batch_distance_(metric->batch_distance()), + query_(nullptr), + dim_(0), + compare_cnt_(0) {} + + void update(const HnswRabitqEntity *entity, + const IndexMetric::Pointer &metric) { + entity_ = entity; + distance_ = metric->distance(); + batch_distance_ = metric->batch_distance(); + } + + void update(const HnswRabitqEntity *entity, + const IndexMetric::Pointer &metric, uint32_t dim) { + entity_ = entity; + distance_ = metric->distance(); + batch_distance_ = metric->batch_distance(); + dim_ = dim; + } + + inline void update_distance( + const IndexMetric::MatrixDistance &distance, + const IndexMetric::MatrixBatchDistance &batch_distance) { + distance_ = distance; + batch_distance_ = batch_distance; + } + + //! Reset query vector data + inline void reset_query(const void *query) { + error_ = false; + query_ = query; + } + + //! Returns distance + inline dist_t dist(const void *vec_lhs, const void *vec_rhs) { + if (ailego_unlikely(vec_lhs == nullptr || vec_rhs == nullptr)) { + LOG_ERROR("Nullptr of dense vector"); + error_ = true; + return 0.0f; + } + + float score{0.0f}; + + distance_(vec_lhs, vec_rhs, dim_, &score); + + return score; + } + + //! Returns distance between query and vec. + inline dist_t dist(const void *vec) { + compare_cnt_++; + + return dist(vec, query_); + } + + //! Return distance between query and node id. + inline dist_t dist(node_id_t id) { + compare_cnt_++; + + const void *feat = get_vector(id); + if (ailego_unlikely(feat == nullptr)) { + LOG_ERROR("Get nullptr vector, id=%u", id); + error_ = true; + return 0.0f; + } + + return dist(feat, query_); + } + + //! Return dist node lhs between node rhs + inline dist_t dist(node_id_t lhs, node_id_t rhs) { + compare_cnt_++; + + const void *feat = get_vector(lhs); + const void *query = get_vector(rhs); + if (ailego_unlikely(feat == nullptr || query == nullptr)) { + LOG_ERROR("Get nullptr vector"); + error_ = true; + return 0.0f; + } + + return dist(feat, query); + } + + dist_t operator()(const void *vec) { + return dist(vec); + } + + dist_t operator()(id_t i) { + return dist(i); + } + + dist_t operator()(id_t lhs, id_t rhs) { + return dist(lhs, rhs); + } + + void batch_dist(const void **vecs, size_t num, dist_t *distances) { + compare_cnt_++; + + batch_distance_(vecs, query_, num, dim_, distances); + } + + inline dist_t batch_dist(node_id_t id) { + compare_cnt_++; + + const void *feat = get_vector(id); + if (ailego_unlikely(feat == nullptr)) { + LOG_ERROR("Get nullptr vector, id=%u", id); + error_ = true; + return 0.0f; + } + dist_t score = 0; + batch_distance_(&feat, query_, 1, dim_, &score); + + return score; + } + + inline void clear() { + compare_cnt_ = 0; + error_ = false; + } + + inline void clear_compare_cnt() { + compare_cnt_ = 0; + } + + inline bool error() const { + return error_; + } + + //! Get distances compute times + inline uint32_t compare_cnt() const { + return compare_cnt_; + } + + inline uint32_t dimension() const { + return dim_; + } + + void set_provider(IndexProvider::Pointer provider) { + provider_ = std::move(provider); + } + + int get_vector(const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const; + + const void *get_vector(node_id_t id) const { + key_t key = entity_->get_key(id); + if (key == kInvalidKey) { + return nullptr; + } + return provider_->get_vector(key); + } + + private: + HnswRabitqAddDistCalculator(const HnswRabitqAddDistCalculator &) = delete; + HnswRabitqAddDistCalculator &operator=(const HnswRabitqAddDistCalculator &) = + delete; + + private: + const HnswRabitqEntity *entity_; + IndexMetric::MatrixDistance distance_; + IndexMetric::MatrixBatchDistance batch_distance_; + + const void *query_; + uint32_t dim_; + + uint32_t compare_cnt_; // record distance compute times + uint32_t compare_cnt_batch_; // record batch distance compute time + bool error_{false}; + + // get raw vector + IndexProvider::Pointer provider_; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_entity.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_entity.cc new file mode 100644 index 00000000..4d9678d3 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_entity.cc @@ -0,0 +1,366 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "hnsw_rabitq_entity.h" +#include "utility/sparse_utility.h" +#include "zvec/core/framework/index_stats.h" + +namespace zvec { +namespace core { + +const std::string HnswRabitqEntity::kGraphHeaderSegmentId = "graph.header"; +const std::string HnswRabitqEntity::kGraphFeaturesSegmentId = "graph.features"; +const std::string HnswRabitqEntity::kGraphKeysSegmentId = "graph.keys"; +const std::string HnswRabitqEntity::kGraphNeighborsSegmentId = + "graph.neighbors"; +const std::string HnswRabitqEntity::kGraphOffsetsSegmentId = "graph.offsets"; +const std::string HnswRabitqEntity::kGraphMappingSegmentId = "graph.mapping"; +const std::string HnswRabitqEntity::kHnswHeaderSegmentId = "hnsw.header"; +const std::string HnswRabitqEntity::kHnswNeighborsSegmentId = "hnsw.neighbors"; +const std::string HnswRabitqEntity::kHnswOffsetsSegmentId = "hnsw.offsets"; + +int HnswRabitqEntity::CalcAndAddPadding(const IndexDumper::Pointer &dumper, + size_t data_size, + size_t *padding_size) { + *padding_size = AlignSize(data_size) - data_size; + if (*padding_size == 0) { + return 0; + } + + std::string padding(*padding_size, '\0'); + if (dumper->write(padding.data(), *padding_size) != *padding_size) { + LOG_ERROR("Append padding failed, size %zu", *padding_size); + return IndexError_WriteData; + } + return 0; +} + +int64_t HnswRabitqEntity::dump_segment(const IndexDumper::Pointer &dumper, + const std::string &segment_id, + const void *data, size_t size) const { + size_t len = dumper->write(data, size); + if (len != size) { + LOG_ERROR("Dump segment %s data failed, expect: %zu, actual: %zu", + segment_id.c_str(), size, len); + return IndexError_WriteData; + } + + size_t padding_size = AlignSize(size) - size; + if (padding_size > 0) { + std::string padding(padding_size, '\0'); + if (dumper->write(padding.data(), padding_size) != padding_size) { + LOG_ERROR("Append padding failed, size %zu", padding_size); + return IndexError_WriteData; + } + } + + uint32_t crc = ailego::Crc32c::Hash(data, size); + int ret = dumper->append(segment_id, size, padding_size, crc); + if (ret != 0) { + LOG_ERROR("Dump segment %s meta failed, ret=%d", segment_id.c_str(), ret); + return ret; + } + + return len + padding_size; +} + +int64_t HnswRabitqEntity::dump_header(const IndexDumper::Pointer &dumper, + const HNSWHeader &hd) const { + //! dump basic graph header. header is aligned and does not need padding + int64_t graph_hd_size = + dump_segment(dumper, kGraphHeaderSegmentId, &hd.graph, hd.graph.size); + if (graph_hd_size < 0) { + return graph_hd_size; + } + + //! dump basic graph header. header is aligned and does not need padding + int64_t hnsw_hd_size = + dump_segment(dumper, kHnswHeaderSegmentId, &hd.hnsw, hd.hnsw.size); + if (hnsw_hd_size < 0) { + return hnsw_hd_size; + } + + return graph_hd_size + hnsw_hd_size; +} + +void HnswRabitqEntity::reshuffle_vectors( + const std::function & /*get_level*/, + std::vector * /*n2o_mapping*/, + std::vector * /*o2n_mapping*/, key_t * /*keys*/) const { + // TODO + return; +} + +int64_t HnswRabitqEntity::dump_mapping_segment( + const IndexDumper::Pointer &dumper, const key_t *keys) const { + std::vector mapping(doc_cnt()); + + std::iota(mapping.begin(), mapping.end(), 0U); + std::sort(mapping.begin(), mapping.end(), + [&](node_id_t i, node_id_t j) { return keys[i] < keys[j]; }); + + size_t size = mapping.size() * sizeof(node_id_t); + + return dump_segment(dumper, kGraphMappingSegmentId, mapping.data(), size); +} + +int64_t HnswRabitqEntity::dump_segments( + const IndexDumper::Pointer &dumper, key_t *keys, + const std::function &get_level) const { + HNSWHeader dump_hd(header()); + + dump_hd.graph.node_size = AlignSize(vector_size()); + + std::vector n2o_mapping; // map new id to origin id + std::vector o2n_mapping; // map origin id to new id + reshuffle_vectors(get_level, &n2o_mapping, &o2n_mapping, keys); + if (!o2n_mapping.empty()) { + dump_hd.hnsw.entry_point = o2n_mapping[entry_point()]; + } + + //! Dump header + int64_t hd_size = dump_header(dumper, dump_hd); + if (hd_size < 0) { + return hd_size; + } + + //! Dump vectors + int64_t vecs_size = dump_vectors(dumper, n2o_mapping); + if (vecs_size < 0) { + return vecs_size; + } + + //! Dump neighbors + auto neighbors_size = + dump_neighbors(dumper, get_level, n2o_mapping, o2n_mapping); + if (neighbors_size < 0) { + return neighbors_size; + } + //! free memory + n2o_mapping = std::vector(); + o2n_mapping = std::vector(); + + //! Dump keys + size_t key_segment_size = doc_cnt() * sizeof(key_t); + int64_t keys_size = + dump_segment(dumper, kGraphKeysSegmentId, keys, key_segment_size); + if (keys_size < 0) { + return keys_size; + } + + //! Dump mapping + int64_t mapping_size = dump_mapping_segment(dumper, keys); + if (mapping_size < 0) { + return mapping_size; + } + + return hd_size + keys_size + vecs_size + neighbors_size + mapping_size; +} + +int64_t HnswRabitqEntity::dump_vectors( + const IndexDumper::Pointer &dumper, + const std::vector &reorder_mapping) const { + size_t vector_dump_size = vector_size(); + + size_t padding_size = AlignSize(vector_dump_size) - vector_dump_size; + + char padding[padding_size]; + memset(padding, 0, sizeof(padding)); + const void *data = nullptr; + uint32_t crc = 0U; + size_t vecs_size = 0UL; + + //! dump vectors + for (node_id_t id = 0; id < doc_cnt(); ++id) { + data = get_vector(reorder_mapping.empty() ? id : reorder_mapping[id]); + if (ailego_unlikely(!data)) { + return IndexError_ReadData; + } + size_t len = dumper->write(data, vector_size()); + if (len != vector_size()) { + LOG_ERROR("Dump vectors failed, write=%zu expect=%zu", len, + vector_size()); + return IndexError_WriteData; + } + + crc = ailego::Crc32c::Hash(data, vector_size(), crc); + vecs_size += vector_size(); + + if (padding_size == 0) { + continue; + } + + len = dumper->write(padding, padding_size); + if (len != padding_size) { + LOG_ERROR("Dump vectors failed, write=%zu expect=%zu", len, padding_size); + return IndexError_WriteData; + } + crc = ailego::Crc32c::Hash(padding, padding_size, crc); + vecs_size += padding_size; + } + + int ret = dumper->append(kGraphFeaturesSegmentId, vecs_size, 0UL, crc); + if (ret != 0) { + LOG_ERROR("Dump vectors segment meta failed, ret %d", ret); + return ret; + } + + return vecs_size; +} + +int64_t HnswRabitqEntity::dump_graph_neighbors( + const IndexDumper::Pointer &dumper, + const std::vector &reorder_mapping, + const std::vector &neighbor_mapping) const { + std::vector graph_meta; + graph_meta.reserve(doc_cnt()); + size_t offset = 0; + uint32_t crc = 0; + node_id_t mapping[l0_neighbor_cnt()]; + + uint32_t min_neighbor_count = 10000; + uint32_t max_neighbor_count = 0; + size_t sum_neighbor_count = 0; + + for (node_id_t id = 0; id < doc_cnt(); ++id) { + const Neighbors neighbors = + get_neighbors(0, reorder_mapping.empty() ? id : reorder_mapping[id]); + ailego_assert_with(!!neighbors.data, "invalid neighbors"); + ailego_assert_with(neighbors.size() <= l0_neighbor_cnt(), + "invalid neighbors"); + + uint32_t neighbor_count = neighbors.size(); + if (neighbor_count < min_neighbor_count) { + min_neighbor_count = neighbor_count; + } + if (neighbor_count > max_neighbor_count) { + max_neighbor_count = neighbor_count; + } + sum_neighbor_count += neighbor_count; + + graph_meta.emplace_back(offset, neighbor_count); + size_t size = neighbors.size() * sizeof(node_id_t); + const node_id_t *data = &neighbors[0]; + if (!neighbor_mapping.empty()) { + for (node_id_t i = 0; i < neighbors.size(); ++i) { + mapping[i] = neighbor_mapping[neighbors[i]]; + } + data = mapping; + } + if (dumper->write(data, size) != size) { + LOG_ERROR("Dump graph neighbor id=%zu failed, size %zu", + static_cast(id), size); + return IndexError_WriteData; + } + crc = ailego::Crc32c::Hash(data, size, crc); + offset += size; + } + + uint32_t average_neighbor_count = 0; + if (doc_cnt() > 0) { + average_neighbor_count = sum_neighbor_count / doc_cnt(); + } + LOG_INFO( + "Dump hnsw graph: min_neighbor_count[%u] max_neighbor_count[%u] " + "average_neighbor_count[%u]", + min_neighbor_count, max_neighbor_count, average_neighbor_count); + + size_t padding_size = 0; + int ret = CalcAndAddPadding(dumper, offset, &padding_size); + if (ret != 0) { + return ret; + } + ret = dumper->append(kGraphNeighborsSegmentId, offset, padding_size, crc); + if (ret != 0) { + LOG_ERROR("Dump segment %s failed, ret %d", + kGraphNeighborsSegmentId.c_str(), ret); + return ret; + } + + //! dump level 0 neighbors meta + auto len = dump_segment(dumper, kGraphOffsetsSegmentId, graph_meta.data(), + graph_meta.size() * sizeof(GraphNeighborMeta)); + if (len < 0) { + return len; + } + + return len + offset + padding_size; +} + +int64_t HnswRabitqEntity::dump_upper_neighbors( + const IndexDumper::Pointer &dumper, + const std::function &get_level, + const std::vector &reorder_mapping, + const std::vector &neighbor_mapping) const { + std::vector hnsw_meta; + hnsw_meta.reserve(doc_cnt()); + size_t offset = 0; + uint32_t crc = 0; + node_id_t buffer[upper_neighbor_cnt() + 1]; + for (node_id_t id = 0; id < doc_cnt(); ++id) { + node_id_t new_id = reorder_mapping.empty() ? id : reorder_mapping[id]; + auto level = get_level(new_id); + if (level == 0) { + hnsw_meta.emplace_back(0U, 0U); + continue; + } + hnsw_meta.emplace_back(offset, level); + ailego_assert_with((size_t)level < kMaxGraphLayers, "invalid level"); + for (level_t cur_level = 1; cur_level <= level; ++cur_level) { + const Neighbors neighbors = get_neighbors(cur_level, new_id); + ailego_assert_with(!!neighbors.data, "invalid neighbors"); + ailego_assert_with(neighbors.size() <= neighbor_cnt(cur_level), + "invalid neighbors"); + memset(buffer, 0, sizeof(buffer)); + buffer[0] = neighbors.size(); + if (neighbor_mapping.empty()) { + memcpy(&buffer[1], &neighbors[0], neighbors.size() * sizeof(node_id_t)); + } else { + for (node_id_t i = 0; i < neighbors.size(); ++i) { + buffer[i + 1] = neighbor_mapping[neighbors[i]]; + } + } + if (dumper->write(buffer, sizeof(buffer)) != sizeof(buffer)) { + LOG_ERROR("Dump graph neighbor id=%zu failed, size %zu", + static_cast(id), sizeof(buffer)); + return IndexError_WriteData; + } + crc = ailego::Crc32c::Hash(buffer, sizeof(buffer), crc); + offset += sizeof(buffer); + } + } + size_t padding_size = 0; + int ret = CalcAndAddPadding(dumper, offset, &padding_size); + if (ret != 0) { + return ret; + } + + ret = dumper->append(kHnswNeighborsSegmentId, offset, padding_size, crc); + if (ret != 0) { + LOG_ERROR("Dump segment %s failed, ret %d", kHnswNeighborsSegmentId.c_str(), + ret); + return ret; + } + + //! dump level 0 neighbors meta + auto len = dump_segment(dumper, kHnswOffsetsSegmentId, hnsw_meta.data(), + hnsw_meta.size() * sizeof(HnswNeighborMeta)); + if (len < 0) { + return len; + } + + return len + offset + padding_size; +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_entity.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_entity.h new file mode 100644 index 00000000..e5290e20 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_entity.h @@ -0,0 +1,643 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "zvec/core/framework/index_dumper.h" +#include "zvec/core/framework/index_error.h" +#include "zvec/core/framework/index_storage.h" + +namespace zvec { +namespace core { + +using node_id_t = uint32_t; +using key_t = uint64_t; +using level_t = int32_t; +using dist_t = float; +struct EstimateRecord { + float ip_x0_qr; + float est_dist; + float low_dist; + + bool operator<(const EstimateRecord &other) const { + return this->est_dist < other.est_dist; + } +}; +struct ResultRecord { + float est_dist; + float low_dist; + ResultRecord() : est_dist(0.0f), low_dist(0.0f) {} + ResultRecord(float dist) : est_dist(dist), low_dist(dist) {} + explicit ResultRecord(const EstimateRecord &other) + : est_dist(other.est_dist), low_dist(other.low_dist) {} + ResultRecord(float est_dist, float low_dist) + : est_dist(est_dist), low_dist(low_dist) {} + bool operator<(const ResultRecord &other) const { + return this->est_dist < other.est_dist; + } + bool operator<=(const ResultRecord &other) const { + return this->est_dist <= other.est_dist; + } + bool operator>(const ResultRecord &other) const { + return this->est_dist > other.est_dist; + } +}; +using TopkHeap = ailego::KeyValueHeap; +using CandidateHeap = + ailego::KeyValueHeap>; +constexpr node_id_t kInvalidNodeId = static_cast(-1); +constexpr key_t kInvalidKey = static_cast(-1); +class DistCalculator; + +struct GraphHeader { + uint32_t size; + uint32_t version; + uint32_t graph_type; + uint32_t doc_count; + uint32_t vector_size; + uint32_t node_size; + uint32_t l0_neighbor_count; + uint32_t prune_type; + uint32_t prune_neighbor_count; + uint32_t ef_construction; + uint32_t options; + uint32_t min_neighbor_count; + uint32_t padded_dim; + uint32_t size_bin_data; + uint32_t size_ex_data; + uint8_t ex_bits; + uint8_t reserved_[4067]; +}; + +static_assert(sizeof(GraphHeader) % 32 == 0, + "GraphHeader must be aligned with 32 bytes"); + +//! Hnsw upper neighbor header +struct HnswHeader { + uint32_t size; // header size + uint32_t revision; // current total docs of the graph + uint32_t upper_neighbor_count; + uint32_t ef_construction; + uint32_t scaling_factor; + uint32_t max_level; + uint32_t entry_point; + uint32_t options; + uint8_t reserved_[30]; +}; + +static_assert(sizeof(HnswHeader) % 32 == 0, + "GraphHeader must be aligned with 32 bytes"); + +//! Hnsw common header and upper neighbor header +struct HNSWHeader { + HNSWHeader() { + clear(); + } + + HNSWHeader(const HNSWHeader &header) { + memcpy(this, &header, sizeof(header)); + } + + HNSWHeader &operator=(const HNSWHeader &header) { + memcpy(this, &header, sizeof(header)); + return *this; + } + + //! Reset state to zero, and the params is untouched + void inline reset() { + graph.doc_count = 0U; + hnsw.entry_point = kInvalidNodeId; + hnsw.max_level = 0; + } + + //! Clear all fields to init value + void inline clear() { + memset(this, 0, sizeof(HNSWHeader)); + hnsw.entry_point = kInvalidNodeId; + graph.size = sizeof(GraphHeader); + hnsw.size = sizeof(HnswHeader); + } + + size_t l0_neighbor_cnt() const { + return graph.l0_neighbor_count; + } + + size_t upper_neighbor_cnt() const { + return hnsw.upper_neighbor_count; + } + + size_t vector_size() const { + return graph.vector_size; + } + + uint8_t ex_bits() const { + return graph.ex_bits; + } + + uint32_t padded_dim() const { + return graph.padded_dim; + } + + size_t ef_construction() const { + return graph.ef_construction; + } + + size_t scaling_factor() const { + return hnsw.scaling_factor; + } + + size_t neighbor_prune_cnt() const { + return graph.prune_neighbor_count; + } + + node_id_t entry_point() const { + return hnsw.entry_point; + } + + node_id_t doc_cnt() const { + return graph.doc_count; + } + + GraphHeader graph; + HnswHeader hnsw; +}; + +struct NeighborsHeader { + uint32_t neighbor_cnt; + node_id_t neighbors[0]; +}; + +struct Neighbors { + Neighbors() : cnt{0}, data{nullptr} {} + + Neighbors(uint32_t cnt_in, const node_id_t *data_in) + : cnt{cnt_in}, data{data_in} {} + + Neighbors(IndexStorage::MemoryBlock &&mem_block) + : neighbor_block{std::move(mem_block)} { + auto hd = reinterpret_cast(neighbor_block.data()); + cnt = hd->neighbor_cnt; + data = hd->neighbors; + } + + size_t size(void) const { + return cnt; + } + + const node_id_t &operator[](size_t idx) const { + return data[idx]; + } + + uint32_t cnt; + const node_id_t *data; + IndexStorage::MemoryBlock neighbor_block; +}; + +//! level 0 neighbors offset +struct GraphNeighborMeta { + GraphNeighborMeta(size_t o, size_t cnt) : offset(o), neighbor_cnt(cnt) {} + + uint64_t offset : 48; + uint64_t neighbor_cnt : 16; +}; + +//! hnsw upper neighbors meta +struct HnswNeighborMeta { + HnswNeighborMeta(size_t o, size_t l) : offset(o), level(l) {} + + uint64_t offset : 48; // offset = idx * upper neighors size + uint64_t level : 16; +}; + +class HnswRabitqEntity { + public: + //! Constructor + HnswRabitqEntity() {} + + //! Constructor + HnswRabitqEntity(const HNSWHeader &hd) { + header_ = hd; + } + + //! Destructor + virtual ~HnswRabitqEntity() {} + + //! HnswRabitqEntity Pointerd; + typedef std::shared_ptr Pointer; + + //! Get max neighbor size of graph level + inline size_t neighbor_cnt(level_t level) const { + return level == 0 ? header_.graph.l0_neighbor_count + : header_.hnsw.upper_neighbor_count; + } + + //! get max neighbor size of graph level 0 + inline size_t l0_neighbor_cnt() const { + return header_.graph.l0_neighbor_count; + } + + //! get min neighbor size of graph + inline size_t min_neighbor_cnt() const { + return header_.graph.min_neighbor_count; + } + + //! get upper neighbor size of graph level other than 0 + inline size_t upper_neighbor_cnt() const { + return header_.hnsw.upper_neighbor_count; + } + + //! Get current total doc of the hnsw graph + inline node_id_t *mutable_doc_cnt() { + return &header_.graph.doc_count; + } + + inline node_id_t doc_cnt() const { + return header_.graph.doc_count; + } + + //! Get hnsw graph scaling params + inline size_t scaling_factor() const { + return header_.hnsw.scaling_factor; + } + + //! Get prune_size + inline size_t prune_cnt() const { + return header_.graph.prune_neighbor_count; + } + + //! Current entity of top level graph + inline node_id_t entry_point() const { + return header_.hnsw.entry_point; + } + + //! Current max graph level + inline level_t cur_max_level() const { + return header_.hnsw.max_level; + } + + //! Retrieve index vector size + size_t vector_size() const { + return header_.graph.vector_size; + } + + //! Retrieve node size + size_t node_size() const { + return header_.graph.node_size; + } + + //! Retrieve ef constuction + size_t ef_construction() const { + return header_.graph.ef_construction; + } + + uint8_t ex_bits() const { + return header_.graph.ex_bits; + } + + uint32_t padded_dim() const { + return header_.graph.padded_dim; + } + + uint32_t size_bin_data() const { + return header_.graph.size_bin_data; + } + + uint32_t size_ex_data() const { + return header_.graph.size_ex_data; + } + + void update_rabitq_params_and_vector_size(uint32_t dimension) { + uint32_t padded_dim = ((dimension + 63) / 64) * 64; + header_.graph.padded_dim = padded_dim; + // return (padded_dim / 8) + (sizeof(T) * 3); + /* + +------------------+----------+-------------+----------+ + | bin_code_ | f_add_ | f_rescale_ | f_error_ | + | (padded_dim/8) | (4 bytes)| (4 bytes) | (4 bytes)| + +------------------+----------+-------------+----------+ + |<----- 量化码 ---->|<------- 3个float参数 (12字节) --->| + */ + header_.graph.size_bin_data = + rabitqlib::BinDataMap::data_bytes(padded_dim); + // return ex_bits > 0 ? (padded_dim * ex_bits / 8) + (sizeof(T) * 2) : 0; + /* + +-------------------------+-------------+-----------------+ + | ex_code_ | f_add_ex_ | f_rescale_ex_ | + | (padded_dim*ex_bits/8) | (4 bytes) | (4 bytes) | + +-------------------------+-------------+-----------------+ + |<----- 扩展量化码 ------->|<---- 2个float参数 (8字节) ---->| + */ + + header_.graph.size_ex_data = rabitqlib::ExDataMap::data_bytes( + padded_dim, header_.graph.ex_bits); + + // cluster_id + bin_data + ex_data + header_.graph.vector_size = + sizeof(uint32_t) + size_bin_data() + size_ex_data(); + } + + void set_ex_bits(uint8_t ex_bits) { + header_.graph.ex_bits = ex_bits; + } + + void set_prune_cnt(size_t v) { + header_.graph.prune_neighbor_count = v; + } + + void set_scaling_factor(size_t val) { + header_.hnsw.scaling_factor = val; + } + + void set_l0_neighbor_cnt(size_t cnt) { + header_.graph.l0_neighbor_count = cnt; + } + + void set_min_neighbor_cnt(size_t cnt) { + header_.graph.min_neighbor_count = cnt; + } + + void set_upper_neighbor_cnt(size_t cnt) { + header_.hnsw.upper_neighbor_count = cnt; + } + + void set_ef_construction(size_t ef) { + header_.graph.ef_construction = ef; + } + + protected: + inline const HNSWHeader &header() const { + return header_; + } + + inline HNSWHeader *mutable_header() { + return &header_; + } + + inline size_t header_size() const { + return sizeof(header_); + } + + void set_node_size(size_t size) { + header_.graph.node_size = size; + } + + //! Dump all segment by dumper + //! Return dump size if success, errno(<0) in failure + int64_t dump_segments( + const IndexDumper::Pointer &dumper, key_t *keys, + const std::function &get_level) const; + + private: + //! dump mapping segment, for get_vector_by_key in provider + int64_t dump_mapping_segment(const IndexDumper::Pointer &dumper, + const key_t *keys) const; + + //! dump hnsw head by dumper + //! Return dump size if success, errno(<0) in failure + int64_t dump_header(const IndexDumper::Pointer &dumper, + const HNSWHeader &hd) const; + + //! dump vectors by dumper + //! Return dump size if success, errno(<0) in failure + int64_t dump_vectors(const IndexDumper::Pointer &dumper, + const std::vector &reorder_mapping) const; + + //! dump hnsw neighbors by dumper + //! Return dump size if success, errno(<0) in failure + int64_t dump_neighbors(const IndexDumper::Pointer &dumper, + const std::function &get_level, + const std::vector &reorder_mapping, + const std::vector &neighbor_mapping) const { + auto len1 = dump_graph_neighbors(dumper, reorder_mapping, neighbor_mapping); + if (len1 < 0) { + return len1; + } + auto len2 = dump_upper_neighbors(dumper, get_level, reorder_mapping, + neighbor_mapping); + if (len2 < 0) { + return len2; + } + + return len1 + len2; + } + + //! dump segment by dumper + //! Return dump size if success, errno(<0) in failure + int64_t dump_segment(const IndexDumper::Pointer &dumper, + const std::string &segment_id, const void *data, + size_t size) const; + + //! Dump level 0 neighbors + //! Return dump size if success, errno(<0) in failure + int64_t dump_graph_neighbors( + const IndexDumper::Pointer &dumper, + const std::vector &reorder_mapping, + const std::vector &neighbor_mapping) const; + + //! Dump upper level neighbors + //! Return dump size if success, errno(<0) in failure + int64_t dump_upper_neighbors( + const IndexDumper::Pointer &dumper, + const std::function &get_level, + const std::vector &reorder_mapping, + const std::vector &neighbor_mapping) const; + + public: + //! Cleanup the entity + virtual int cleanup(void) { + header_.clear(); + return 0; + } + + //! Make a copy of searcher entity, to support thread-safe operation. + //! The segment in container cannot be read concurrenly + virtual const HnswRabitqEntity::Pointer clone() const { + LOG_ERROR("Update neighbors not implemented"); + return HnswRabitqEntity::Pointer(); + } + + //! Get primary key of the node id + virtual key_t get_key(node_id_t id) const = 0; + + //! Get vector feature data by key + virtual const void *get_vector(node_id_t id) const = 0; + + //! Get vectors feature data by keys + virtual int get_vector(const node_id_t *ids, uint32_t count, + const void **vecs) const = 0; + + virtual int get_vector(const node_id_t id, + IndexStorage::MemoryBlock &block) const = 0; + virtual int get_vector( + const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const = 0; + + //! Retrieve a vector using a primary key + virtual const void *get_vector_by_key(uint64_t /*key*/) const { + LOG_ERROR("get vector not implemented"); + return nullptr; + } + + virtual int get_vector_by_key(const key_t /*key*/, + IndexStorage::MemoryBlock & /*block*/) const { + return IndexError_NotImplemented; + } + + //! Get the node id's neighbors on graph level + //! Note: the neighbors cannot be modified, using the following + //! method to get WritableNeighbors if want to + virtual const Neighbors get_neighbors(level_t level, node_id_t id) const = 0; + + //! Add vector and key to hnsw entity, and local id will be saved in id + virtual int add_vector(level_t /*level*/, key_t /*key*/, const void * /*vec*/, + node_id_t * /*id*/) { + return IndexError_NotImplemented; + } + + //! Add vector and id to hnsw entity + virtual int add_vector_with_id(level_t /*level*/, node_id_t /*id*/, + const void * /*vec*/) { + return IndexError_NotImplemented; + } + + virtual int update_neighbors( + level_t /*level*/, node_id_t /*id*/, + const std::vector> & /*neighbors*/) { + LOG_ERROR("Update neighbors dense not implemented"); + + return 0; + } + + //! Append neighbor_id to node id neighbors on level, size is the current + //! neighbors size. Notice: the caller must be ensure the neighbors not full + virtual void add_neighbor(level_t /*level*/, node_id_t /*id*/, + uint32_t /*size*/, node_id_t /*neighbor_id*/) { + LOG_ERROR("Add neighbor not implemented"); + } + + //! Update entry point and max level + virtual void update_ep_and_level(node_id_t ep, level_t level) { + header_.hnsw.entry_point = ep; + header_.hnsw.max_level = level; + } + + virtual int load(const IndexStorage::Pointer & /*container*/, + bool /*check_crc*/) { + LOG_ERROR("Load not implemented"); + return IndexError_NotImplemented; + } + + virtual int dump(const IndexDumper::Pointer & /*dumper*/) { + LOG_ERROR("Dump not implemented"); + return IndexError_NotImplemented; + } + + static int CalcAndAddPadding(const IndexDumper::Pointer &dumper, + size_t data_size, size_t *padding_size); + + uint32_t get_cluster_id(const void *vec) const { + return *reinterpret_cast( + reinterpret_cast(vec) + cluster_id_offset()); + } + + const char *get_bin_data(const void *vec) const { + return reinterpret_cast(vec) + bin_data_offset(); + } + + const char *get_ex_data(const void *vec) const { + return reinterpret_cast(vec) + ex_data_offset(); + } + + uint32_t cluster_id_offset() const { + return 0; + } + + uint32_t bin_data_offset() const { + return cluster_id_offset() + sizeof(uint32_t); + } + + uint32_t ex_data_offset() const { + return bin_data_offset() + size_bin_data(); + } + + protected: + static inline size_t AlignSize(size_t size) { + return (size + 0x1F) & (~0x1F); + } + + static inline size_t AlignPageSize(size_t size) { + size_t page_mask = ailego::MemoryHelper::PageSize() - 1; + return (size + page_mask) & (~page_mask); + } + + static inline size_t AlignHugePageSize(size_t size) { + size_t page_mask = ailego::MemoryHelper::HugePageSize() - 1; + return (size + page_mask) & (~page_mask); + } + + //! rearrange vectors to improve cache locality + void reshuffle_vectors(const std::function &get_level, + std::vector *n2o_mapping, + std::vector *o2n_mapping, + key_t *keys) const; + + public: + const static std::string kGraphHeaderSegmentId; + const static std::string kGraphFeaturesSegmentId; + const static std::string kGraphKeysSegmentId; + const static std::string kGraphNeighborsSegmentId; + const static std::string kGraphOffsetsSegmentId; + const static std::string kGraphMappingSegmentId; + const static std::string kHnswHeaderSegmentId; + const static std::string kHnswNeighborsSegmentId; + const static std::string kHnswOffsetsSegmentId; + + constexpr static uint32_t kRevision = 0U; + constexpr static size_t kMaxGraphLayers = 15; + constexpr static uint32_t kDefaultEfConstruction = 500; + constexpr static uint32_t kDefaultEf = 500; + constexpr static uint32_t kDefaultUpperMaxNeighborCnt = 50; // M of HNSW + constexpr static uint32_t kDefaultL0MaxNeighborCnt = 100; + constexpr static uint32_t kMaxNeighborCnt = 65535; + constexpr static float kDefaultScanRatio = 0.1f; + constexpr static uint32_t kDefaultMinScanLimit = 10000; + constexpr static uint32_t kDefaultMaxScanLimit = + std::numeric_limits::max(); + constexpr static float kDefaultBFNegativeProbability = 0.001f; + constexpr static uint32_t kDefaultScalingFactor = 50U; + constexpr static uint32_t kDefaultBruteForceThreshold = 1000U; + constexpr static uint32_t kDefaultDocsHardLimit = 1 << 30U; // 1 billion + constexpr static float kDefaultDocsSoftLimitRatio = 0.9f; + constexpr static size_t kMaxChunkSize = 0xFFFFFFFF; + constexpr static size_t kDefaultChunkSize = 2UL * 1024UL * 1024UL; + constexpr static size_t kDefaultMaxChunkCnt = 50000UL; + constexpr static float kDefaultNeighborPruneMultiplier = + 1.0f; // prune_cnt = upper_max_neighbor_cnt * multiplier + constexpr static float kDefaultL0MaxNeighborCntMultiplier = + 2.0f; // l0_max_neighbor_cnt = upper_max_neighbor_cnt * multiplier + + protected: + HNSWHeader header_{}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_index_hash.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_index_hash.h new file mode 100644 index 00000000..4f01aabb --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_index_hash.h @@ -0,0 +1,231 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "hnsw_rabitq_chunk.h" + +namespace zvec { +namespace core { + +//! Persistent hashmap implement through open addressing algorithm +template ::value>::type> +class HnswIndexHashMap { + using key_type = Key; + using val_type = Val; + + struct Iterator { + key_type first; + val_type second; + }; + typedef Iterator *iterator; + typedef Iterator Item; + typedef const Iterator *const_iterator; + + class Slot { + public: + Slot(Chunk::Pointer &&chunk, const void *data) + : chunk_(std::move(chunk)), + items_(reinterpret_cast(data)) {} + //! Return a empty loc or the key item loc + + Slot(Chunk::Pointer &&chunk, IndexStorage::MemoryBlock &&mem_block) + : chunk_(std::move(chunk)), items_block_(std::move(mem_block)) { + items_ = reinterpret_cast(items_block_.data()); + } + const_iterator find(key_type key, uint32_t max_items, uint32_t mask) const { + auto it = &items_[key & mask]; + for (auto i = 0U; i < max_items; ++i) { + if (it->first == key || it->second == EmptyVal) { + // LOG_DEBUG("i=%u", i); + return it; + } + ++it; + if (it == &items_[max_items]) { + it = &items_[0]; + } + } + return nullptr; + } + + bool update(const_iterator it) { + uint32_t offset = reinterpret_cast(it) - + reinterpret_cast(&items_[0]); + if (ailego_unlikely(chunk_->write(offset, it, sizeof(Item)) != + sizeof(Item))) { + LOG_ERROR("Chunk write failed"); + return false; + } + return true; + } + + private: + Chunk::Pointer chunk_{}; + const Item *items_{nullptr}; // point to chunk data + IndexStorage::MemoryBlock items_block_{}; + }; + + public: + //! Init the hash + //! broker the index allocator + //! chunk_size the size of per chunk allocated, actual size may greater + //! factor factor = 1/ratio, ratio is the probability of a squence + //! number inserted to this container + //! max the max number key can be inserted + //! expansion_ratio memory expansion ratio + int init(HnswRabitqChunkBroker::Pointer &broker, uint32_t chunk_size, + uint32_t factor, size_t max, float expansion_ratio) { + ailego_assert_with(expansion_ratio > 1.0f, "ratio must > 1.0f"); + broker_ = broker; + + size_t items = std::ceil(chunk_size * 1.0f / sizeof(Item)); + slot_items_ = 1UL << static_cast((std::ceil(std::log2(items)))); + size_t range = slot_items_ * factor / expansion_ratio; + mask_bits_ = std::floor(std::log2(range)); + range = 1UL << mask_bits_; + size_t max_slots = std::ceil(max * 1.0f / range); + slots_.reserve(max_slots); + slot_loc_mask_ = slot_items_ - 1U; + int ret = load(); + if (ret != 0) { + return ret; + } + + LOG_DEBUG( + "HnswRabitqIndexHash init, chunkSize=%u factor=%u max=%zu " + "ratio=%f slotItems=%u maxSlots=%zu maskBits=%u " + "range=%zu", + chunk_size, factor, max, expansion_ratio, slot_items_, max_slots, + mask_bits_, range); + + return 0; + } + + int cleanup(void) { + broker_.reset(); + slots_.clear(); + slots_.shrink_to_fit(); + mask_bits_ = 0U; + slot_items_ = 0U; + slot_loc_mask_ = 0U; + + return 0; + } + + const_iterator end(void) const { + return nullptr; + } + + const_iterator find(const key_type key) const { + auto idx = key >> mask_bits_; + if (idx >= slots_.size()) { + return end(); + } + auto it = slots_[idx].find(key, slot_items_, slot_loc_mask_); + return it && it->second != EmptyVal ? it : nullptr; + } + + bool insert(key_type key, val_type val) { + auto idx = key >> mask_bits_; + if (idx >= slots_.size()) { + if (ailego_unlikely(idx >= slots_.capacity())) { + LOG_ERROR("no space to insert"); + return false; + } + for (auto i = slots_.size(); i <= idx; ++i) { + if (ailego_unlikely(!alloc_slot(i))) { + return false; + } + } + } + auto it = slots_[idx].find(key, slot_items_, slot_loc_mask_); + if (ailego_unlikely(it == nullptr)) { + LOG_ERROR("no space to insert"); + return false; + } + + //! TODO: write memory is ok? + const_cast(it)->first = key; + const_cast(it)->second = val; + + return slots_[idx].update(it); + } + + private: + bool alloc_slot(size_t idx) { + ailego_assert_with(idx == slots_.size(), "invalid idx"); + + size_t size = slot_items_ * sizeof(Item); + auto p = broker_->alloc_chunk( + HnswRabitqChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX, idx, size); + if (ailego_unlikely(p.first != 0)) { + LOG_ERROR("Alloc data chunk failed"); + return false; + } + Chunk::Pointer chunk = p.second; + if (ailego_unlikely(chunk->resize(size) != size)) { + LOG_ERROR("Chunk resize failed, size=%zu", size); + return false; + } + //! Read the whole data to memory + IndexStorage::MemoryBlock data_block; + if (ailego_unlikely(chunk->read(0U, data_block, size) != size)) { + LOG_ERROR("Chunk read failed, size=%zu", size); + return false; + } + + slots_.emplace_back(std::move(chunk), std::move(data_block)); + return true; + } + + int load(void) { + size_t slots_cnt = broker_->get_chunk_cnt( + HnswRabitqChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX); + for (size_t i = 0UL; i < slots_cnt; ++i) { + auto chunk = broker_->get_chunk( + HnswRabitqChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX, i); + if (!chunk) { + LOG_ERROR("Get chunk failed, seq=%zu", i); + return IndexError_InvalidFormat; + } + size_t size = sizeof(Item) * slot_items_; + if (chunk->data_size() < size) { + LOG_ERROR( + "Hash params may be mismatch, seq=%zu, data_size=%zu " + "expect=%zu", + i, chunk->data_size(), size); + return IndexError_InvalidFormat; + } + //! Read the whole data to memory + IndexStorage::MemoryBlock data_block; + if (ailego_unlikely(chunk->read(0U, data_block, size) != size)) { + LOG_ERROR("Chunk read failed, size=%zu", size); + return false; + } + slots_.emplace_back(std::move(chunk), std::move(data_block)); + } + return 0; + } + + private: + HnswRabitqChunkBroker::Pointer broker_{}; // chunk broker + std::vector slots_{}; + uint32_t mask_bits_{0U}; + uint32_t slot_items_{}; // must be a power of 2 + uint32_t slot_loc_mask_{}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_index_provider.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_index_provider.h new file mode 100644 index 00000000..29464e0f --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_index_provider.h @@ -0,0 +1,134 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "zvec/core/framework/index_provider.h" +#include "zvec/core/framework/index_searcher.h" +#include "zvec/core/framework/index_streamer.h" +#include "hnsw_rabitq_entity.h" + +namespace zvec { +namespace core { + +class HnswRabitqIndexProvider : public IndexProvider { + public: + HnswRabitqIndexProvider(const IndexMeta &meta, + const HnswRabitqEntity::Pointer &entity, + const std::string &owner) + : meta_(meta), entity_(entity), owner_class_(owner) {} + + HnswRabitqIndexProvider(const HnswRabitqIndexProvider &) = delete; + HnswRabitqIndexProvider &operator=(const HnswRabitqIndexProvider &) = delete; + + public: // holder interface + //! Create a new iterator + IndexProvider::Iterator::Pointer create_iterator() override { + return HnswRabitqIndexProvider::Iterator::Pointer(new (std::nothrow) + Iterator(entity_)); + } + + //! Retrieve count of vectors + size_t count(void) const override { + return entity_->doc_cnt(); + } + + //! Retrieve dimension of vector + size_t dimension(void) const override { + return meta_.dimension(); + } + + //! Retrieve type of vector + IndexMeta::DataType data_type(void) const override { + return meta_.data_type(); + } + + //! Retrieve vector size in bytes + size_t element_size(void) const override { + return meta_.element_size(); + } + + public: // provider's unique interface + //! Retrieve a vector using a primary key + const void *get_vector(uint64_t key) const override { + return entity_->get_vector_by_key(key); + } + + int get_vector(const uint64_t key, + IndexStorage::MemoryBlock &block) const override { + return entity_->get_vector_by_key(key, block); + } + + //! Retrieve the owner class + const std::string &owner_class(void) const override { + return owner_class_; + } + + private: + class Iterator : public IndexProvider::Iterator { + public: + Iterator(const HnswRabitqEntity::Pointer &entity) + : entity_(entity), cur_id_(0U) {} + + //! Retrieve pointer of data + //! NOTICE: the vec feature will be changed after iterating to next, so + //! the caller need to keep a copy of it before iterator to next vector + virtual const void *data(void) const override { + return entity_->get_vector(cur_id_); + } + + //! Test if the iterator is valid + virtual bool is_valid(void) const override { + return cur_id_ < entity_->doc_cnt(); + } + + //! Retrieve primary key + virtual uint64_t key(void) const override { + return entity_->get_key(cur_id_); + } + + //! Next iterator + virtual void next(void) override { + // cur_id_ += 1; + cur_id_ = get_next_valid_id(cur_id_ + 1); + } + + //! Reset the iterator + void reset(void) { + cur_id_ = get_next_valid_id(0); + } + + private: + node_id_t get_next_valid_id(node_id_t start_id) { + for (node_id_t i = start_id; i < entity_->doc_cnt(); i++) { + if (entity_->get_key(i) != kInvalidNodeId) { + cur_id_ = i; + return i; + } + } + return kInvalidNodeId; + } + + private: + const HnswRabitqEntity::Pointer entity_; + node_id_t cur_id_; + }; + + private: + const IndexMeta &meta_; + const HnswRabitqEntity::Pointer entity_; + const std::string owner_class_; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_params.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_params.h new file mode 100644 index 00000000..8b1c597c --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_params.h @@ -0,0 +1,121 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include + +namespace zvec { +namespace core { + +inline const std::string PARAM_HNSW_RABITQ_GENERAL_DIMENSION( + "proxima.hnsw_rabitq.general.dimension"); + +inline const std::string PARAM_HNSW_RABITQ_BUILDER_THREAD_COUNT( + "proxima.hnsw_rabitq.builder.thread_count"); +inline const std::string PARAM_HNSW_RABITQ_BUILDER_MEMORY_QUOTA( + "proxima.hnsw_rabitq.builder.memory_quota"); +inline const std::string PARAM_HNSW_RABITQ_BUILDER_EFCONSTRUCTION( + "proxima.hnsw_rabitq.builder.efconstruction"); +inline const std::string PARAM_HNSW_RABITQ_BUILDER_SCALING_FACTOR( + "proxima.hnsw_rabitq.builder.scaling_factor"); +inline const std::string PARAM_HNSW_RABITQ_BUILDER_CHECK_INTERVAL_SECS( + "proxima.hnsw_rabitq.builder.check_interval_secs"); +inline const std::string PARAM_HNSW_RABITQ_BUILDER_NEIGHBOR_PRUNE_MULTIPLIER( + "proxima.hnsw_rabitq.builder.neighbor_prune_multiplier"); +inline const std::string PARAM_HNSW_RABITQ_BUILDER_MIN_NEIGHBOR_COUNT( + "proxima.hnsw_rabitq.builder.min_neighbor_count"); +inline const std::string PARAM_HNSW_RABITQ_BUILDER_MAX_NEIGHBOR_COUNT( + "proxima.hnsw_rabitq.builder.max_neighbor_count"); +inline const std::string + PARAM_HNSW_RABITQ_BUILDER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER( + "proxima.hnsw_rabitq.builder.l0_max_neighbor_count_multiplier"); + +inline const std::string PARAM_HNSW_RABITQ_SEARCHER_EF( + "proxima.hnsw_rabitq.searcher.ef"); +inline const std::string PARAM_HNSW_RABITQ_SEARCHER_BRUTE_FORCE_THRESHOLD( + "proxima.hnsw_rabitq.searcher.brute_force_threshold"); +inline const std::string PARAM_HNSW_RABITQ_SEARCHER_NEIGHBORS_IN_MEMORY_ENABLE( + "proxima.hnsw_rabitq.searcher.neighbors_in_memory_enable"); +inline const std::string PARAM_HNSW_RABITQ_SEARCHER_MAX_SCAN_RATIO( + "proxima.hnsw_rabitq.searcher.max_scan_ratio"); +inline const std::string PARAM_HNSW_RABITQ_SEARCHER_CHECK_CRC_ENABLE( + "proxima.hnsw_rabitq.searcher.check_crc_enable"); +inline const std::string PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_ENABLE( + "proxima.hnsw_rabitq.searcher.visit_bloomfilter_enable"); +inline const std::string + PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB( + "proxima.hnsw_rabitq.searcher.visit_bloomfilter_negative_prob"); +inline const std::string PARAM_HNSW_RABITQ_SEARCHER_FORCE_PADDING_RESULT_ENABLE( + "proxima.hnsw_rabitq.searcher.force_padding_result_enable"); + +inline const std::string PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_RATIO( + "proxima.hnsw_rabitq.streamer.max_scan_ratio"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_MIN_SCAN_LIMIT( + "proxima.hnsw_rabitq.streamer.min_scan_limit"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_LIMIT( + "proxima.hnsw_rabitq.streamer.max_scan_limit"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_EF( + "proxima.hnsw_rabitq.streamer.ef"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_EFCONSTRUCTION( + "proxima.hnsw_rabitq.streamer.efconstruction"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_MAX_NEIGHBOR_COUNT( + "proxima.hnsw_rabitq.streamer.max_neighbor_count"); +inline const std::string + PARAM_HNSW_RABITQ_STREAMER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER( + "proxima.hnsw_rabitq.streamer.l0_max_neighbor_count_multiplier"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_SCALING_FACTOR( + "proxima.hnsw_rabitq.streamer.scaling_factor"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_BRUTE_FORCE_THRESHOLD( + "proxima.hnsw_rabitq.streamer.brute_force_threshold"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_DOCS_HARD_LIMIT( + "proxima.hnsw_rabitq.streamer.docs_hard_limit"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_DOCS_SOFT_LIMIT( + "proxima.hnsw_rabitq.streamer.docs_soft_limit"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_MAX_INDEX_SIZE( + "proxima.hnsw_rabitq.streamer.max_index_size"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_ENABLE( + "proxima.hnsw_rabitq.streamer.visit_bloomfilter_enable"); +inline const std::string + PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB( + "proxima.hnsw_rabitq.streamer.visit_bloomfilter_negative_prob"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_CHECK_CRC_ENABLE( + "proxima.hnsw_rabitq.streamer.check_crc_enable"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_NEIGHBOR_PRUNE_MULTIPLIER( + "proxima.hnsw_rabitq.streamer.neighbor_prune_multiplier"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_CHUNK_SIZE( + "proxima.hnsw_rabitq.streamer.chunk_size"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_FILTER_SAME_KEY( + "proxima.hnsw_rabitq.streamer.filter_same_key"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_GET_VECTOR_ENABLE( + "proxima.hnsw_rabitq.streamer.get_vector_enable"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_MIN_NEIGHBOR_COUNT( + "proxima.hnsw_rabitq.streamer.min_neighbor_count"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_FORCE_PADDING_RESULT_ENABLE( + "proxima.hnsw_rabitq.streamer.force_padding_result_enable"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_ESTIMATE_DOC_COUNT( + "proxima.hnsw_rabitq.streamer.estimate_doc_count"); +inline const std::string PARAM_HNSW_RABITQ_STREAMER_USE_ID_MAP( + "proxima.hnsw_rabitq.streamer.use_id_map"); + +inline const std::string PARAM_HNSW_RABITQ_REDUCER_WORKING_PATH( + "proxima.hnsw_rabitq.reducer.working_path"); +inline const std::string PARAM_HNSW_RABITQ_REDUCER_NUM_OF_ADD_THREADS( + "proxima.hnsw_rabitq.reducer.num_of_add_threads"); +inline const std::string PARAM_HNSW_RABITQ_REDUCER_INDEX_NAME( + "proxima.hnsw_rabitq.reducer.index_name"); +inline const std::string PARAM_HNSW_RABITQ_REDUCER_EFCONSTRUCTION( + "proxima.hnsw_rabitq.reducer.efconstruction"); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_query_algorithm.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_query_algorithm.cc new file mode 100644 index 00000000..88f1a692 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_query_algorithm.cc @@ -0,0 +1,436 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "hnsw_rabitq_query_algorithm.h" +#include +#include +#include +#include +#include "zvec/ailego/internal/platform.h" +#include "hnsw_rabitq_entity.h" +#include "hnsw_rabitq_query_entity.h" + +namespace zvec { +namespace core { + +HnswRabitqQueryAlgorithm::HnswRabitqQueryAlgorithm( + HnswRabitqEntity &entity, size_t num_clusters, + rabitqlib::MetricType metric_type) + : entity_(entity), + mt_(std::chrono::system_clock::now().time_since_epoch().count()), + lock_pool_(kLockCnt), + num_clusters_(num_clusters), + metric_type_(metric_type) { + ex_bits_ = entity_.ex_bits(); + padded_dim_ = entity_.padded_dim(); + ip_func_ = rabitqlib::select_excode_ipfunc(ex_bits_); + LOG_INFO( + "Create query algorithm. num_clusters=%zu ex_bits=%zu padded_dim=%zu", + num_clusters_, ex_bits_, padded_dim_); +} + +int HnswRabitqQueryAlgorithm::cleanup() { + return 0; +} + +int HnswRabitqQueryAlgorithm::search(HnswRabitqQueryEntity *entity, + HnswRabitqContext *ctx) const { + spin_lock_.lock(); + auto maxLevel = entity_.cur_max_level(); + auto entry_point = entity_.entry_point(); + spin_lock_.unlock(); + + if (ailego_unlikely(entry_point == kInvalidNodeId)) { + return 0; + } + + EstimateRecord curest; + get_bin_est(entity_.get_vector(entry_point), curest, *entity); + + for (level_t cur_level = maxLevel; cur_level >= 1; --cur_level) { + select_entry_point(cur_level, &entry_point, &curest, ctx, entity); + } + + auto &topk_heap = ctx->topk_heap(); + topk_heap.clear(); + search_neighbors(0, &entry_point, &curest, topk_heap, ctx, entity); + + if (ctx->group_by_search()) { + expand_neighbors_by_group(topk_heap, ctx); + } + + return 0; +} + + +//! select_entry_point on hnsw level, ef = 1 +void HnswRabitqQueryAlgorithm::select_entry_point( + level_t level, node_id_t *entry_point, EstimateRecord *curest, + HnswRabitqContext *ctx, HnswRabitqQueryEntity *query_entity) const { + auto &entity = ctx->get_entity(); + while (true) { + const Neighbors neighbors = entity.get_neighbors(level, *entry_point); + if (ailego_unlikely(ctx->debugging())) { + (*ctx->mutable_stats_get_neighbors())++; + } + ailego_prefetch(neighbors.data); + uint32_t size = neighbors.size(); + if (size == 0) { + break; + } + + // TODO: use MemoryBlock + // + // std::vector neighbor_vec_blocks; + // // 需要调用provider来获取 + // int ret = entity.get_vector(&neighbors[0], size, neighbor_vec_blocks); + // if (ailego_unlikely(ctx->debugging())) { + // (*ctx->mutable_stats_get_vector())++; + // } + // if (ailego_unlikely(ret != 0)) { + // break; + // } + + bool find_closer = false; + for (uint32_t i = 0; i < size; ++i) { + EstimateRecord candest; + get_bin_est(entity_.get_vector(neighbors[i]), candest, *query_entity); + + if (candest.est_dist < curest->est_dist) { + *curest = candest; + *entry_point = neighbors[i]; + find_closer = true; + } + } + + if (!find_closer) { + break; + } + } + + return; +} + +void HnswRabitqQueryAlgorithm::search_neighbors( + level_t level, node_id_t *entry_point, EstimateRecord *dist, TopkHeap &topk, + HnswRabitqContext *ctx, HnswRabitqQueryEntity *query_entity) const { + const auto &entity = ctx->get_entity(); + VisitFilter &visit = ctx->visit_filter(); + CandidateHeap &candidates = ctx->candidates(); + std::function filter = [](node_id_t) { return false; }; + if (ctx->filter().is_valid()) { + filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); }; + } + + candidates.clear(); + visit.clear(); + visit.set_visited(*entry_point); + if (!filter(*entry_point)) { + topk.emplace(*entry_point, ResultRecord(*dist)); + } + + candidates.emplace(*entry_point, ResultRecord(*dist)); + while (!candidates.empty() && !ctx->reach_scan_limit()) { + auto top = candidates.begin(); + node_id_t main_node = top->first; + auto main_dist = top->second; + + if (topk.full() && main_dist.est_dist > topk[0].second.est_dist) { + break; + } + + candidates.pop(); + const Neighbors neighbors = entity.get_neighbors(level, main_node); + ailego_prefetch(neighbors.data); + if (ailego_unlikely(ctx->debugging())) { + (*ctx->mutable_stats_get_neighbors())++; + } + + node_id_t neighbor_ids[neighbors.size()]; + uint32_t size = 0; + for (uint32_t i = 0; i < neighbors.size(); ++i) { + node_id_t node = neighbors[i]; + if (visit.visited(node)) { + if (ailego_unlikely(ctx->debugging())) { + (*ctx->mutable_stats_visit_dup_cnt())++; + } + continue; + } + visit.set_visited(node); + neighbor_ids[size++] = node; + } + if (size == 0) { + continue; + } + + // std::vector neighbor_vec_blocks; + // int ret = entity.get_vector(neighbor_ids, size, neighbor_vec_blocks); + // if (ailego_unlikely(ctx->debugging())) { + // (*ctx->mutable_stats_get_vector())++; + // } + // if (ailego_unlikely(ret != 0)) { + // break; + // } + + // // do prefetch + // static constexpr node_id_t BATCH_SIZE = 12; + // static constexpr node_id_t PREFETCH_STEP = 2; + // for (uint32_t i = 0; i < std::min(BATCH_SIZE * PREFETCH_STEP, size); ++i) + // { + // ailego_prefetch(neighbor_vec_blocks[i].data()); + // } + // // done + + // float dists[size]; + // const void *neighbor_vecs[size]; + + // for (uint32_t i = 0; i < size; ++i) { + // neighbor_vecs[i] = neighbor_vec_blocks[i].data(); + // } + + // dc.batch_dist(neighbor_vecs, size, dists); + + for (uint32_t i = 0; i < size; ++i) { + node_id_t node = neighbor_ids[i]; + EstimateRecord candest; + auto *cand_vector = entity_.get_vector(node); + ailego_prefetch(cand_vector); + get_bin_est(cand_vector, candest, *query_entity); + + if (ex_bits_ > 0) { + // Check preliminary score against current worst full estimate. + bool flag_update_KNNs = + (!topk.full()) || candest.low_dist < topk[0].second.est_dist; + + if (flag_update_KNNs) { + // Compute the full estimate if promising. + get_full_est(cand_vector, candest, *query_entity); + } else { + continue; + } + } else { + // Candidate cand{ResultRecord(candest.est_dist, candest.low_dist), + // static_cast(candidate_id)}; + // boundedKNN.insert(cand); + } + candidates.emplace(node, ResultRecord(candest)); + // update entry_point for next level scan + if (candest < *dist) { + *entry_point = node; + *dist = candest; + } + if (!filter(node)) { + topk.emplace(node, ResultRecord(candest)); + } + + + // TODO: check loop type + + // if ((!topk.full()) || cur_dist < topk[0].second) { + // candidates.emplace(node, cur_dist); + // // update entry_point for next level scan + // if (cur_dist < *dist) { + // *entry_point = node; + // *dist = cur_dist; + // } + // if (!filter(node)) { + // topk.emplace(node, cur_dist); + // } + // } // end if + + } // end for + } // while + + return; +} + +void HnswRabitqQueryAlgorithm::expand_neighbors_by_group( + TopkHeap &topk, HnswRabitqContext *ctx) const { + // if (!ctx->group_by().is_valid()) { + // return; + // } + + // const auto &entity = ctx->get_entity(); + // std::function group_by = [&](node_id_t id) { + // return ctx->group_by()(entity.get_key(id)); + // }; + + // // devide into groups + // std::map &group_topk_heaps = + // ctx->group_topk_heaps(); for (uint32_t i = 0; i < topk.size(); ++i) { + // node_id_t id = topk[i].first; + // auto score = topk[i].second; + + // std::string group_id = group_by(id); + + // auto &topk_heap = group_topk_heaps[group_id]; + // if (topk_heap.empty()) { + // topk_heap.limit(ctx->group_topk()); + // } + // topk_heap.emplace_back(id, score); + // } + + // // stage 2, expand to reach group num as possible + // if (group_topk_heaps.size() < ctx->group_num()) { + // VisitFilter &visit = ctx->visit_filter(); + // CandidateHeap &candidates = ctx->candidates(); + // HnswRabitqDistCalculator &dc = ctx->dist_calculator(); + + // std::function filter = [](node_id_t) { return false; }; + // if (ctx->filter().is_valid()) { + // filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); + // }; + // } + + // // refill to get enough groups + // candidates.clear(); + // visit.clear(); + // for (uint32_t i = 0; i < topk.size(); ++i) { + // node_id_t id = topk[i].first; + // auto score = topk[i].second; + + // visit.set_visited(id); + // candidates.emplace_back(id, score); + // } + + // // do expand + // while (!candidates.empty() && !ctx->reach_scan_limit()) { + // auto top = candidates.begin(); + // node_id_t main_node = top->first; + + // candidates.pop(); + // const Neighbors neighbors = entity.get_neighbors(0, main_node); + // if (ailego_unlikely(ctx->debugging())) { + // (*ctx->mutable_stats_get_neighbors())++; + // } + + // node_id_t neighbor_ids[neighbors.size()]; + // uint32_t size = 0; + // for (uint32_t i = 0; i < neighbors.size(); ++i) { + // node_id_t node = neighbors[i]; + // if (visit.visited(node)) { + // if (ailego_unlikely(ctx->debugging())) { + // (*ctx->mutable_stats_visit_dup_cnt())++; + // } + // continue; + // } + // visit.set_visited(node); + // neighbor_ids[size++] = node; + // } + // if (size == 0) { + // continue; + // } + + // std::vector neighbor_vec_blocks; + // int ret = entity.get_vector(neighbor_ids, size, neighbor_vec_blocks); + // if (ailego_unlikely(ctx->debugging())) { + // (*ctx->mutable_stats_get_vector())++; + // } + // if (ailego_unlikely(ret != 0)) { + // break; + // } + + // static constexpr node_id_t PREFETCH_STEP = 2; + // for (uint32_t i = 0; i < size; ++i) { + // node_id_t node = neighbor_ids[i]; + // node_id_t prefetch_id = i + PREFETCH_STEP; + // if (prefetch_id < size) { + // ailego_prefetch(neighbor_vec_blocks[prefetch_id].data()); + // } + // dist_t cur_dist = dc.dist(neighbor_vec_blocks[i].data()); + + // if (!filter(node)) { + // std::string group_id = group_by(node); + + // auto &topk_heap = group_topk_heaps[group_id]; + // if (topk_heap.empty()) { + // topk_heap.limit(ctx->group_topk()); + // } + // topk_heap.emplace_back(node, cur_dist); + + // if (group_topk_heaps.size() >= ctx->group_num()) { + // break; + // } + // } + + // candidates.emplace(node, cur_dist); + // } // end for + // } // end while + // } // end if +} + +void HnswRabitqQueryAlgorithm::get_bin_est( + const void *vector, EstimateRecord &res, + HnswRabitqQueryEntity &entity) const { + const auto &q_to_centroids = entity.q_to_centroids; + auto &query_wrapper = *entity.query_wrapper; + uint32_t cluster_id = entity_.get_cluster_id(vector); + const char *bin_data = entity_.get_bin_data(vector); + if (metric_type_ == rabitqlib::METRIC_IP) { + float norm = q_to_centroids[cluster_id]; + float error = q_to_centroids[cluster_id + num_clusters_]; + rabitqlib::split_single_estdist(bin_data, query_wrapper, padded_dim_, + res.ip_x0_qr, res.est_dist, res.low_dist, + -norm, error); + } else { + // L2 distance + float norm = q_to_centroids[cluster_id]; + rabitqlib::split_single_estdist(bin_data, query_wrapper, padded_dim_, + res.ip_x0_qr, res.est_dist, res.low_dist, + norm * norm, norm); + } +} + +void HnswRabitqQueryAlgorithm::get_ex_est(const void *vector, + EstimateRecord &res, + HnswRabitqQueryEntity &entity) const { + const auto &q_to_centroids = entity.q_to_centroids; + auto &query_wrapper = *entity.query_wrapper; + uint32_t cluster_id = entity_.get_cluster_id(vector); + const char *ex_data = entity_.get_ex_data(vector); + query_wrapper.set_g_add(q_to_centroids[cluster_id]); + float est_dist = rabitqlib::split_distance_boosting( + ex_data, ip_func_, query_wrapper, padded_dim_, ex_bits_, res.ip_x0_qr); + float low_dist = est_dist - (res.est_dist - res.low_dist) / (1 << ex_bits_); + res.est_dist = est_dist; + res.low_dist = low_dist; + // Note that res.ip_x0_qr becomes invalid after this function. +} + +void HnswRabitqQueryAlgorithm::get_full_est( + const void *vector, EstimateRecord &res, + HnswRabitqQueryEntity &entity) const { + const auto &q_to_centroids = entity.q_to_centroids; + auto &query_wrapper = *entity.query_wrapper; + uint32_t cluster_id = entity_.get_cluster_id(vector); + const char *bin_data = entity_.get_bin_data(vector); + const char *ex_data = entity_.get_ex_data(vector); + + if (metric_type_ == rabitqlib::METRIC_IP) { + float norm = q_to_centroids[cluster_id]; + float error = q_to_centroids[cluster_id + num_clusters_]; + rabitqlib::split_single_fulldist(bin_data, ex_data, ip_func_, query_wrapper, + padded_dim_, ex_bits_, res.est_dist, + res.low_dist, res.ip_x0_qr, -norm, error); + } else { + // L2 distance + float norm = q_to_centroids[cluster_id]; + rabitqlib::split_single_fulldist( + bin_data, ex_data, ip_func_, query_wrapper, padded_dim_, ex_bits_, + res.est_dist, res.low_dist, res.ip_x0_qr, norm * norm, norm); + } +} + + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_query_algorithm.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_query_algorithm.h new file mode 100644 index 00000000..e3df766f --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_query_algorithm.h @@ -0,0 +1,136 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "hnsw_rabitq_context.h" +#include "hnsw_rabitq_dist_calculator.h" +#include "hnsw_rabitq_entity.h" +#include "hnsw_rabitq_query_entity.h" + +namespace zvec { +namespace core { + +//! hnsw graph algorithm implement +class HnswRabitqQueryAlgorithm { + public: + typedef std::unique_ptr UPointer; + + public: + //! Constructor + explicit HnswRabitqQueryAlgorithm(HnswRabitqEntity &entity, + size_t num_clusters, + rabitqlib::MetricType metric_type); + + //! Destructor + ~HnswRabitqQueryAlgorithm() = default; + + //! Cleanup HnswRabitqQueryAlgorithm + int cleanup(); + + //! do knn search in graph + //! return 0 on success, or errCode in failure. results saved in ctx + int search(HnswRabitqQueryEntity *entity, HnswRabitqContext *ctx) const; + + //! Initiate HnswRabitqQueryAlgorithm + int init() { + level_probas_.clear(); + double level_mult = + 1 / std::log(static_cast(entity_.scaling_factor())); + for (int level = 0;; level++) { + // refers faiss get_random_level alg + double proba = + std::exp(-level / level_mult) * (1 - std::exp(-1 / level_mult)); + if (proba < 1e-9) { + break; + } + level_probas_.push_back(proba); + } + + return 0; + } + + //! Generate a random level + //! return graph level + uint32_t get_random_level() const { + // gen rand float (0, 1) + double f = mt_() / static_cast(mt_.max()); + for (size_t level = 0; level < level_probas_.size(); level++) { + if (f < level_probas_[level]) { + return level; + } + f -= level_probas_[level]; + } + return level_probas_.size() - 1; + } + void get_full_est(node_id_t id, EstimateRecord &res, + HnswRabitqQueryEntity &entity) const { + return get_full_est(entity_.get_vector(id), res, entity); + } + + private: + //! Select in upper layer to get entry point for next layer search + void select_entry_point(level_t level, node_id_t *entry_point, + EstimateRecord *dist, HnswRabitqContext *ctx, + HnswRabitqQueryEntity *entity) const; + + + //! Given a node id and level, search the nearest neighbors in graph + //! Note: the nearest neighbors result keeps in topk, and entry_point and + //! dist will be updated to current level nearest node id and distance + void search_neighbors(level_t level, node_id_t *entry_point, + EstimateRecord *dist, TopkHeap &topk, + HnswRabitqContext *ctx, + HnswRabitqQueryEntity *entity) const; + + + //! expand neighbors until group nums are reached + void expand_neighbors_by_group(TopkHeap &topk, HnswRabitqContext *ctx) const; + + void get_full_est(const void *vector, EstimateRecord &res, + HnswRabitqQueryEntity &entity) const; + void get_bin_est(const void *vector, EstimateRecord &res, + HnswRabitqQueryEntity &entity) const; + + void get_ex_est(const void *vector, EstimateRecord &res, + HnswRabitqQueryEntity &entity) const; + + private: + HnswRabitqQueryAlgorithm(const HnswRabitqQueryAlgorithm &) = delete; + HnswRabitqQueryAlgorithm &operator=(const HnswRabitqQueryAlgorithm &) = + delete; + + + private: + static constexpr uint32_t kLockCnt{1U << 8}; + static constexpr uint32_t kLockMask{kLockCnt - 1U}; + + HnswRabitqEntity &entity_; + mutable std::mt19937 mt_{}; + std::vector level_probas_{}; + + mutable ailego::SpinMutex spin_lock_{}; // global spin lock + std::mutex mutex_{}; // global mutex + // TODO: spin lock? + std::vector lock_pool_{}; + size_t num_clusters_{0}; + rabitqlib::MetricType metric_type_{rabitqlib::METRIC_L2}; + size_t padded_dim_{0}; + size_t ex_bits_{0}; + float (*ip_func_)(const float *, const uint8_t *, size_t); +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_query_entity.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_query_entity.h new file mode 100644 index 00000000..760cdd2e --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_query_entity.h @@ -0,0 +1,28 @@ +// Copyright 2025-present the centaurdb project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +#pragma once + +#include +#include +#include +namespace zvec::core { + +struct HnswRabitqQueryEntity { + std::vector rotated_query; + std::vector q_to_centroids; + std::unique_ptr> query_wrapper; +}; + +} // namespace zvec::core \ No newline at end of file diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher.cc new file mode 100644 index 00000000..01dde32f --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher.cc @@ -0,0 +1,496 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "hnsw_rabitq_searcher.h" +#include "hnsw_rabitq_algorithm.h" +#include "hnsw_rabitq_entity.h" +#include "hnsw_rabitq_index_provider.h" +#include "hnsw_rabitq_params.h" +#include "hnsw_rabitq_searcher_entity.h" + +namespace zvec { +namespace core { + +HnswRabitqSearcher::HnswRabitqSearcher() {} + +HnswRabitqSearcher::~HnswRabitqSearcher() {} + +int HnswRabitqSearcher::init(const ailego::Params &search_params) { + params_ = search_params; + params_.get(PARAM_HNSW_RABITQ_SEARCHER_EF, &ef_); + params_.get(PARAM_HNSW_RABITQ_SEARCHER_MAX_SCAN_RATIO, &max_scan_ratio_); + params_.get(PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_ENABLE, + &bf_enabled_); + params_.get(PARAM_HNSW_RABITQ_SEARCHER_CHECK_CRC_ENABLE, &check_crc_enabled_); + params_.get(PARAM_HNSW_RABITQ_SEARCHER_NEIGHBORS_IN_MEMORY_ENABLE, + &neighbors_in_memory_enabled_); + params_.get(PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB, + &bf_negative_probability_); + params_.get(PARAM_HNSW_RABITQ_SEARCHER_BRUTE_FORCE_THRESHOLD, + &bruteforce_threshold_); + params_.get(PARAM_HNSW_RABITQ_SEARCHER_FORCE_PADDING_RESULT_ENABLE, + &force_padding_topk_enabled_); + + if (ef_ == 0) { + ef_ = HnswRabitqEntity::kDefaultEf; + } + if (bf_negative_probability_ <= 0.0f || bf_negative_probability_ >= 1.0f) { + LOG_ERROR( + "[%s] must be in range (0,1)", + PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB.c_str()); + return IndexError_InvalidArgument; + } + + entity_.set_neighbors_in_memory(neighbors_in_memory_enabled_); + + ailego::Params reformer_params; + reformer_params.set(PARAM_RABITQ_METRIC_NAME, meta_.metric_name()); + int ret = reformer_.init(reformer_params); + if (ret != 0) { + LOG_ERROR("Failed to initialize RabitqReformer: %d", ret); + return ret; + } + + state_ = STATE_INITED; + + LOG_DEBUG( + "Init params: ef=%u maxScanRatio=%f bfEnabled=%u checkCrcEnabled=%u " + "neighborsInMemoryEnabled=%u bfNagtiveProb=%f bruteForceThreshold=%u " + "forcePadding=%u", + ef_, max_scan_ratio_, bf_enabled_, check_crc_enabled_, + neighbors_in_memory_enabled_, bf_negative_probability_, + bruteforce_threshold_, force_padding_topk_enabled_); + + return 0; +} + +void HnswRabitqSearcher::print_debug_info() { + for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { + Neighbors neighbours = entity_.get_neighbors(0, id); + std::cout << "node: " << id << "; "; + for (uint32_t i = 0; i < neighbours.size(); ++i) { + std::cout << neighbours[i]; + + if (i == neighbours.size() - 1) { + std::cout << std::endl; + } else { + std::cout << ", "; + } + } + } +} + +int HnswRabitqSearcher::cleanup() { + LOG_INFO("Begin HnswRabitqSearcher:cleanup"); + + metric_.reset(); + meta_.clear(); + stats_.clear_attributes(); + stats_.set_loaded_count(0UL); + stats_.set_loaded_costtime(0UL); + max_scan_ratio_ = HnswRabitqEntity::kDefaultScanRatio; + max_scan_num_ = 0U; + ef_ = HnswRabitqEntity::kDefaultEf; + bf_enabled_ = false; + bf_negative_probability_ = HnswRabitqEntity::kDefaultBFNegativeProbability; + bruteforce_threshold_ = HnswRabitqEntity::kDefaultBruteForceThreshold; + check_crc_enabled_ = false; + neighbors_in_memory_enabled_ = false; + entity_.cleanup(); + state_ = STATE_INIT; + + LOG_INFO("End HnswRabitqSearcher:cleanup"); + + return 0; +} + +int HnswRabitqSearcher::load(IndexStorage::Pointer container, + IndexMetric::Pointer metric) { + if (state_ != STATE_INITED) { + LOG_ERROR("Init the searcher first before load index"); + return IndexError_Runtime; + } + + LOG_INFO("Begin HnswRabitqSearcher:load"); + + auto start_time = ailego::Monotime::MilliSeconds(); + + int ret = IndexHelper::DeserializeFromStorage(container.get(), &meta_); + if (ret != 0) { + LOG_ERROR("Failed to deserialize meta from container"); + return ret; + } + + ret = reformer_.load(container); + if (ret != 0) { + LOG_ERROR("Failed to load reformer from container: %d", ret); + return ret; + } + + ret = entity_.load(container, check_crc_enabled_); + if (ret != 0) { + LOG_ERROR("HnswRabitqSearcher load index failed"); + return ret; + } + + alg_ = HnswRabitqQueryAlgorithm::UPointer(new HnswRabitqQueryAlgorithm( + entity_, reformer_.num_clusters(), reformer_.rabitq_metric_type())); + + if (metric) { + metric_ = metric; + } else { + metric_ = IndexFactory::CreateMetric(meta_.metric_name()); + if (!metric_) { + LOG_ERROR("CreateMetric failed, name: %s", meta_.metric_name().c_str()); + return IndexError_NoExist; + } + ret = metric_->init(meta_, meta_.metric_params()); + if (ret != 0) { + LOG_ERROR("IndexMetric init failed, ret=%d", ret); + return ret; + } + if (metric_->query_metric()) { + metric_ = metric_->query_metric(); + } + } + + if (!metric_->is_matched(meta_)) { + LOG_ERROR("IndexMetric not match index meta"); + return IndexError_Mismatch; + } + + max_scan_num_ = static_cast(max_scan_ratio_ * entity_.doc_cnt()); + max_scan_num_ = std::max(4096U, max_scan_num_); + + stats_.set_loaded_count(entity_.doc_cnt()); + stats_.set_loaded_costtime(ailego::Monotime::MilliSeconds() - start_time); + state_ = STATE_LOADED; + magic_ = IndexContext::GenerateMagic(); + + LOG_INFO("End HnswRabitqSearcher::load"); + + return 0; +} + +int HnswRabitqSearcher::unload() { + LOG_INFO("HnswRabitqSearcher unload index"); + + meta_.clear(); + entity_.cleanup(); + metric_.reset(); + max_scan_num_ = 0; + stats_.set_loaded_count(0UL); + stats_.set_loaded_costtime(0UL); + state_ = STATE_INITED; + + return 0; +} + +int HnswRabitqSearcher::update_context(HnswRabitqContext *ctx) const { + const HnswRabitqEntity::Pointer entity = entity_.clone(); + if (!entity) { + LOG_ERROR("Failed to clone search context entity"); + return IndexError_Runtime; + } + ctx->set_max_scan_num(max_scan_num_); + ctx->set_bruteforce_threshold(bruteforce_threshold_); + + return ctx->update_context(HnswRabitqContext::kSearcherContext, meta_, + metric_, entity, magic_); +} + +int HnswRabitqSearcher::search_impl(const void *query, + const IndexQueryMeta &qmeta, uint32_t count, + Context::Pointer &context) const { + if (ailego_unlikely(!query || !context)) { + LOG_ERROR("The context is not created by this searcher"); + return IndexError_Mismatch; + } + HnswRabitqContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to HnswRabitqContext failed"); + return IndexError_Cast; + } + + if (entity_.doc_cnt() <= ctx->get_bruteforce_threshold()) { + return search_bf_impl(query, qmeta, count, context); + } + // return search_bf_impl(query, qmeta, count, context); + + if (ctx->magic() != magic_) { + //! context is created by another searcher or streamer + int ret = update_context(ctx); + if (ret != 0) { + return ret; + } + } + + ctx->clear(); + ctx->resize_results(count); + for (size_t q = 0; q < count; ++q) { + HnswRabitqQueryEntity entity; + int ret = reformer_.transform_to_entity(query, &entity); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw searcher transform failed"); + return ret; + } + ctx->reset_query(query); + ret = alg_->search(&entity, ctx); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw searcher fast search failed"); + return ret; + } + ctx->topk_to_result(q); + query = static_cast(query) + qmeta.element_size(); + } + + if (ailego_unlikely(ctx->error())) { + return IndexError_Runtime; + } + + return 0; +} + +int HnswRabitqSearcher::search_bf_impl(const void *query, + const IndexQueryMeta &qmeta, + uint32_t count, + Context::Pointer &context) const { + if (ailego_unlikely(!query || !context)) { + LOG_ERROR("The context is not created by this searcher"); + return IndexError_Mismatch; + } + HnswRabitqContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to HnswRabitqContext failed"); + return IndexError_Cast; + } + if (ctx->magic() != magic_) { + //! context is created by another searcher or streamer + int ret = update_context(ctx); + if (ret != 0) { + return ret; + } + } + + ctx->clear(); + ctx->resize_results(count); + + if (ctx->group_by_search()) { + // if (!ctx->group_by().is_valid()) { + // LOG_ERROR("Invalid group-by function"); + // return IndexError_InvalidArgument; + // } + + // std::function group_by = [&](node_id_t id) { + // return ctx->group_by()(entity_.get_key(id)); + // }; + + // for (size_t q = 0; q < count; ++q) { + // ctx->reset_query(query); + // ctx->group_topk_heaps().clear(); + + // for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { + // if (entity_.get_key(id) == kInvalidKey) { + // continue; + // } + // if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) + // { + // dist_t dist = ctx->dist_calculator().dist(id); + + // std::string group_id = group_by(id); + + // auto &topk_heap = ctx->group_topk_heaps()[group_id]; + // if (topk_heap.empty()) { + // topk_heap.limit(ctx->group_topk()); + // } + // topk_heap.emplace_back(id, dist); + // } + // } + // ctx->topk_to_result(q); + // query = static_cast(query) + qmeta.element_size(); + // } + } else { + for (size_t q = 0; q < count; ++q) { + HnswRabitqQueryEntity entity; + int ret = reformer_.transform_to_entity(query, &entity); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw searcher transform failed"); + return ret; + } + ctx->reset_query(query); + ctx->topk_heap().clear(); + for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { + if (entity_.get_key(id) == kInvalidKey) { + continue; + } + if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { + EstimateRecord dist; + alg_->get_full_est(id, dist, entity); + ctx->topk_heap().emplace(id, dist); + } + } + ctx->topk_to_result(q); + query = static_cast(query) + qmeta.element_size(); + } + } + + if (ailego_unlikely(ctx->error())) { + return IndexError_Runtime; + } + + return 0; +} + +int HnswRabitqSearcher::search_bf_by_p_keys_impl( + const void *query, const std::vector> &p_keys, + const IndexQueryMeta &qmeta, uint32_t count, + Context::Pointer &context) const { + // if (ailego_unlikely(!query || !context)) { + // LOG_ERROR("The context is not created by this searcher"); + // return IndexError_Mismatch; + // } + + // if (ailego_unlikely(p_keys.size() != count)) { + // LOG_ERROR("The size of p_keys is not equal to count"); + // return IndexError_InvalidArgument; + // } + + // HnswRabitqContext *ctx = dynamic_cast(context.get()); + // ailego_do_if_false(ctx) { + // LOG_ERROR("Cast context to HnswRabitqContext failed"); + // return IndexError_Cast; + // } + // if (ctx->magic() != magic_) { + // //! context is created by another searcher or streamer + // int ret = update_context(ctx); + // if (ret != 0) { + // return ret; + // } + // } + + // ctx->clear(); + // ctx->resize_results(count); + + // if (ctx->group_by_search()) { + // if (!ctx->group_by().is_valid()) { + // LOG_ERROR("Invalid group-by function"); + // return IndexError_InvalidArgument; + // } + + // std::function group_by = [&](node_id_t id) { + // return ctx->group_by()(entity_.get_key(id)); + // }; + + // for (size_t q = 0; q < count; ++q) { + // ctx->reset_query(query); + // ctx->group_topk_heaps().clear(); + + // for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { + // uint64_t pk = p_keys[q][idx]; + // if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { + // node_id_t id = entity_.get_id(pk); + // if (id != kInvalidNodeId) { + // dist_t dist = ctx->dist_calculator().dist(id); + // std::string group_id = group_by(id); + + // auto &topk_heap = ctx->group_topk_heaps()[group_id]; + // if (topk_heap.empty()) { + // topk_heap.limit(ctx->group_topk()); + // } + // topk_heap.emplace_back(id, dist); + // } + // } + // } + // ctx->topk_to_result(q); + // query = static_cast(query) + qmeta.element_size(); + // } + // } else { + // for (size_t q = 0; q < count; ++q) { + // ctx->reset_query(query); + // ctx->topk_heap().clear(); + // for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { + // uint64_t pk = p_keys[q][idx]; + // if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { + // node_id_t id = entity_.get_id(pk); + // if (id != kInvalidNodeId) { + // dist_t dist = ctx->dist_calculator().dist(id); + // ctx->topk_heap().emplace(id, dist); + // } + // } + // } + // ctx->topk_to_result(q); + // query = static_cast(query) + qmeta.element_size(); + // } + // } + + // if (ailego_unlikely(ctx->error())) { + // return IndexError_Runtime; + // } + + return 0; +} + +IndexSearcher::Context::Pointer HnswRabitqSearcher::create_context() const { + if (ailego_unlikely(state_ != STATE_LOADED)) { + LOG_ERROR("Load the index first before create context"); + return Context::Pointer(); + } + const HnswRabitqEntity::Pointer search_ctx_entity = entity_.clone(); + if (!search_ctx_entity) { + LOG_ERROR("Failed to create search context entity"); + return Context::Pointer(); + } + HnswRabitqContext *ctx = new (std::nothrow) + HnswRabitqContext(meta_.dimension(), metric_, search_ctx_entity); + if (ailego_unlikely(ctx == nullptr)) { + LOG_ERROR("Failed to new HnswRabitqContext"); + return Context::Pointer(); + } + ctx->set_ef(ef_); + ctx->set_max_scan_num(max_scan_num_); + uint32_t filter_mode = + bf_enabled_ ? VisitFilter::BloomFilter : VisitFilter::ByteMap; + ctx->set_filter_mode(filter_mode); + ctx->set_filter_negative_probability(bf_negative_probability_); + ctx->set_magic(magic_); + ctx->set_force_padding_topk(force_padding_topk_enabled_); + ctx->set_bruteforce_threshold(bruteforce_threshold_); + if (ailego_unlikely(ctx->init(HnswRabitqContext::kSearcherContext)) != 0) { + LOG_ERROR("Init HnswRabitqContext failed"); + delete ctx; + return Context::Pointer(); + } + + return Context::Pointer(ctx); +} + +IndexProvider::Pointer HnswRabitqSearcher::create_provider(void) const { + LOG_DEBUG("HnswRabitqSearcher create provider"); + + auto entity = entity_.clone(); + if (ailego_unlikely(!entity)) { + LOG_ERROR("Clone HnswRabitqEntity failed"); + return Provider::Pointer(); + } + return Provider::Pointer(new (std::nothrow) HnswRabitqIndexProvider( + meta_, entity, "HnswRabitqSearcher")); +} + +const void *HnswRabitqSearcher::get_vector(uint64_t key) const { + return entity_.get_vector_by_key(key); +} + +INDEX_FACTORY_REGISTER_SEARCHER(HnswRabitqSearcher); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher.h new file mode 100644 index 00000000..062498e6 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher.h @@ -0,0 +1,142 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "zvec/core/framework/index_framework.h" +#include "hnsw_rabitq_query_algorithm.h" +#include "hnsw_rabitq_searcher_entity.h" +#include "rabitq_reformer.h" + +namespace zvec { +namespace core { + +class HnswRabitqSearcher : public IndexSearcher { + public: + using ContextPointer = IndexSearcher::Context::Pointer; + + public: + HnswRabitqSearcher(void); + ~HnswRabitqSearcher(void); + + HnswRabitqSearcher(const HnswRabitqSearcher &) = delete; + HnswRabitqSearcher &operator=(const HnswRabitqSearcher &) = delete; + + protected: + //! Initialize Searcher + virtual int init(const ailego::Params ¶ms) override; + + //! Cleanup Searcher + virtual int cleanup(void) override; + + //! Load Index from storage + virtual int load(IndexStorage::Pointer container, + IndexMetric::Pointer metric) override; + + //! Unload index from storage + virtual int unload(void) override; + + //! KNN Search + virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, + ContextPointer &context) const override { + return search_impl(query, qmeta, 1, context); + } + + //! KNN Search + virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, + ContextPointer &context) const override; + + //! Linear Search + virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, + ContextPointer &context) const override { + return search_bf_impl(query, qmeta, 1, context); + } + + //! Linear Search + virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, + ContextPointer &context) const override; + + //! Linear search by primary keys + virtual int search_bf_by_p_keys_impl( + const void *query, const std::vector> &p_keys, + const IndexQueryMeta &qmeta, ContextPointer &context) const override { + return search_bf_by_p_keys_impl(query, p_keys, qmeta, 1, context); + } + + //! Linear search by primary keys + virtual int search_bf_by_p_keys_impl( + const void *query, const std::vector> &p_keys, + const IndexQueryMeta &qmeta, uint32_t count, + ContextPointer &context) const override; + + //! Fetch vector by key + virtual const void *get_vector(uint64_t key) const override; + + //! Create a searcher context + virtual ContextPointer create_context() const override; + + //! Create a new iterator + virtual IndexProvider::Pointer create_provider(void) const override; + + //! Retrieve statistics + virtual const Stats &stats(void) const override { + return stats_; + } + + //! Retrieve meta of index + virtual const IndexMeta &meta(void) const override { + return meta_; + } + + //! Retrieve params of index + virtual const ailego::Params ¶ms(void) const override { + return params_; + } + + virtual void print_debug_info() override; + + private: + //! To share ctx across streamer/searcher, we need to update the context for + //! current streamer/searcher + int update_context(HnswRabitqContext *ctx) const; + + private: + enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; + + HnswRabitqSearcherEntity entity_{}; + HnswRabitqQueryAlgorithm::UPointer alg_; // impl graph algorithm + + IndexMetric::Pointer metric_{}; + IndexMeta meta_{}; + ailego::Params params_{}; + Stats stats_; + uint32_t ef_{HnswRabitqEntity::kDefaultEf}; + uint32_t max_scan_num_{0U}; + uint32_t bruteforce_threshold_{HnswRabitqEntity::kDefaultBruteForceThreshold}; + float max_scan_ratio_{HnswRabitqEntity::kDefaultScanRatio}; + bool bf_enabled_{false}; + bool check_crc_enabled_{false}; + bool neighbors_in_memory_enabled_{false}; + bool force_padding_topk_enabled_{false}; + float bf_negative_probability_{ + HnswRabitqEntity::kDefaultBFNegativeProbability}; + uint32_t magic_{0U}; + RabitqReformer reformer_; + + State state_{STATE_INIT}; +}; + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher_entity.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher_entity.cc new file mode 100644 index 00000000..415060c2 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher_entity.cc @@ -0,0 +1,515 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "hnsw_rabitq_searcher_entity.h" +#include +#include "utility/sparse_utility.h" + +namespace zvec { +namespace core { + +HnswRabitqSearcherEntity::HnswRabitqSearcherEntity() {} + +int HnswRabitqSearcherEntity::cleanup(void) { + storage_.reset(); + vectors_.reset(); + keys_.reset(); + neighbors_.reset(); + neighbors_meta_.reset(); + neighbors_in_memory_enabled_ = false; + loaded_ = false; + + this->HnswRabitqEntity::cleanup(); + + return 0; +} + +key_t HnswRabitqSearcherEntity::get_key(node_id_t id) const { + const void *key; + if (ailego_unlikely(keys_->read(id * sizeof(key_t), &key, sizeof(key_t)) != + sizeof(key_t))) { + LOG_ERROR("Read key from segment failed"); + return kInvalidKey; + } + return *(reinterpret_cast(key)); +} + +//! Get vector local id by key +node_id_t HnswRabitqSearcherEntity::get_id(key_t key) const { + if (ailego_unlikely(!mapping_)) { + LOG_ERROR("Index missing mapping segment"); + return kInvalidNodeId; + } + + //! Do binary search + node_id_t start = 0UL; + node_id_t end = doc_cnt(); + const void *data; + node_id_t idx = 0u; + while (start < end) { + idx = start + (end - start) / 2; + if (ailego_unlikely( + mapping_->read(idx * sizeof(node_id_t), &data, sizeof(node_id_t)) != + sizeof(node_id_t))) { + LOG_ERROR("Read key from segment failed"); + return kInvalidNodeId; + } + const key_t *mkey; + node_id_t local_id = *reinterpret_cast(data); + if (ailego_unlikely(keys_->read(local_id * sizeof(key_t), + (const void **)(&mkey), + sizeof(key_t)) != sizeof(key_t))) { + LOG_ERROR("Read key from segment failed"); + return kInvalidNodeId; + } + if (*mkey < key) { + start = idx + 1; + } else if (*mkey > key) { + end = idx; + } else { + return local_id; + } + } + return kInvalidNodeId; +} + +const void *HnswRabitqSearcherEntity::get_vector_by_key(key_t key) const { + node_id_t local_id = get_id(key); + if (ailego_unlikely(local_id == kInvalidNodeId)) { + return nullptr; + } + + return get_vector(local_id); +} + +const void *HnswRabitqSearcherEntity::get_vector(node_id_t id) const { + size_t read_size = vector_size(); + size_t offset = node_size() * id; + + const void *vec; + if (ailego_unlikely(vectors_->read(offset, &vec, read_size) != read_size)) { + LOG_ERROR("Read vector from segment failed"); + return nullptr; + } + return vec; +} + +int HnswRabitqSearcherEntity::get_vector( + const node_id_t id, IndexStorage::MemoryBlock &block) const { + const void *vec = get_vector(id); + block.reset((void *)vec); + return 0; +} + +const void *HnswRabitqSearcherEntity::get_vectors() const { + const void *vec; + size_t len = node_size() * doc_cnt(); + if (vectors_->read(0, &vec, len) != len) { + LOG_ERROR("Read vectors from segment failed"); + return nullptr; + } + return vec; +} + +int HnswRabitqSearcherEntity::get_vector(const node_id_t *ids, uint32_t count, + const void **vecs) const { + ailego_assert_with(count <= segment_datas_.size(), "invalid count"); + + size_t read_size = vector_size(); + + for (uint32_t i = 0; i < count; ++i) { + segment_datas_[i].offset = node_size() * ids[i]; + segment_datas_[i].length = read_size; + + ailego_assert_with(segment_datas_[i].offset < vectors_->data_size(), + "invalid offset"); + } + if (ailego_unlikely(!vectors_->read(&segment_datas_[0], count))) { + LOG_ERROR("Read vectors from segment failed"); + return IndexError_ReadData; + } + for (uint32_t i = 0; i < count; ++i) { + vecs[i] = segment_datas_[i].data; + } + + return 0; +} + +int HnswRabitqSearcherEntity::get_vector( + const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const { + const void *vecs[count]; + get_vector(ids, count, vecs); + for (uint32_t i = 0; i < count; ++i) { + vec_blocks.emplace_back(IndexStorage::MemoryBlock((void *)vecs[i])); + } + return 0; +} + +const Neighbors HnswRabitqSearcherEntity::get_neighbors(level_t level, + node_id_t id) const { + if (level == 0) { + if (neighbors_in_memory_enabled_) { + auto hd = reinterpret_cast( + fixed_neighbors_.get() + neighbors_size() * id); + return {hd->neighbor_cnt, hd->neighbors}; + } + + const GraphNeighborMeta *m; + if (ailego_unlikely(neighbors_meta_->read(id * sizeof(GraphNeighborMeta), + (const void **)(&m), + sizeof(GraphNeighborMeta)) != + sizeof(GraphNeighborMeta))) { + LOG_ERROR("Read neighbors meta from segment failed"); + return {0, nullptr}; + } + + const void *data; + if (ailego_unlikely(neighbors_->read(m->offset, &data, + m->neighbor_cnt * sizeof(node_id_t)) != + m->neighbor_cnt * sizeof(node_id_t))) { + LOG_ERROR("Read neighbors from segment failed"); + return {0, nullptr}; + } + return {static_cast(m->neighbor_cnt), + reinterpret_cast(data)}; + } + + //! Read level > 0 neighbors + const HnswNeighborMeta *m; + if (ailego_unlikely(upper_neighbors_meta_->read(id * sizeof(HnswNeighborMeta), + (const void **)(&m), + sizeof(HnswNeighborMeta)) != + sizeof(HnswNeighborMeta))) { + LOG_ERROR("Read neighbors meta from segment failed"); + return {0, nullptr}; + } + + ailego_assert_with(level <= m->level, "invalid level"); + size_t offset = m->offset + (level - 1) * upper_neighbors_size(); + ailego_assert_with(offset <= upper_neighbors_->data_size(), "invalid offset"); + const void *data; + if (ailego_unlikely( + upper_neighbors_->read(offset, &data, upper_neighbors_size()) != + upper_neighbors_size())) { + LOG_ERROR("Read neighbors from segment failed"); + return {0, nullptr}; + } + + auto hd = reinterpret_cast(data); + return {hd->neighbor_cnt, hd->neighbors}; +} + +int HnswRabitqSearcherEntity::load(const IndexStorage::Pointer &container, + bool check_crc) { + storage_ = container; + + int ret = load_segments(check_crc); + if (ret != 0) { + return ret; + } + + loaded_ = true; + + LOG_INFO( + "Index info: docCnt=%u entryPoint=%u maxLevel=%d efConstruct=%zu " + "l0NeighborCnt=%zu upperNeighborCnt=%zu scalingFactor=%zu " + "vectorSize=%zu nodeSize=%zu vectorSegmentSize=%zu keySegmentSize=%zu " + "neighborsSegmentSize=%zu neighborsMetaSegmentSize=%zu ", + doc_cnt(), entry_point(), cur_max_level(), ef_construction(), + l0_neighbor_cnt(), upper_neighbor_cnt(), scaling_factor(), vector_size(), + node_size(), vectors_->data_size(), keys_->data_size(), + neighbors_ == nullptr ? 0 : neighbors_->data_size(), + neighbors_meta_ == nullptr ? 0 : neighbors_meta_->data_size()); + + return 0; +} + +int HnswRabitqSearcherEntity::load_segments(bool check_crc) { + //! load header + const void *data = nullptr; + HNSWHeader hd; + auto graph_hd_segment = storage_->get(kGraphHeaderSegmentId); + if (!graph_hd_segment || graph_hd_segment->data_size() < sizeof(hd.graph)) { + LOG_ERROR("Miss or invalid segment %s", kGraphHeaderSegmentId.c_str()); + return IndexError_InvalidFormat; + } + if (graph_hd_segment->read(0, reinterpret_cast(&data), + sizeof(hd.graph)) != sizeof(hd.graph)) { + LOG_ERROR("Read segment %s failed", kGraphHeaderSegmentId.c_str()); + return IndexError_ReadData; + } + memcpy(&hd.graph, data, sizeof(hd.graph)); + + auto hnsw_hd_segment = storage_->get(kHnswHeaderSegmentId); + if (!hnsw_hd_segment || hnsw_hd_segment->data_size() < sizeof(hd.hnsw)) { + LOG_ERROR("Miss or invalid segment %s", kHnswHeaderSegmentId.c_str()); + return IndexError_InvalidFormat; + } + if (hnsw_hd_segment->read(0, reinterpret_cast(&data), + sizeof(hd.hnsw)) != sizeof(hd.hnsw)) { + LOG_ERROR("Read segment %s failed", kHnswHeaderSegmentId.c_str()); + return IndexError_ReadData; + } + memcpy(&hd.hnsw, data, sizeof(hd.hnsw)); + *mutable_header() = hd; + segment_datas_.resize(std::max(l0_neighbor_cnt(), upper_neighbor_cnt())); + + vectors_ = storage_->get(kGraphFeaturesSegmentId); + if (!vectors_) { + LOG_ERROR("IndexStorage get segment %s failed", + kGraphFeaturesSegmentId.c_str()); + return IndexError_InvalidFormat; + } + keys_ = storage_->get(kGraphKeysSegmentId); + if (!keys_) { + LOG_ERROR("IndexStorage get segment %s failed", + kGraphKeysSegmentId.c_str()); + return IndexError_InvalidFormat; + } + + neighbors_ = storage_->get(kGraphNeighborsSegmentId); + if (!neighbors_ || (neighbors_->data_size() == 0 && doc_cnt() > 1)) { + LOG_ERROR("IndexStorage get segment %s failed or empty", + kGraphNeighborsSegmentId.c_str()); + return IndexError_InvalidArgument; + } + neighbors_meta_ = storage_->get(kGraphOffsetsSegmentId); + if (!neighbors_meta_ || + neighbors_meta_->data_size() < sizeof(GraphNeighborMeta) * doc_cnt()) { + LOG_ERROR("IndexStorage get segment %s failed or invalid size", + kGraphOffsetsSegmentId.c_str()); + return IndexError_InvalidArgument; + } + + upper_neighbors_ = storage_->get(kHnswNeighborsSegmentId); + if (!upper_neighbors_ || + (upper_neighbors_->data_size() == 0 && cur_max_level() > 0)) { + LOG_ERROR("IndexStorage get segment %s failed or empty", + kHnswNeighborsSegmentId.c_str()); + return IndexError_InvalidArgument; + } + + upper_neighbors_meta_ = storage_->get(kHnswOffsetsSegmentId); + if (!upper_neighbors_meta_ || upper_neighbors_meta_->data_size() < + sizeof(HnswNeighborMeta) * doc_cnt()) { + LOG_ERROR("IndexStorage get segment %s failed or invalid size", + kHnswOffsetsSegmentId.c_str()); + return IndexError_InvalidArgument; + } + + mapping_ = storage_->get(kGraphMappingSegmentId); + if (!mapping_ || mapping_->data_size() < sizeof(node_id_t) * doc_cnt()) { + LOG_ERROR("IndexStorage get segment %s failed or invalid size", + kGraphMappingSegmentId.c_str()); + return IndexError_InvalidArgument; + } + + if (check_crc) { + std::vector segments; + segments.emplace_back(graph_hd_segment); + segments.emplace_back(hnsw_hd_segment); + segments.emplace_back(vectors_); + segments.emplace_back(keys_); + + segments.emplace_back(neighbors_); + segments.emplace_back(neighbors_meta_); + segments.emplace_back(upper_neighbors_); + segments.emplace_back(upper_neighbors_meta_); + + if (!do_crc_check(segments)) { + LOG_ERROR("Check index crc failed, the index may broken"); + return IndexError_Runtime; + } + } + + if (neighbors_in_memory_enabled_) { + int ret = load_and_flat_neighbors(); + if (ret != 0) { + return ret; + } + } + + return 0; +} + +int HnswRabitqSearcherEntity::load_and_flat_neighbors() { + fixed_neighbors_.reset( + new (std::nothrow) char[neighbors_size() * doc_cnt()]{}, + std::default_delete()); + if (!fixed_neighbors_) { + LOG_ERROR("Malloc memory failed"); + return IndexError_NoMemory; + } + + //! Get a new segemnt to release the buffer after loading neighbors + auto neighbors_meta = storage_->get(kGraphOffsetsSegmentId); + if (!neighbors_meta) { + LOG_ERROR("IndexStorage get segment graph.offsets failed"); + return IndexError_InvalidArgument; + } + + const GraphNeighborMeta *neighbors_index = nullptr; + if (neighbors_meta->read(0, reinterpret_cast(&neighbors_index), + neighbors_meta->data_size()) != + neighbors_meta->data_size()) { + LOG_ERROR("Read segment %s data failed", kGraphOffsetsSegmentId.c_str()); + return IndexError_InvalidArgument; + } + + const char *neighbor_data; + for (node_id_t id = 0; id < doc_cnt(); ++id) { + size_t rd_size = neighbors_index[id].neighbor_cnt * sizeof(node_id_t); + if (ailego_unlikely( + neighbors_->read(neighbors_index[id].offset, + reinterpret_cast(&neighbor_data), + rd_size) != rd_size)) { + LOG_ERROR("Read neighbors from segment failed"); + return IndexError_ReadData; + } + // copy level 0 neighbors to fixed size neighbors memory + char *dst = fixed_neighbors_.get() + neighbors_size() * id; + *reinterpret_cast(dst) = neighbors_index[id].neighbor_cnt; + memcpy(dst + sizeof(uint32_t), neighbor_data, rd_size); + } + + return 0; +} + +int HnswRabitqSearcherEntity::get_fixed_neighbors( + std::vector *fixed_neighbors) const { + //! Get a new segemnt to release the buffer after loading neighbors + auto neighbors_meta = storage_->get(kGraphOffsetsSegmentId); + if (!neighbors_meta) { + LOG_ERROR("IndexStorage get segment graph.offsets failed"); + return IndexError_InvalidArgument; + } + + const GraphNeighborMeta *neighbors_index = nullptr; + size_t meta_size = neighbors_meta->data_size(); + if (neighbors_meta->read(0, reinterpret_cast(&neighbors_index), + meta_size) != meta_size) { + LOG_ERROR("Read segment %s data failed", kGraphOffsetsSegmentId.c_str()); + return IndexError_InvalidArgument; + } + + size_t fixed_neighbor_cnt = l0_neighbor_cnt(); + fixed_neighbors->resize((fixed_neighbor_cnt + 1) * doc_cnt(), kInvalidNodeId); + + size_t neighbors_cnt_offset = fixed_neighbor_cnt * doc_cnt(); + size_t total_neighbor_cnt = 0; + for (node_id_t id = 0; id < doc_cnt(); ++id) { + size_t cur_neighbor_cnt = neighbors_index[id].neighbor_cnt; + if (cur_neighbor_cnt == 0) { + (*fixed_neighbors)[neighbors_cnt_offset + id] = 0; + continue; + } + size_t rd_size = cur_neighbor_cnt * sizeof(node_id_t); + const uint32_t *neighbors; + if (neighbors_->read(neighbors_index[id].offset, + reinterpret_cast(&neighbors), + rd_size) != rd_size) { + LOG_ERROR("Read neighbors from segment failed"); + return IndexError_ReadData; + } + + // copy level 0 neighbors to fixed size neighbors memory + auto it = fixed_neighbors->begin() + id * fixed_neighbor_cnt; + std::copy(neighbors, neighbors + cur_neighbor_cnt, it); + + (*fixed_neighbors)[neighbors_cnt_offset + id] = cur_neighbor_cnt; + total_neighbor_cnt += cur_neighbor_cnt; + } + LOG_INFO("total neighbor cnt: %zu, average neighbor cnt: %zu", + total_neighbor_cnt, total_neighbor_cnt / doc_cnt()); + + return 0; +} + +bool HnswRabitqSearcherEntity::do_crc_check( + std::vector &segments) const { + constexpr size_t blk_size = 4096; + const void *data; + for (auto &segment : segments) { + size_t offset = 0; + size_t rd_size; + uint32_t crc = 0; + while (offset < segment->data_size()) { + size_t size = std::min(blk_size, segment->data_size() - offset); + if ((rd_size = segment->read(offset, &data, size)) <= 0) { + break; + } + offset += rd_size; + crc = ailego::Crc32c::Hash(data, rd_size, crc); + } + if (crc != segment->data_crc()) { + return false; + } + } + return true; +} + +const HnswRabitqEntity::Pointer HnswRabitqSearcherEntity::clone() const { + auto vectors = vectors_->clone(); + if (ailego_unlikely(!vectors)) { + LOG_ERROR("clone segment %s failed", kGraphFeaturesSegmentId.c_str()); + return HnswRabitqEntity::Pointer(); + } + auto keys = keys_->clone(); + if (ailego_unlikely(!keys)) { + LOG_ERROR("clone segment %s failed", kGraphKeysSegmentId.c_str()); + return HnswRabitqEntity::Pointer(); + } + + auto mapping = mapping_->clone(); + if (ailego_unlikely(!mapping)) { + LOG_ERROR("clone segment %s failed", kGraphMappingSegmentId.c_str()); + return HnswRabitqEntity::Pointer(); + } + + auto neighbors = neighbors_->clone(); + if (ailego_unlikely(!neighbors)) { + LOG_ERROR("clone segment %s failed", kGraphNeighborsSegmentId.c_str()); + return HnswRabitqEntity::Pointer(); + } + auto upper_neighbors = upper_neighbors_->clone(); + if (ailego_unlikely(!neighbors)) { + LOG_ERROR("clone segment %s failed", kHnswNeighborsSegmentId.c_str()); + return HnswRabitqEntity::Pointer(); + } + auto neighbors_meta = neighbors_meta_->clone(); + if (ailego_unlikely(!neighbors_meta)) { + LOG_ERROR("clone segment %s failed", kGraphOffsetsSegmentId.c_str()); + return HnswRabitqEntity::Pointer(); + } + auto upper_neighbors_meta = upper_neighbors_meta_->clone(); + if (ailego_unlikely(!upper_neighbors_meta)) { + LOG_ERROR("clone segment %s failed", kHnswOffsetsSegmentId.c_str()); + return HnswRabitqEntity::Pointer(); + } + + SegmentGroupParam neighbor_group{neighbors, neighbors_meta, upper_neighbors, + upper_neighbors_meta}; + + HnswRabitqSearcherEntity *entity = new (std::nothrow) + HnswRabitqSearcherEntity(header(), vectors, keys, mapping, neighbor_group, + fixed_neighbors_, neighbors_in_memory_enabled_); + if (ailego_unlikely(!entity)) { + LOG_ERROR("HnswRabitqSearcherEntity new failed"); + } + + return HnswRabitqEntity::Pointer(entity); +} + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher_entity.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher_entity.h new file mode 100644 index 00000000..164e0fdf --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_searcher_entity.h @@ -0,0 +1,158 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "hnsw_rabitq_builder_entity.h" +#include "hnsw_rabitq_entity.h" + +namespace zvec { +namespace core { + +class HnswRabitqSearcherEntity : public HnswRabitqEntity { + public: + using Pointer = std::shared_ptr; + using SegmentPointer = IndexStorage::Segment::Pointer; + + public: + struct SegmentGroupParam { + SegmentGroupParam(SegmentPointer neighbors_in, + SegmentPointer neighbors_meta_in, + SegmentPointer upper_neighbors_in, + SegmentPointer upper_neighbors_meta_in) + : neighbors{neighbors_in}, + neighbors_meta{neighbors_meta_in}, + upper_neighbors{upper_neighbors_in}, + upper_neighbors_meta{upper_neighbors_meta_in} {} + + SegmentPointer neighbors{nullptr}; + SegmentPointer neighbors_meta{nullptr}; + SegmentPointer upper_neighbors{nullptr}; + SegmentPointer upper_neighbors_meta{nullptr}; + }; + + //! Constructor + HnswRabitqSearcherEntity(); + + //! Make a copy of searcher entity, to support thread-safe operation. + //! The segment in container cannot be read concurrenly + virtual const HnswRabitqEntity::Pointer clone() const override; + + //! Get primary key of the node id + virtual key_t get_key(node_id_t id) const override; + + //! Get vector local id by key + node_id_t get_id(key_t key) const; + + //! Get vector feature data by key + virtual const void *get_vector_by_key(key_t key) const override; + + //! Get vector feature data by id + virtual const void *get_vector(node_id_t id) const override; + + //! Get vector feature data by id + virtual int get_vector(const node_id_t *ids, uint32_t count, + const void **vecs) const override; + + virtual int get_vector(const node_id_t id, + IndexStorage::MemoryBlock &block) const override; + virtual int get_vector( + const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const override; + + //! Get all vectors + const void *get_vectors() const; + + //! Get the node id's neighbors on graph level + virtual const Neighbors get_neighbors(level_t level, + node_id_t id) const override; + + virtual int load(const IndexStorage::Pointer &container, + bool check_crc) override; + + int load_segments(bool check_crc); + + virtual int cleanup(void) override; + + public: + bool is_loaded() const { + return loaded_; + } + + void set_neighbors_in_memory(bool enabled) { + neighbors_in_memory_enabled_ = enabled; + } + + //! get fixed length neighbors data + int get_fixed_neighbors(std::vector *fixed_neighbors) const; + + private: + //! Constructor + HnswRabitqSearcherEntity(const HNSWHeader &hd, const SegmentPointer &vectors, + const SegmentPointer &keys, + const SegmentPointer &mapping, + const SegmentGroupParam &neighbor_group, + const std::shared_ptr &fixed_neighbors, + bool neighbors_in_memory_enabled) + : HnswRabitqEntity(hd), + vectors_(vectors), + keys_(keys), + mapping_(mapping), + neighbors_(neighbor_group.neighbors), + neighbors_meta_(neighbor_group.neighbors_meta), + upper_neighbors_(neighbor_group.upper_neighbors), + upper_neighbors_meta_(neighbor_group.upper_neighbors_meta), + neighbors_in_memory_enabled_(neighbors_in_memory_enabled) { + segment_datas_.resize(std::max(l0_neighbor_cnt(), upper_neighbor_cnt()), + IndexStorage::SegmentData(0U, 0U)); + fixed_neighbors_ = fixed_neighbors; + } + + bool do_crc_check(std::vector &segments) const; + + inline size_t neighbors_size() const { + return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); + } + + inline size_t upper_neighbors_size() const { + return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); + } + + //! If neighbors_in_memory_enabled, load the level0 neighbors to memory + int load_and_flat_neighbors(void); + + public: + HnswRabitqSearcherEntity(const HnswRabitqSearcherEntity &) = delete; + HnswRabitqSearcherEntity &operator=(const HnswRabitqSearcherEntity &) = + delete; + + private: + IndexStorage::Pointer storage_{}; + + SegmentPointer vectors_{}; + SegmentPointer keys_{}; + SegmentPointer mapping_{}; + + SegmentPointer neighbors_{}; + SegmentPointer neighbors_meta_{}; + SegmentPointer upper_neighbors_{}; + SegmentPointer upper_neighbors_meta_{}; + + mutable std::vector segment_datas_{}; + std::shared_ptr fixed_neighbors_{}; // level 0 fixed size neighbors + bool neighbors_in_memory_enabled_{false}; + bool loaded_{false}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer.cc new file mode 100644 index 00000000..f89f236a --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer.cc @@ -0,0 +1,967 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "hnsw_rabitq_streamer.h" +#include +#include +#include +#include +#include +#include +#include "algorithm/hnsw-rabitq/rabitq_reformer.h" +#include "zvec/ailego/container/params.h" +#include "zvec/ailego/logger/logger.h" +#include "hnsw_rabitq_algorithm.h" +#include "hnsw_rabitq_context.h" +#include "hnsw_rabitq_dist_calculator.h" +#include "hnsw_rabitq_index_provider.h" +#include "rabitq_params.h" +#include "rabitq_utils.h" + +namespace zvec { +namespace core { +HnswRabitqStreamer::HnswRabitqStreamer() : entity_(stats_) {} + +HnswRabitqStreamer::HnswRabitqStreamer(IndexProvider::Pointer provider, + RabitqReformer::Pointer reformer) + : entity_(stats_), + reformer_(std::move(reformer)), + provider_(std::move(provider)) {} + +HnswRabitqStreamer::~HnswRabitqStreamer() { + if (state_ == STATE_INITED) { + this->cleanup(); + } +} + +int HnswRabitqStreamer::init(const IndexMeta &imeta, + const ailego::Params ¶ms) { + meta_ = imeta; + meta_.set_streamer("HnswRabitqStreamer", HnswRabitqEntity::kRevision, params); + + params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_INDEX_SIZE, &max_index_size_); + + params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_NEIGHBOR_COUNT, + &upper_max_neighbor_cnt_); + float multiplier = HnswRabitqEntity::kDefaultL0MaxNeighborCntMultiplier; + params.get(PARAM_HNSW_RABITQ_STREAMER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER, + &multiplier); + l0_max_neighbor_cnt_ = multiplier * upper_max_neighbor_cnt_; + + multiplier = HnswRabitqEntity::kDefaultNeighborPruneMultiplier; + params.get(PARAM_HNSW_RABITQ_STREAMER_NEIGHBOR_PRUNE_MULTIPLIER, &multiplier); + size_t prune_cnt = multiplier * upper_max_neighbor_cnt_; + scaling_factor_ = upper_max_neighbor_cnt_; + params.get(PARAM_HNSW_RABITQ_STREAMER_SCALING_FACTOR, &scaling_factor_); + + params.get(PARAM_HNSW_RABITQ_STREAMER_DOCS_HARD_LIMIT, &docs_hard_limit_); + params.get(PARAM_HNSW_RABITQ_STREAMER_EF, &ef_); + params.get(PARAM_HNSW_RABITQ_STREAMER_EFCONSTRUCTION, &ef_construction_); + params.get(PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_ENABLE, &bf_enabled_); + params.get(PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB, + &bf_negative_prob_); + params.get(PARAM_HNSW_RABITQ_STREAMER_BRUTE_FORCE_THRESHOLD, + &bruteforce_threshold_); + params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_RATIO, &max_scan_ratio_); + params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_LIMIT, &max_scan_limit_); + params.get(PARAM_HNSW_RABITQ_STREAMER_MIN_SCAN_LIMIT, &min_scan_limit_); + params.get(PARAM_HNSW_RABITQ_STREAMER_CHECK_CRC_ENABLE, &check_crc_enabled_); + params.get(PARAM_HNSW_RABITQ_STREAMER_CHUNK_SIZE, &chunk_size_); + params.get(PARAM_HNSW_RABITQ_STREAMER_FILTER_SAME_KEY, &filter_same_key_); + params.get(PARAM_HNSW_RABITQ_STREAMER_GET_VECTOR_ENABLE, + &get_vector_enabled_); + params.get(PARAM_HNSW_RABITQ_STREAMER_MIN_NEIGHBOR_COUNT, &min_neighbor_cnt_); + params.get(PARAM_HNSW_RABITQ_STREAMER_FORCE_PADDING_RESULT_ENABLE, + &force_padding_topk_enabled_); + params.get(PARAM_HNSW_RABITQ_STREAMER_USE_ID_MAP, &use_id_map_); + entity_.set_use_key_info_map(use_id_map_); + + params.get(PARAM_HNSW_RABITQ_STREAMER_DOCS_SOFT_LIMIT, &docs_soft_limit_); + if (docs_soft_limit_ > 0 && docs_soft_limit_ > docs_hard_limit_) { + LOG_ERROR("[%s] must be >= [%s]", + PARAM_HNSW_RABITQ_STREAMER_DOCS_HARD_LIMIT.c_str(), + PARAM_HNSW_RABITQ_STREAMER_DOCS_SOFT_LIMIT.c_str()); + return IndexError_InvalidArgument; + } else if (docs_soft_limit_ == 0UL) { + docs_soft_limit_ = + docs_hard_limit_ * HnswRabitqEntity::kDefaultDocsSoftLimitRatio; + } + + if (ef_ == 0U) { + ef_ = HnswRabitqEntity::kDefaultEf; + } + if (ef_construction_ == 0U) { + ef_construction_ = HnswRabitqEntity::kDefaultEfConstruction; + } + if (upper_max_neighbor_cnt_ == 0U) { + upper_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultUpperMaxNeighborCnt; + } + if (upper_max_neighbor_cnt_ > HnswRabitqEntity::kMaxNeighborCnt) { + LOG_ERROR("[%s] must be in range (0,%d)", + PARAM_HNSW_RABITQ_STREAMER_MAX_NEIGHBOR_COUNT.c_str(), + HnswRabitqEntity::kMaxNeighborCnt); + return IndexError_InvalidArgument; + } + if (l0_max_neighbor_cnt_ == 0U) { + l0_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultL0MaxNeighborCnt; + } + if (l0_max_neighbor_cnt_ > HnswRabitqEntity::kMaxNeighborCnt) { + LOG_ERROR("MaxL0NeighborCnt must be in range (0,%d)", + HnswRabitqEntity::kMaxNeighborCnt); + return IndexError_InvalidArgument; + } + if (min_neighbor_cnt_ > upper_max_neighbor_cnt_) { + LOG_ERROR("[%s]-[%zu] must be <= [%s]-[%zu]", + PARAM_HNSW_RABITQ_STREAMER_MIN_NEIGHBOR_COUNT.c_str(), + static_cast(min_neighbor_cnt_), + PARAM_HNSW_RABITQ_STREAMER_MAX_NEIGHBOR_COUNT.c_str(), + static_cast(upper_max_neighbor_cnt_)); + return IndexError_InvalidArgument; + } + + if (bf_negative_prob_ <= 0.0f || bf_negative_prob_ >= 1.0f) { + LOG_ERROR( + "[%s] must be in range (0,1)", + PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB.c_str()); + return IndexError_InvalidArgument; + } + + if (scaling_factor_ == 0U) { + scaling_factor_ = HnswRabitqEntity::kDefaultScalingFactor; + } + if (scaling_factor_ < 5 || scaling_factor_ > 1000) { + LOG_ERROR("[%s] must be in range [5,1000]", + PARAM_HNSW_RABITQ_STREAMER_SCALING_FACTOR.c_str()); + return IndexError_InvalidArgument; + } + + if (max_scan_ratio_ <= 0.0f || max_scan_ratio_ > 1.0f) { + LOG_ERROR("[%s] must be in range (0.0f,1.0f]", + PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_RATIO.c_str()); + return IndexError_InvalidArgument; + } + + if (max_scan_limit_ < min_scan_limit_) { + LOG_ERROR("[%s] must be >= [%s]", + PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_LIMIT.c_str(), + PARAM_HNSW_RABITQ_STREAMER_MIN_SCAN_LIMIT.c_str()); + return IndexError_InvalidArgument; + } + + if (prune_cnt == 0UL) { + prune_cnt = upper_max_neighbor_cnt_; + } + if (chunk_size_ == 0UL) { + chunk_size_ = HnswRabitqEntity::kDefaultChunkSize; + } + if (chunk_size_ > HnswRabitqEntity::kMaxChunkSize) { + LOG_ERROR("[%s] must be < %zu", + PARAM_HNSW_RABITQ_STREAMER_CHUNK_SIZE.c_str(), + HnswRabitqEntity::kMaxChunkSize); + return IndexError_InvalidArgument; + } + uint32_t total_bits = 0; + params.get(PARAM_RABITQ_TOTAL_BITS, &total_bits); + if (total_bits == 0) { + total_bits = kDefaultRabitqTotalBits; + } + if (total_bits < 1 || total_bits > 9) { + LOG_ERROR("Invalid total_bits: %zu, must be in [1, 9]", (size_t)total_bits); + return IndexError_InvalidArgument; + } + uint8_t ex_bits = total_bits - 1; + entity_.set_ex_bits(ex_bits); + + uint32_t dimension = 0; + params.get(PARAM_HNSW_RABITQ_GENERAL_DIMENSION, &dimension); + if (dimension == 0) { + LOG_ERROR("%s not set", PARAM_HNSW_RABITQ_GENERAL_DIMENSION.c_str()); + return IndexError_InvalidArgument; + } + entity_.update_rabitq_params_and_vector_size(dimension); + + entity_.set_ef_construction(ef_construction_); + entity_.set_upper_neighbor_cnt(upper_max_neighbor_cnt_); + entity_.set_l0_neighbor_cnt(l0_max_neighbor_cnt_); + entity_.set_scaling_factor(scaling_factor_); + entity_.set_prune_cnt(prune_cnt); + + entity_.set_chunk_size(chunk_size_); + entity_.set_filter_same_key(filter_same_key_); + entity_.set_get_vector(get_vector_enabled_); + entity_.set_min_neighbor_cnt(min_neighbor_cnt_); + + int ret = entity_.init(docs_hard_limit_); + if (ret != 0) { + LOG_ERROR("Hnsw entity init failed for %s", IndexError::What(ret)); + return ret; + } + + LOG_DEBUG( + "Init params: maxIndexSize=%zu docsHardLimit=%zu docsSoftLimit=%zu " + "efConstruction=%u ef=%u upperMaxNeighborCnt=%u l0MaxNeighborCnt=%u " + "scalingFactor=%u maxScanRatio=%.3f minScanLimit=%zu maxScanLimit=%zu " + "bfEnabled=%d bruteFoceThreshold=%zu bfNegativeProbability=%.5f " + "checkCrcEnabled=%d pruneSize=%zu vectorSize=%u chunkSize=%zu " + "filterSameKey=%u getVectorEnabled=%u minNeighborCount=%u " + "forcePadding=%u ", + max_index_size_, docs_hard_limit_, docs_soft_limit_, ef_construction_, + ef_, upper_max_neighbor_cnt_, l0_max_neighbor_cnt_, scaling_factor_, + max_scan_ratio_, min_scan_limit_, max_scan_limit_, bf_enabled_, + bruteforce_threshold_, bf_negative_prob_, check_crc_enabled_, prune_cnt, + meta_.element_size(), chunk_size_, filter_same_key_, get_vector_enabled_, + min_neighbor_cnt_, force_padding_topk_enabled_); + + alg_ = HnswRabitqAlgorithm::UPointer(new HnswRabitqAlgorithm(entity_)); + + ret = alg_->init(); + if (ret != 0) { + return ret; + } + + state_ = STATE_INITED; + + return 0; +} + +int HnswRabitqStreamer::cleanup(void) { + if (state_ == STATE_OPENED) { + this->close(); + } + + LOG_INFO("HnswRabitqStreamer cleanup"); + + meta_.clear(); + metric_.reset(); + stats_.clear(); + entity_.cleanup(); + + if (alg_) { + alg_->cleanup(); + } + + max_index_size_ = 0UL; + docs_hard_limit_ = HnswRabitqEntity::kDefaultDocsHardLimit; + docs_soft_limit_ = 0UL; + upper_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultUpperMaxNeighborCnt; + l0_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultL0MaxNeighborCnt; + ef_ = HnswRabitqEntity::kDefaultEf; + ef_construction_ = HnswRabitqEntity::kDefaultEfConstruction; + bf_enabled_ = false; + scaling_factor_ = HnswRabitqEntity::kDefaultScalingFactor; + bruteforce_threshold_ = HnswRabitqEntity::kDefaultBruteForceThreshold; + max_scan_limit_ = HnswRabitqEntity::kDefaultMaxScanLimit; + min_scan_limit_ = HnswRabitqEntity::kDefaultMinScanLimit; + chunk_size_ = HnswRabitqEntity::kDefaultChunkSize; + bf_negative_prob_ = HnswRabitqEntity::kDefaultBFNegativeProbability; + max_scan_ratio_ = HnswRabitqEntity::kDefaultScanRatio; + state_ = STATE_INIT; + check_crc_enabled_ = false; + filter_same_key_ = false; + get_vector_enabled_ = false; + + return 0; +} + +int HnswRabitqStreamer::open(IndexStorage::Pointer stg) { + LOG_INFO("HnswRabitqStreamer open"); + + if (ailego_unlikely(state_ != STATE_INITED)) { + LOG_ERROR("Open storage failed, init streamer first!"); + return IndexError_NoReady; + } + + // try to load reformer + if (reformer_ == nullptr) { + reformer_ = std::make_shared(); + ailego::Params reformer_params; + reformer_params.set(PARAM_RABITQ_METRIC_NAME, meta_.metric_name()); + int ret = reformer_->init(reformer_params); + if (ret != 0) { + LOG_ERROR("Failed to initialize RabitqReformer: %d", ret); + return ret; + } + + ret = reformer_->load(stg); + if (ret != 0) { + LOG_ERROR("Failed to load reformer, ret=%d", ret); + return ret; + } + } else { + if (!stg->has(RABITQ_CONVERER_SEG_ID)) { + int ret = reformer_->dump(stg); + if (ret != 0) { + LOG_ERROR("Failed to dump reformer, ret=%d", ret); + return ret; + } + LOG_INFO("Dump reformer success."); + } + } + + int ret = entity_.open(std::move(stg), max_index_size_, check_crc_enabled_); + if (ret != 0) { + return ret; + } + IndexMeta index_meta; + ret = entity_.get_index_meta(&index_meta); + if (ret == IndexError_NoExist) { + // Set IndexMeta for the new index + ret = entity_.set_index_meta(meta_); + if (ret != 0) { + LOG_ERROR("Failed to set index meta for %s", IndexError::What(ret)); + return ret; + } + } else if (ret != 0) { + LOG_ERROR("Failed to get index meta for %s", IndexError::What(ret)); + return ret; + } else { + if (index_meta.dimension() != meta_.dimension() || + index_meta.element_size() != meta_.element_size() || + index_meta.metric_name() != meta_.metric_name() || + index_meta.data_type() != meta_.data_type()) { + LOG_ERROR("IndexMeta mismatch from the previous in index"); + return IndexError_Mismatch; + } + // The IndexMetric Params may be updated like MipsSquaredEuclidean + auto metric_params = index_meta.metric_params(); + metric_params.merge(meta_.metric_params()); + meta_.set_metric(index_meta.metric_name(), 0, metric_params); + } + + metric_ = IndexFactory::CreateMetric(meta_.metric_name()); + if (!metric_) { + LOG_ERROR("Failed to create metric %s", meta_.metric_name().c_str()); + return IndexError_NoExist; + } + ret = metric_->init(meta_, meta_.metric_params()); + if (ret != 0) { + LOG_ERROR("Failed to init metric, ret=%d", ret); + return ret; + } + + if (!metric_->distance()) { + LOG_ERROR("Invalid metric distance"); + return IndexError_InvalidArgument; + } + + if (!metric_->batch_distance()) { + LOG_ERROR("Invalid metric batch distance"); + return IndexError_InvalidArgument; + } + + add_distance_ = metric_->distance(); + add_batch_distance_ = metric_->batch_distance(); + + search_distance_ = add_distance_; + search_batch_distance_ = add_batch_distance_; + + if (metric_->query_metric() && metric_->query_metric()->distance() && + metric_->query_metric()->batch_distance()) { + search_distance_ = metric_->query_metric()->distance(); + search_batch_distance_ = metric_->query_metric()->batch_distance(); + } + + state_ = STATE_OPENED; + magic_ = IndexContext::GenerateMagic(); + + query_alg_ = HnswRabitqQueryAlgorithm::UPointer(new HnswRabitqQueryAlgorithm( + entity_, reformer_->num_clusters(), reformer_->rabitq_metric_type())); + + return 0; +} + +int HnswRabitqStreamer::close(void) { + LOG_INFO("HnswRabitqStreamer close"); + + stats_.clear(); + meta_.set_metric(metric_->name(), 0, metric_->params()); + entity_.set_index_meta(meta_); + int ret = entity_.close(); + if (ret != 0) { + return ret; + } + state_ = STATE_INITED; + + return 0; +} + +int HnswRabitqStreamer::flush(uint64_t checkpoint) { + LOG_INFO("HnswRabitqStreamer flush checkpoint=%zu", (size_t)checkpoint); + + meta_.set_metric(metric_->name(), 0, metric_->params()); + entity_.set_index_meta(meta_); + return entity_.flush(checkpoint); +} + +int HnswRabitqStreamer::dump(const IndexDumper::Pointer &dumper) { + LOG_INFO("HnswRabitqStreamer dump"); + + shared_mutex_.lock(); + AILEGO_DEFER([&]() { shared_mutex_.unlock(); }); + + meta_.set_searcher("HnswRabitqSearcher", HnswRabitqEntity::kRevision, + ailego::Params()); + + int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); + if (ret != 0) { + LOG_ERROR("Failed to serialize meta into dumper."); + return ret; + } + ret = reformer_->dump(dumper); + if (ret != 0) { + LOG_ERROR("Failed to dump reformer into dumper."); + return ret; + } + return entity_.dump(dumper); +} + +IndexStreamer::Context::Pointer HnswRabitqStreamer::create_context(void) const { + if (ailego_unlikely(state_ != STATE_OPENED)) { + LOG_ERROR("Create context failed, open storage first!"); + return Context::Pointer(); + } + + HnswRabitqEntity::Pointer entity = entity_.clone(); + if (ailego_unlikely(!entity)) { + LOG_ERROR("CreateContext clone init failed"); + return Context::Pointer(); + } + HnswRabitqContext *ctx = + new (std::nothrow) HnswRabitqContext(meta_.dimension(), metric_, entity); + if (ailego_unlikely(ctx == nullptr)) { + LOG_ERROR("Failed to new HnswRabitqContext"); + return Context::Pointer(); + } + ctx->set_ef(ef_); + ctx->set_max_scan_limit(max_scan_limit_); + ctx->set_min_scan_limit(min_scan_limit_); + ctx->set_max_scan_ratio(max_scan_ratio_); + ctx->set_filter_mode(bf_enabled_ ? VisitFilter::BloomFilter + : VisitFilter::ByteMap); + ctx->set_filter_negative_probability(bf_negative_prob_); + ctx->set_magic(magic_); + ctx->set_force_padding_topk(force_padding_topk_enabled_); + ctx->set_bruteforce_threshold(bruteforce_threshold_); + + if (ailego_unlikely(ctx->init(HnswRabitqContext::kStreamerContext)) != 0) { + LOG_ERROR("Init HnswRabitqContext failed"); + delete ctx; + return Context::Pointer(); + } + uint32_t estimate_doc_count = 0; + if (meta_.streamer_params().get(PARAM_HNSW_RABITQ_STREAMER_ESTIMATE_DOC_COUNT, + &estimate_doc_count)) { + LOG_DEBUG("HnswRabitqStreamer doc_count[%zu] estimate[%zu]", + (size_t)entity_.doc_cnt(), (size_t)estimate_doc_count); + } + ctx->check_need_adjuct_ctx(std::max(entity_.doc_cnt(), estimate_doc_count)); + + return Context::Pointer(ctx); +} + +IndexProvider::Pointer HnswRabitqStreamer::create_provider(void) const { + LOG_DEBUG("HnswRabitqStreamer create provider"); + + auto entity = entity_.clone(); + if (ailego_unlikely(!entity)) { + LOG_ERROR("Clone HnswRabitqEntity failed"); + return nullptr; + } + return Provider::Pointer( + new HnswRabitqIndexProvider(meta_, entity, "HnswRabitqStreamer")); +} + +int HnswRabitqStreamer::update_context(HnswRabitqContext *ctx) const { + const HnswRabitqEntity::Pointer entity = entity_.clone(); + if (!entity) { + LOG_ERROR("Failed to clone search context entity"); + return IndexError_Runtime; + } + ctx->set_max_scan_limit(max_scan_limit_); + ctx->set_min_scan_limit(min_scan_limit_); + ctx->set_max_scan_ratio(max_scan_ratio_); + ctx->set_bruteforce_threshold(bruteforce_threshold_); + return ctx->update_context(HnswRabitqContext::kStreamerContext, meta_, + metric_, entity, magic_); +} + +//! Add a vector with id into index +int HnswRabitqStreamer::add_with_id_impl( + uint32_t id, const void *query, const IndexQueryMeta &qmeta, + IndexStreamer::Context::Pointer &context) { + if (!provider_) { + LOG_ERROR("Provider is nullptr, cannot add vector"); + return IndexError_InvalidArgument; + } + + int ret = check_params(query, qmeta); + if (ailego_unlikely(ret != 0)) { + return ret; + } + + HnswRabitqContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to HnswRabitqContext failed"); + return IndexError_Cast; + } + if (ctx->magic() != magic_) { + //! context is created by another searcher or streamer + ret = update_context(ctx); + if (ret != 0) { + return ret; + } + } + + if (ailego_unlikely(entity_.doc_cnt() >= docs_soft_limit_)) { + if (entity_.doc_cnt() >= docs_hard_limit_) { + LOG_ERROR("Current docs %zu exceed [%s]", + static_cast(entity_.doc_cnt()), + PARAM_HNSW_RABITQ_STREAMER_DOCS_HARD_LIMIT.c_str()); + const std::lock_guard lk(mutex_); + (*stats_.mutable_discarded_count())++; + return IndexError_IndexFull; + } else { + LOG_WARN("Current docs %zu exceed [%s]", + static_cast(entity_.doc_cnt()), + PARAM_HNSW_RABITQ_STREAMER_DOCS_SOFT_LIMIT.c_str()); + } + } + if (ailego_unlikely(!shared_mutex_.try_lock_shared())) { + LOG_ERROR("Cannot add vector while dumping index"); + (*stats_.mutable_discarded_count())++; + return IndexError_Unsupported; + } + AILEGO_DEFER([&]() { shared_mutex_.unlock_shared(); }); + + ctx->clear(); + ctx->update_dist_caculator_distance(add_distance_, add_batch_distance_); + ctx->reset_query(query); + ctx->check_need_adjuct_ctx(entity_.doc_cnt()); + ctx->set_provider(provider_); + + if (metric_->support_train()) { + const std::lock_guard lk(mutex_); + ret = metric_->train(query, meta_.dimension()); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw streamer metric train failed"); + (*stats_.mutable_discarded_count())++; + return ret; + } + } + + std::string converted_vector; + IndexQueryMeta converted_meta; + ret = reformer_->convert(query, qmeta, &converted_vector, &converted_meta); + if (ret != 0) { + LOG_ERROR("Rabitq hnsw convert failed, ret=%d", ret); + return ret; + } + + level_t level = alg_->get_random_level(); + ret = entity_.add_vector_with_id(level, id, converted_vector.data()); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw streamer add vector failed"); + (*stats_.mutable_discarded_count())++; + return ret; + } + + ret = alg_->add_node(id, level, ctx); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw steamer add node failed"); + (*stats_.mutable_discarded_count())++; + return ret; + } + + if (ailego_unlikely(ctx->error())) { + (*stats_.mutable_discarded_count())++; + return IndexError_Runtime; + } + (*stats_.mutable_added_count())++; + + return 0; +} + +//! Add a vector into index +int HnswRabitqStreamer::add_impl(uint64_t pkey, const void *query, + const IndexQueryMeta &qmeta, + IndexStreamer::Context::Pointer &context) { + if (!provider_) { + LOG_ERROR("Provider is nullptr, cannot add vector"); + return IndexError_InvalidArgument; + } + + int ret = check_params(query, qmeta); + if (ailego_unlikely(ret != 0)) { + return ret; + } + + HnswRabitqContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to HnswRabitqContext failed"); + return IndexError_Cast; + } + if (ctx->magic() != magic_) { + //! context is created by another searcher or streamer + ret = update_context(ctx); + if (ret != 0) { + return ret; + } + } + + if (ailego_unlikely(entity_.doc_cnt() >= docs_soft_limit_)) { + if (entity_.doc_cnt() >= docs_hard_limit_) { + LOG_ERROR("Current docs %zu exceed [%s]", + static_cast(entity_.doc_cnt()), + PARAM_HNSW_RABITQ_STREAMER_DOCS_HARD_LIMIT.c_str()); + const std::lock_guard lk(mutex_); + (*stats_.mutable_discarded_count())++; + return IndexError_IndexFull; + } else { + LOG_WARN("Current docs %zu exceed [%s]", + static_cast(entity_.doc_cnt()), + PARAM_HNSW_RABITQ_STREAMER_DOCS_SOFT_LIMIT.c_str()); + } + } + if (ailego_unlikely(!shared_mutex_.try_lock_shared())) { + LOG_ERROR("Cannot add vector while dumping index"); + (*stats_.mutable_discarded_count())++; + return IndexError_Unsupported; + } + AILEGO_DEFER([&]() { shared_mutex_.unlock_shared(); }); + + ctx->clear(); + ctx->update_dist_caculator_distance(add_distance_, add_batch_distance_); + ctx->reset_query(query); + ctx->check_need_adjuct_ctx(entity_.doc_cnt()); + ctx->set_provider(provider_); + + if (metric_->support_train()) { + const std::lock_guard lk(mutex_); + ret = metric_->train(query, meta_.dimension()); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw streamer metric train failed"); + (*stats_.mutable_discarded_count())++; + return ret; + } + } + + std::string converted_vector; + IndexQueryMeta converted_meta; + ret = reformer_->convert(query, qmeta, &converted_vector, &converted_meta); + if (ret != 0) { + LOG_ERROR("Rabitq hnsw convert failed, ret=%d", ret); + return ret; + } + + level_t level = alg_->get_random_level(); + node_id_t id; + ret = entity_.add_vector(level, pkey, converted_vector.data(), &id); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw streamer add vector failed"); + (*stats_.mutable_discarded_count())++; + return ret; + } + + ret = alg_->add_node(id, level, ctx); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw steamer add node failed"); + (*stats_.mutable_discarded_count())++; + return ret; + } + + if (ailego_unlikely(ctx->error())) { + (*stats_.mutable_discarded_count())++; + return IndexError_Runtime; + } + (*stats_.mutable_added_count())++; + + return 0; +} + + +int HnswRabitqStreamer::search_impl( + const void *query, const IndexQueryMeta &qmeta, + IndexStreamer::Context::Pointer &context) const { + return search_impl(query, qmeta, 1, context); +} + +//! Similarity search +int HnswRabitqStreamer::search_impl( + const void *query, const IndexQueryMeta &qmeta, uint32_t count, + IndexStreamer::Context::Pointer &context) const { + int ret = check_params(query, qmeta); + if (ailego_unlikely(ret != 0)) { + return ret; + } + HnswRabitqContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to HnswRabitqContext failed"); + return IndexError_Cast; + } + + if (entity_.doc_cnt() <= ctx->get_bruteforce_threshold()) { + return search_bf_impl(query, qmeta, count, context); + } + + if (ctx->magic() != magic_) { + //! context is created by another searcher or streamer + ret = update_context(ctx); + if (ret != 0) { + return ret; + } + } + + ctx->clear(); + ctx->update_dist_caculator_distance(search_distance_, search_batch_distance_); + ctx->resize_results(count); + ctx->check_need_adjuct_ctx(entity_.doc_cnt()); + for (size_t q = 0; q < count; ++q) { + HnswRabitqQueryEntity entity; + ret = reformer_->transform_to_entity(query, &entity); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw searcher transform failed"); + return ret; + } + ctx->reset_query(query); + ret = query_alg_->search(&entity, ctx); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw searcher fast search failed"); + return ret; + } + ctx->topk_to_result(q); + query = static_cast(query) + qmeta.element_size(); + } + + if (ailego_unlikely(ctx->error())) { + return IndexError_Runtime; + } + + return 0; +} + +void HnswRabitqStreamer::print_debug_info() { + for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { + if (entity_.get_key(id) == kInvalidKey) { + continue; + } + Neighbors neighbours = entity_.get_neighbors(0, id); + std::cout << "node: " << id << "; "; + if (neighbours.size() == 0) std::cout << std::endl; + for (uint32_t i = 0; i < neighbours.size(); ++i) { + std::cout << neighbours[i]; + + if (i == neighbours.size() - 1) { + std::cout << std::endl; + } else { + std::cout << ", "; + } + } + } + + // entity_.print_key_map(); +} + +int HnswRabitqStreamer::search_bf_impl( + const void *query, const IndexQueryMeta &qmeta, + IndexStreamer::Context::Pointer &context) const { + return search_bf_impl(query, qmeta, 1, context); +} + +int HnswRabitqStreamer::search_bf_impl( + const void *query, const IndexQueryMeta &qmeta, uint32_t count, + IndexStreamer::Context::Pointer &context) const { + int ret = check_params(query, qmeta); + if (ailego_unlikely(ret != 0)) { + return ret; + } + HnswRabitqContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to HnswRabitqContext failed"); + return IndexError_Cast; + } + if (ctx->magic() != magic_) { + //! context is created by another searcher or streamer + ret = update_context(ctx); + if (ret != 0) { + return ret; + } + } + + ctx->clear(); + ctx->update_dist_caculator_distance(search_distance_, search_batch_distance_); + ctx->resize_results(count); + + if (ctx->group_by_search()) { + if (!ctx->group_by().is_valid()) { + LOG_ERROR("Invalid group-by function"); + return IndexError_InvalidArgument; + } + + std::function group_by = [&](node_id_t id) { + return ctx->group_by()(entity_.get_key(id)); + }; + + for (size_t q = 0; q < count; ++q) { + ctx->reset_query(query); + ctx->group_topk_heaps().clear(); + + for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { + if (entity_.get_key(id) == kInvalidKey) { + continue; + } + + if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { + dist_t dist = ctx->dist_calculator().dist(id); + + std::string group_id = group_by(id); + + auto &topk_heap = ctx->group_topk_heaps()[group_id]; + if (topk_heap.empty()) { + topk_heap.limit(ctx->group_topk()); + } + topk_heap.emplace_back(id, dist); + } + } + ctx->topk_to_result(q); + query = static_cast(query) + qmeta.element_size(); + } + } else { + for (size_t q = 0; q < count; ++q) { + HnswRabitqQueryEntity entity; + ret = reformer_->transform_to_entity(query, &entity); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Hnsw rabitq streamer transform failed"); + return ret; + } + ctx->reset_query(query); + ctx->topk_heap().clear(); + for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { + if (entity_.get_key(id) == kInvalidKey) { + continue; + } + if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { + EstimateRecord dist; + query_alg_->get_full_est(id, dist, entity); + ctx->topk_heap().emplace(id, dist); + } + } + ctx->topk_to_result(q); + query = static_cast(query) + qmeta.element_size(); + } + } + + if (ailego_unlikely(ctx->error())) { + return IndexError_Runtime; + } + + return 0; +} + +int HnswRabitqStreamer::search_bf_by_p_keys_impl( + const void *query, const std::vector> &p_keys, + const IndexQueryMeta &qmeta, uint32_t count, + Context::Pointer &context) const { + int ret = check_params(query, qmeta); + if (ailego_unlikely(ret != 0)) { + return ret; + } + + if (ailego_unlikely(p_keys.size() != count)) { + LOG_ERROR("The size of p_keys is not equal to count"); + return IndexError_InvalidArgument; + } + + HnswRabitqContext *ctx = dynamic_cast(context.get()); + ailego_do_if_false(ctx) { + LOG_ERROR("Cast context to HnswRabitqContext failed"); + return IndexError_Cast; + } + if (ctx->magic() != magic_) { + //! context is created by another searcher or streamer + ret = update_context(ctx); + if (ret != 0) { + return ret; + } + } + + ctx->clear(); + ctx->update_dist_caculator_distance(search_distance_, search_batch_distance_); + ctx->resize_results(count); + + if (ctx->group_by_search()) { + if (!ctx->group_by().is_valid()) { + LOG_ERROR("Invalid group-by function"); + return IndexError_InvalidArgument; + } + + std::function group_by = [&](node_id_t id) { + return ctx->group_by()(entity_.get_key(id)); + }; + + for (size_t q = 0; q < count; ++q) { + ctx->reset_query(query); + ctx->group_topk_heaps().clear(); + + for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { + uint64_t pk = p_keys[q][idx]; + if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { + node_id_t id = entity_.get_id(pk); + if (id != kInvalidNodeId) { + dist_t dist = ctx->dist_calculator().dist(id); + std::string group_id = group_by(id); + + auto &topk_heap = ctx->group_topk_heaps()[group_id]; + if (topk_heap.empty()) { + topk_heap.limit(ctx->group_topk()); + } + topk_heap.emplace_back(id, dist); + } + } + } + ctx->topk_to_result(q); + query = static_cast(query) + qmeta.element_size(); + } + } else { + auto &filter = ctx->filter(); + auto &topk = ctx->topk_heap(); + + for (size_t q = 0; q < count; ++q) { + ctx->reset_query(query); + topk.clear(); + for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { + key_t pk = p_keys[q][idx]; + if (!filter.is_valid() || !filter(pk)) { + node_id_t id = entity_.get_id(pk); + if (id != kInvalidNodeId) { + dist_t dist = ctx->dist_calculator().dist(id); + topk.emplace(id, dist); + } + } + } + ctx->topk_to_result(q); + query = static_cast(query) + qmeta.element_size(); + } + } + + if (ailego_unlikely(ctx->error())) { + return IndexError_Runtime; + } + + return 0; +} + + +INDEX_FACTORY_REGISTER_STREAMER(HnswRabitqStreamer); + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer.h new file mode 100644 index 00000000..167fdf2a --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer.h @@ -0,0 +1,249 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "algorithm/hnsw-rabitq/rabitq_reformer.h" +#include "zvec/core/framework/index_framework.h" +#include "zvec/core/framework/index_provider.h" +#include "zvec/core/framework/index_reformer.h" +#include "hnsw_rabitq_algorithm.h" +#include "hnsw_rabitq_query_algorithm.h" +#include "hnsw_rabitq_streamer_entity.h" + +namespace zvec { +namespace core { + +class HnswRabitqStreamer : public IndexStreamer { + public: + using ContextPointer = IndexStreamer::Context::Pointer; + + HnswRabitqStreamer(); + explicit HnswRabitqStreamer(IndexProvider::Pointer provider, + RabitqReformer::Pointer reformer = nullptr); + virtual ~HnswRabitqStreamer(void); + + HnswRabitqStreamer(const HnswRabitqStreamer &streamer) = delete; + HnswRabitqStreamer &operator=(const HnswRabitqStreamer &streamer) = delete; + + void set_provider(IndexProvider::Pointer provider) { + provider_ = std::move(provider); + } + + void set_reformer(IndexReformer::Pointer reformer) { + reformer_ = std::dynamic_pointer_cast(reformer); + } + + protected: + //! Initialize Streamer + virtual int init(const IndexMeta &imeta, + const ailego::Params ¶ms) override; + + //! Cleanup Streamer + virtual int cleanup(void) override; + + //! Create a context + virtual Context::Pointer create_context(void) const override; + + //! Create a new iterator + virtual IndexProvider::Pointer create_provider(void) const override; + + //! Add a vector into index + virtual int add_impl(uint64_t pkey, const void *query, + const IndexQueryMeta &qmeta, + Context::Pointer &context) override; + + //! Add a vector with id into index + virtual int add_with_id_impl(uint32_t id, const void *query, + const IndexQueryMeta &qmeta, + Context::Pointer &context) override; + + //! Similarity search + virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, + Context::Pointer &context) const override; + + //! Similarity search + virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, + Context::Pointer &context) const override; + + //! Similarity brute force search + virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, + Context::Pointer &context) const override; + + //! Similarity brute force search + virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, + Context::Pointer &context) const override; + + //! Linear search by primary keys + virtual int search_bf_by_p_keys_impl( + const void *query, const std::vector> &p_keys, + const IndexQueryMeta &qmeta, ContextPointer &context) const override { + return search_bf_by_p_keys_impl(query, p_keys, qmeta, 1, context); + } + + //! Linear search by primary keys + virtual int search_bf_by_p_keys_impl( + const void *query, const std::vector> &p_keys, + const IndexQueryMeta &qmeta, uint32_t count, + ContextPointer &context) const override; + + //! Fetch vector by key + virtual const void *get_vector(uint64_t key) const override { + return entity_.get_vector_by_key(key); + } + + virtual int get_vector(const uint64_t key, + IndexStorage::MemoryBlock &block) const override { + return entity_.get_vector_by_key(key, block); + } + + //! Fetch vector by id + virtual const void *get_vector_by_id(uint32_t id) const override { + return entity_.get_vector(id); + } + + virtual int get_vector_by_id( + const uint32_t id, IndexStorage::MemoryBlock &block) const override { + return entity_.get_vector(id, block); + } + + //! Open index from file path + virtual int open(IndexStorage::Pointer stg) override; + + //! Close file + virtual int close(void) override; + + //! flush file + virtual int flush(uint64_t checkpoint) override; + + //! Dump index into storage + virtual int dump(const IndexDumper::Pointer &dumper) override; + + //! Retrieve statistics + virtual const Stats &stats(void) const override { + return stats_; + } + + //! Retrieve meta of index + virtual const IndexMeta &meta(void) const override { + return meta_; + } + + virtual void print_debug_info() override; + + private: + inline int check_params(const void *query, + const IndexQueryMeta &qmeta) const { + if (ailego_unlikely(!query)) { + LOG_ERROR("null query"); + return IndexError_InvalidArgument; + } + if (ailego_unlikely(qmeta.dimension() != meta_.dimension() || + qmeta.data_type() != meta_.data_type() || + qmeta.element_size() != meta_.element_size())) { + LOG_ERROR("Unsupported query meta"); + return IndexError_Mismatch; + } + return 0; + } + + inline int check_sparse_count_is_zero(const uint32_t *sparse_count, + uint32_t count) const { + for (uint32_t i = 0; i < count; ++i) { + if (sparse_count[i] != 0) + LOG_ERROR("Sparse cout is not empty. Index: %u, Sparse Count: %u", i, + sparse_count[i]); + return IndexError_InvalidArgument; + } + + return 0; + } + + private: + //! To share ctx across streamer/searcher, we need to update the context for + //! current streamer/searcher + int update_context(HnswRabitqContext *ctx) const; + + private: + enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_OPENED = 2 }; + class Stats : public IndexStreamer::Stats { + public: + void clear(void) { + set_revision_id(0u); + set_loaded_count(0u); + set_added_count(0u); + set_discarded_count(0u); + set_index_size(0u); + set_dumped_size(0u); + set_check_point(0u); + set_create_time(0u); + set_update_time(0u); + clear_attributes(); + } + }; + + HnswRabitqStreamerEntity entity_; + HnswRabitqAlgorithm::UPointer alg_; + IndexMeta meta_{}; + IndexMetric::Pointer metric_{}; + + IndexMetric::MatrixDistance add_distance_{}; + IndexMetric::MatrixDistance search_distance_{}; + + IndexMetric::MatrixBatchDistance add_batch_distance_{}; + IndexMetric::MatrixBatchDistance search_batch_distance_{}; + + RabitqReformer::Pointer reformer_{}; // RaBitQ reformer + HnswRabitqQueryAlgorithm::UPointer query_alg_; // query algorithm + // provider_ provides raw vector, which is used to build graph + IndexProvider::Pointer provider_{}; + + Stats stats_{}; + std::mutex mutex_{}; + + size_t max_index_size_{0UL}; + size_t chunk_size_{HnswRabitqEntity::kDefaultChunkSize}; + size_t docs_hard_limit_{HnswRabitqEntity::kDefaultDocsHardLimit}; + size_t docs_soft_limit_{0UL}; + uint32_t min_neighbor_cnt_{0u}; + uint32_t upper_max_neighbor_cnt_{ + HnswRabitqEntity::kDefaultUpperMaxNeighborCnt}; + uint32_t l0_max_neighbor_cnt_{HnswRabitqEntity::kDefaultL0MaxNeighborCnt}; + uint32_t ef_{HnswRabitqEntity::kDefaultEf}; + uint32_t ef_construction_{HnswRabitqEntity::kDefaultEfConstruction}; + uint32_t scaling_factor_{HnswRabitqEntity::kDefaultScalingFactor}; + size_t bruteforce_threshold_{HnswRabitqEntity::kDefaultBruteForceThreshold}; + size_t max_scan_limit_{HnswRabitqEntity::kDefaultMaxScanLimit}; + size_t min_scan_limit_{HnswRabitqEntity::kDefaultMinScanLimit}; + float bf_negative_prob_{HnswRabitqEntity::kDefaultBFNegativeProbability}; + float max_scan_ratio_{HnswRabitqEntity::kDefaultScanRatio}; + + uint32_t magic_{0U}; + State state_{STATE_INIT}; + bool bf_enabled_{false}; + bool check_crc_enabled_{false}; + bool filter_same_key_{false}; + bool get_vector_enabled_{false}; + bool force_padding_topk_enabled_{false}; + bool use_id_map_{true}; + + //! avoid add vector while dumping index + ailego::SharedMutex shared_mutex_{}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer_entity.cc b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer_entity.cc new file mode 100644 index 00000000..35501ed9 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer_entity.cc @@ -0,0 +1,709 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hnsw_rabitq_streamer_entity.h" +#include + +// #define DEBUG_PRINT + +namespace zvec { +namespace core { + +HnswRabitqStreamerEntity::HnswRabitqStreamerEntity(IndexStreamer::Stats &stats) + : stats_(stats) {} + +HnswRabitqStreamerEntity::~HnswRabitqStreamerEntity() {} + +int HnswRabitqStreamerEntity::init(size_t max_doc_cnt) { + if (std::pow(scaling_factor(), kMaxGraphLayers) < max_doc_cnt) { + LOG_ERROR("scalingFactor=%zu is too small", scaling_factor()); + return IndexError_InvalidArgument; + } + + std::lock_guard lock(mutex_); + broker_ = std::make_shared(stats_); + upper_neighbor_index_ = std::make_shared(); + keys_map_lock_ = std::make_shared(); + keys_map_ = std::make_shared>(); + if (!keys_map_ || !upper_neighbor_index_ || !broker_ || !keys_map_lock_) { + LOG_ERROR("HnswRabitqStreamerEntity new object failed"); + return IndexError_NoMemory; + } + keys_map_->set_empty_key(kInvalidKey); + + neighbor_size_ = neighbors_size(); + upper_neighbor_size_ = upper_neighbors_size(); + + //! vector + key + level 0 neighbors + size_t size = vector_size() + sizeof(key_t) + neighbor_size_; + + size = AlignSize(size); + set_node_size(size); + return 0; +} + +int HnswRabitqStreamerEntity::cleanup() { + std::lock_guard lock(mutex_); + mutable_header()->clear(); + chunk_size_ = kDefaultChunkSize; + node_index_mask_bits_ = 0U; + node_index_mask_ = 0U; + node_cnt_per_chunk_ = 0U; + neighbor_size_ = 0U; + upper_neighbor_size_ = 0U; + if (upper_neighbor_index_) { + upper_neighbor_index_->cleanup(); + } + if (keys_map_) { + keys_map_->clear(); + } + node_chunks_.clear(); + upper_neighbor_chunks_.clear(); + filter_same_key_ = false; + get_vector_enabled_ = false; + broker_.reset(); + + return 0; +} + +int HnswRabitqStreamerEntity::update_neighbors( + level_t level, node_id_t id, + const std::vector> &neighbors) { + char buffer[neighbor_size_]; + NeighborsHeader *hd = reinterpret_cast(buffer); + hd->neighbor_cnt = neighbors.size(); + size_t i = 0; + for (; i < neighbors.size(); ++i) { + hd->neighbors[i] = neighbors[i].first; + } + + auto loc = get_neighbor_chunk_loc(level, id); + size_t size = reinterpret_cast(&hd->neighbors[i]) - &buffer[0]; + size_t ret = loc.first->write(loc.second, hd, size); + if (ailego_unlikely(ret != size)) { + LOG_ERROR("Write neighbor header failed, ret=%zu", ret); + + return IndexError_Runtime; + } + + return 0; +} + +const Neighbors HnswRabitqStreamerEntity::get_neighbors(level_t level, + node_id_t id) const { + Chunk *chunk = nullptr; + size_t offset = 0UL; + size_t neighbor_size = neighbor_size_; + if (level == 0UL) { + uint32_t chunk_idx = id >> node_index_mask_bits_; + offset = + (id & node_index_mask_) * node_size() + vector_size() + sizeof(key_t); + + sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, chunk_idx, + &node_chunks_); + ailego_assert_with(chunk_idx < node_chunks_.size(), "invalid chunk idx"); + chunk = node_chunks_[chunk_idx].get(); + } else { + auto p = get_upper_neighbor_chunk_loc(level, id); + chunk = upper_neighbor_chunks_[p.first].get(); + offset = p.second; + neighbor_size = upper_neighbor_size_; + } + + ailego_assert_with(offset < chunk->data_size(), "invalid chunk offset"); + IndexStorage::MemoryBlock neighbor_block; + size_t size = chunk->read(offset, neighbor_block, neighbor_size); + if (ailego_unlikely(size != neighbor_size)) { + LOG_ERROR("Read neighbor header failed, ret=%zu", size); + return Neighbors(); + } + return Neighbors(std::move(neighbor_block)); +} + +//! Get vector data by key +const void *HnswRabitqStreamerEntity::get_vector(node_id_t id) const { + auto loc = get_vector_chunk_loc(id); + const void *vec = nullptr; + ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), + "invalid chunk offset"); + + size_t read_size = vector_size(); + + size_t ret = node_chunks_[loc.first]->read(loc.second, &vec, read_size); + if (ailego_unlikely(ret != read_size)) { + LOG_ERROR("Read vector failed, offset=%zu, read size=%zu, ret=%zu", + static_cast(loc.second), read_size, ret); + } + + return vec; +} + +int HnswRabitqStreamerEntity::get_vector(const node_id_t *ids, uint32_t count, + const void **vecs) const { + for (auto i = 0U; i < count; ++i) { + auto loc = get_vector_chunk_loc(ids[i]); + ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), + "invalid chunk offset"); + + size_t read_size = vector_size(); + + size_t ret = node_chunks_[loc.first]->read(loc.second, &vecs[i], read_size); + if (ailego_unlikely(ret != read_size)) { + LOG_ERROR("Read vector failed, offset=%zu, read size=%zu, ret=%zu", + static_cast(loc.second), read_size, ret); + return IndexError_ReadData; + } + } + return 0; +} + +int HnswRabitqStreamerEntity::get_vector( + const node_id_t id, IndexStorage::MemoryBlock &block) const { + auto loc = get_vector_chunk_loc(id); + ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), + "invalid chunk offset"); + + size_t read_size = vector_size(); + + size_t ret = node_chunks_[loc.first]->read(loc.second, block, read_size); + if (ailego_unlikely(ret != read_size)) { + LOG_ERROR("Read vector failed, offset=%zu, read size=%zu, ret=%zu", + static_cast(loc.second), read_size, ret); + return IndexError_ReadData; + } + return 0; +} + +int HnswRabitqStreamerEntity::get_vector( + const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const { + vec_blocks.resize(count); + for (auto i = 0U; i < count; ++i) { + auto loc = get_vector_chunk_loc(ids[i]); + ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), + "invalid chunk offset"); + + size_t read_size = vector_size(); + + size_t ret = + node_chunks_[loc.first]->read(loc.second, vec_blocks[i], read_size); + if (ailego_unlikely(ret != read_size)) { + LOG_ERROR("Read vector failed, offset=%zu, read size=%zu, ret=%zu", + static_cast(loc.second), read_size, ret); + return IndexError_ReadData; + } + } + return 0; +} + +key_t HnswRabitqStreamerEntity::get_key(node_id_t id) const { + if (use_key_info_map_) { + auto loc = get_key_chunk_loc(id); + IndexStorage::MemoryBlock key_block; + ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), + "invalid chunk offset"); + size_t ret = + node_chunks_[loc.first]->read(loc.second, key_block, sizeof(key_t)); + if (ailego_unlikely(ret != sizeof(key_t))) { + LOG_ERROR("Read vector failed, ret=%zu", ret); + return kInvalidKey; + } + + return *reinterpret_cast(key_block.data()); + } else { + return id; + } +} + +void HnswRabitqStreamerEntity::add_neighbor(level_t level, node_id_t id, + uint32_t size, + node_id_t neighbor_id) { + auto loc = get_neighbor_chunk_loc(level, id); + size_t offset = + loc.second + sizeof(NeighborsHeader) + size * sizeof(node_id_t); + ailego_assert_with(size < neighbor_cnt(level), "invalid neighbor size"); + ailego_assert_with(offset < loc.first->data_size(), "invalid chunk offset"); + size_t ret = loc.first->write(offset, &neighbor_id, sizeof(node_id_t)); + if (ailego_unlikely(ret != sizeof(node_id_t))) { + LOG_ERROR("Write neighbor id failed, ret=%zu", ret); + return; + } + + uint32_t neighbors = size + 1; + ret = loc.first->write(loc.second, &neighbors, sizeof(uint32_t)); + if (ailego_unlikely(ret != sizeof(uint32_t))) { + LOG_ERROR("Write neighbor cnt failed, ret=%zu", ret); + } + + return; +} + +int HnswRabitqStreamerEntity::init_chunks(const Chunk::Pointer &header_chunk) { + if (header_chunk->data_size() < header_size()) { + LOG_ERROR("Invalid header chunk size"); + return IndexError_InvalidFormat; + } + IndexStorage::MemoryBlock header_block; + size_t size = header_chunk->read(0UL, header_block, header_size()); + if (ailego_unlikely(size != header_size())) { + LOG_ERROR("Read header chunk failed"); + return IndexError_ReadData; + } + *mutable_header() = + *reinterpret_cast(header_block.data()); + + int ret = check_hnsw_index(&header()); + if (ret != 0) { + broker_->close(); + return ret; + } + + node_chunks_.resize( + broker_->get_chunk_cnt(HnswRabitqChunkBroker::CHUNK_TYPE_NODE)); + for (auto seq = 0UL; seq < node_chunks_.size(); ++seq) { + node_chunks_[seq] = + broker_->get_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, seq); + if (!node_chunks_[seq]) { + LOG_ERROR("Missing hnsw streamer data chunk %zu th of %zu", seq, + node_chunks_.size()); + return IndexError_InvalidFormat; + } + } + + upper_neighbor_chunks_.resize( + broker_->get_chunk_cnt(HnswRabitqChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR)); + for (auto seq = 0UL; seq < upper_neighbor_chunks_.size(); ++seq) { + upper_neighbor_chunks_[seq] = broker_->get_chunk( + HnswRabitqChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, seq); + if (!upper_neighbor_chunks_[seq]) { + LOG_ERROR("Missing hnsw streamer index chunk %zu th of %zu", seq, + upper_neighbor_chunks_.size()); + return IndexError_InvalidFormat; + } + } + + return 0; +} + +int HnswRabitqStreamerEntity::open(IndexStorage::Pointer stg, + uint64_t max_index_size, bool check_crc) { + std::lock_guard lock(mutex_); + bool huge_page = stg->isHugePage(); + LOG_DEBUG("huge_page: %d", (int)huge_page); + int ret = init_chunk_params(max_index_size, huge_page); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("init_chunk_params failed for %s", IndexError::What(ret)); + return ret; + } + ret = broker_->open(std::move(stg), max_index_size_, chunk_size_, check_crc); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Open index failed for %s", IndexError::What(ret)); + return ret; + } + ret = upper_neighbor_index_->init(broker_, upper_neighbor_chunk_size_, + scaling_factor(), estimate_doc_capacity(), + kUpperHashMemoryInflateRatio); + if (ailego_unlikely(ret != 0)) { + LOG_ERROR("Init neighbor hash map failed"); + return ret; + } + + //! init header + auto header_chunk = + broker_->get_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_HEADER, + HnswRabitqChunkBroker::kDefaultChunkSeqId); + if (!header_chunk) { // open empty index, create one + auto p = broker_->alloc_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_HEADER, + HnswRabitqChunkBroker::kDefaultChunkSeqId, + header_size()); + if (ailego_unlikely(p.first != 0)) { + LOG_ERROR("Alloc header chunk failed"); + return p.first; + } + size_t size = p.second->write(0UL, &header(), header_size()); + if (ailego_unlikely(size != header_size())) { + LOG_ERROR("Write header chunk failed"); + return IndexError_WriteData; + } + return 0; + } + + //! Open an exist hnsw index + ret = init_chunks(header_chunk); + if (ailego_unlikely(ret != 0)) { + return ret; + } + + //! total docs including features wrote in index but neighbors may not ready + node_id_t total_vecs = 0; + if (node_chunks_.size() > 0) { + size_t last_idx = node_chunks_.size() - 1; + auto last_chunk = node_chunks_[last_idx]; + if (last_chunk->data_size() % node_size()) { + LOG_WARN("The index may broken"); + return IndexError_InvalidFormat; + } + total_vecs = last_idx * node_cnt_per_chunk_ + + node_chunks_[last_idx]->data_size() / node_size(); + } + + LOG_INFO( + "Open index, l0NeighborCnt=%zu upperNeighborCnt=%zu " + "efConstruction=%zu curDocCnt=%u totalVecs=%u maxLevel=%u", + l0_neighbor_cnt(), upper_neighbor_cnt(), ef_construction(), doc_cnt(), + total_vecs, cur_max_level()); + //! try to correct the docCnt if index not fully flushed + if (doc_cnt() != total_vecs) { + LOG_WARN("Index closed abnormally, using totalVecs as curDocCnt"); + *mutable_doc_cnt() = total_vecs; + } + if (filter_same_key_ || get_vector_enabled_) { + if (use_key_info_map_) { + for (node_id_t id = 0U; id < doc_cnt(); ++id) { + if (get_key(id) == kInvalidKey) { + continue; + } + (*keys_map_)[get_key(id)] = id; + } + } + } + + stats_.set_loaded_count(doc_cnt()); + + return 0; +} + +int HnswRabitqStreamerEntity::close() { + LOG_DEBUG("close index"); + + std::lock_guard lock(mutex_); + flush_header(); + mutable_header()->reset(); + upper_neighbor_index_->cleanup(); + keys_map_->clear(); + header_.clear(); + node_chunks_.clear(); + upper_neighbor_chunks_.clear(); + + return broker_->close(); +} + +int HnswRabitqStreamerEntity::flush(uint64_t checkpoint) { + LOG_INFO("Flush index, curDocs=%zu", static_cast(doc_cnt())); + + std::lock_guard lock(mutex_); + flush_header(); + int ret = broker_->flush(checkpoint); + if (ret != 0) { + return ret; + } + + return 0; +} + +int HnswRabitqStreamerEntity::dump(const IndexDumper::Pointer &dumper) { + LOG_INFO("Dump index, curDocs=%zu", static_cast(doc_cnt())); + + //! sort by keys, to support get_vector by key in searcher + std::vector keys(doc_cnt()); + for (node_id_t i = 0; i < doc_cnt(); ++i) { + keys[i] = get_key(i); + } + + //! dump neighbors + auto get_level = [&](node_id_t id) { + auto it = upper_neighbor_index_->find(id); + if (it == upper_neighbor_index_->end()) { + return 0U; + }; + auto meta = reinterpret_cast(&it->second); + return meta->level; + }; + auto ret = dump_segments(dumper, keys.data(), get_level); + if (ailego_unlikely(ret < 0)) { + return ret; + } + *stats_.mutable_dumped_size() += ret; + + return 0; +} + +int HnswRabitqStreamerEntity::check_hnsw_index(const HNSWHeader *hd) const { + if (l0_neighbor_cnt() != hd->l0_neighbor_cnt() || + upper_neighbor_cnt() != hd->upper_neighbor_cnt()) { + LOG_ERROR("Param neighbor cnt: %zu:%zu mismatch index previous %zu:%zu", + l0_neighbor_cnt(), upper_neighbor_cnt(), hd->l0_neighbor_cnt(), + hd->upper_neighbor_cnt()); + return IndexError_Mismatch; + } + if (vector_size() != hd->vector_size()) { + LOG_ERROR("vector size %zu mismatch index previous %zu", vector_size(), + hd->vector_size()); + return IndexError_Mismatch; + } + if (ef_construction() != hd->ef_construction()) { + LOG_WARN("Param efConstruction %zu mismatch index previous %zu", + ef_construction(), hd->ef_construction()); + } + if (scaling_factor() != hd->scaling_factor()) { + LOG_WARN("Param scalingFactor %zu mismatch index previous %zu", + scaling_factor(), hd->scaling_factor()); + return IndexError_Mismatch; + } + if (prune_cnt() != hd->neighbor_prune_cnt()) { + LOG_WARN("Param pruneCnt %zu mismatch index previous %zu", prune_cnt(), + hd->neighbor_prune_cnt()); + return IndexError_Mismatch; + } + if ((hd->entry_point() != kInvalidNodeId && + hd->entry_point() >= hd->doc_cnt()) || + (hd->entry_point() == kInvalidNodeId && hd->doc_cnt() > 0U)) { + LOG_WARN("Invalid entryPoint %zu, docCnt %zu", + static_cast(hd->entry_point()), + static_cast(hd->doc_cnt())); + return IndexError_InvalidFormat; + } + if (hd->entry_point() == kInvalidNodeId && + broker_->get_chunk_cnt(HnswRabitqChunkBroker::CHUNK_TYPE_NODE) > 0) { + LOG_WARN("The index is broken, maybe it haven't flush"); + return IndexError_InvalidFormat; + } + + return 0; +} + +int HnswRabitqStreamerEntity::add_vector(level_t level, key_t key, + const void *vec, node_id_t *id) { + Chunk::Pointer node_chunk; + size_t chunk_offset = -1UL; + + std::lock_guard lock(mutex_); + // duplicate check + if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { + LOG_WARN("Try to add duplicate key, ignore it"); + return IndexError_Duplicate; + } + + node_id_t local_id = static_cast(doc_cnt()); + uint32_t chunk_index = node_chunks_.size() - 1U; + if (chunk_index == -1U || + (node_chunks_[chunk_index]->data_size() >= + node_cnt_per_chunk_ * node_size())) { // no space left and need to alloc + if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { + LOG_ERROR("add vector failed for no memory quota"); + return IndexError_IndexFull; + } + chunk_index++; + auto p = broker_->alloc_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, + chunk_index, chunk_size_); + if (ailego_unlikely(p.first != 0)) { + LOG_ERROR("Alloc data chunk failed"); + return p.first; + } + node_chunk = p.second; + chunk_offset = 0UL; + node_chunks_.emplace_back(node_chunk); + } else { + node_chunk = node_chunks_[chunk_index]; + chunk_offset = node_chunk->data_size(); + } + + size_t size = node_chunk->write(chunk_offset, vec, vector_size()); + if (ailego_unlikely(size != vector_size())) { + LOG_ERROR("Chunk write vec failed, ret=%zu", size); + return IndexError_WriteData; + } + size = node_chunk->write(chunk_offset + vector_size(), &key, sizeof(key_t)); + if (ailego_unlikely(size != sizeof(key_t))) { + LOG_ERROR("Chunk write vec failed, ret=%zu", size); + return IndexError_WriteData; + } + //! level 0 neighbors is inited to zero by default + + int ret = add_upper_neighbor(level, local_id); + if (ret != 0) { + return ret; + } + + chunk_offset += node_size(); + if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { + LOG_ERROR("Chunk resize to %zu failed", chunk_offset); + return IndexError_Runtime; + } + if (filter_same_key_ || get_vector_enabled_) { + if (use_key_info_map_) { + keys_map_lock_->lock(); + (*keys_map_)[key] = local_id; + keys_map_lock_->unlock(); + } + } + + *mutable_doc_cnt() += 1; + broker_->mark_dirty(); + *id = local_id; + + return 0; +} + +int HnswRabitqStreamerEntity::add_vector_with_id(level_t level, node_id_t id, + const void *vec) { + Chunk::Pointer node_chunk; + size_t chunk_offset = -1UL; + key_t key = id; + + std::lock_guard lock(mutex_); + + // duplicate check + if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { + LOG_WARN("Try to add duplicate key, ignore it"); + return IndexError_Duplicate; + } + + // set node_chunk & chunk_offset if succeed + auto func_get_node_chunk_and_offset = [&](node_id_t node_id) -> int { + uint32_t chunk_index = node_id >> node_index_mask_bits_; + ailego_assert_with(chunk_index <= node_chunks_.size(), "invalid chunk idx"); + // belongs to next chunk + if (chunk_index == node_chunks_.size()) { + if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { + LOG_ERROR("add vector failed for no memory quota"); + return IndexError_IndexFull; + } + auto p = broker_->alloc_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, + chunk_index, chunk_size_); + if (ailego_unlikely(p.first != 0)) { + LOG_ERROR("Alloc data chunk failed"); + return p.first; + } + node_chunk = p.second; + node_chunks_.emplace_back(node_chunk); + } + + node_chunk = node_chunks_[chunk_index]; + chunk_offset = (node_id & node_index_mask_) * node_size(); + return 0; + }; + + for (size_t start_id = doc_cnt(); start_id < id; ++start_id) { + if (auto ret = func_get_node_chunk_and_offset(start_id); ret != 0) { + LOG_ERROR("func_get_node_chunk_and_offset failed"); + return ret; + } + size_t size = node_chunk->write(chunk_offset + vector_size(), &kInvalidKey, + sizeof(key_t)); + if (ailego_unlikely(size != sizeof(key_t))) { + LOG_ERROR("Chunk write key failed, ret=%zu", size); + return IndexError_WriteData; + } + + chunk_offset += node_size(); + if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { + LOG_ERROR("Chunk resize to %zu failed", chunk_offset); + return IndexError_Runtime; + } + } + + if (auto ret = func_get_node_chunk_and_offset(id); ret != 0) { + LOG_ERROR("func_get_node_chunk_and_offset failed"); + return ret; + } + + size_t size = node_chunk->write(chunk_offset, vec, vector_size()); + if (ailego_unlikely(size != vector_size())) { + LOG_ERROR("Chunk write vec failed, ret=%zu", size); + return IndexError_WriteData; + } + + size = node_chunk->write(chunk_offset + vector_size(), &key, sizeof(key_t)); + if (ailego_unlikely(size != sizeof(key_t))) { + LOG_ERROR("Chunk write vec failed, ret=%zu", size); + return IndexError_WriteData; + } + //! level 0 neighbors is inited to zero by default + + int ret = add_upper_neighbor(level, id); + if (ret != 0) { + return ret; + } + + if (*mutable_doc_cnt() <= id) { + *mutable_doc_cnt() = id + 1; + chunk_offset += node_size(); + if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { + LOG_ERROR("Chunk resize to %zu failed", chunk_offset); + return IndexError_Runtime; + } + } + + if (filter_same_key_ || get_vector_enabled_) { + if (use_key_info_map_) { + keys_map_lock_->lock(); + (*keys_map_)[key] = id; + keys_map_lock_->unlock(); + } + } + + broker_->mark_dirty(); + + return 0; +} + +void HnswRabitqStreamerEntity::update_ep_and_level(node_id_t ep, + level_t level) { + HnswRabitqEntity::update_ep_and_level(ep, level); + flush_header(); + + return; +} + +const HnswRabitqEntity::Pointer HnswRabitqStreamerEntity::clone() const { + std::vector node_chunks; + node_chunks.reserve(node_chunks_.size()); + for (size_t i = 0UL; i < node_chunks_.size(); ++i) { + node_chunks.emplace_back(node_chunks_[i]->clone()); + if (ailego_unlikely(!node_chunks[i])) { + LOG_ERROR("HnswRabitqStreamerEntity get chunk failed in clone"); + return HnswRabitqEntity::Pointer(); + } + } + + std::vector upper_neighbor_chunks; + upper_neighbor_chunks.reserve(upper_neighbor_chunks_.size()); + for (size_t i = 0UL; i < upper_neighbor_chunks_.size(); ++i) { + upper_neighbor_chunks.emplace_back(upper_neighbor_chunks_[i]->clone()); + if (ailego_unlikely(!upper_neighbor_chunks[i])) { + LOG_ERROR("HnswRabitqStreamerEntity get chunk failed in clone"); + return HnswRabitqEntity::Pointer(); + } + } + + HnswRabitqStreamerEntity *entity = + new (std::nothrow) HnswRabitqStreamerEntity( + stats_, header(), chunk_size_, node_index_mask_bits_, + upper_neighbor_mask_bits_, filter_same_key_, get_vector_enabled_, + upper_neighbor_index_, keys_map_lock_, keys_map_, use_key_info_map_, + std::move(node_chunks), std::move(upper_neighbor_chunks), broker_); + if (ailego_unlikely(!entity)) { + LOG_ERROR("HnswRabitqStreamerEntity new failed"); + } + return HnswRabitqEntity::Pointer(entity); +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer_entity.h b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer_entity.h new file mode 100644 index 00000000..ea36143a --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/hnsw_rabitq_streamer_entity.h @@ -0,0 +1,522 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "zvec/core/framework/index_framework.h" +#include "hnsw_rabitq_chunk.h" +#include "hnsw_rabitq_entity.h" +#include "hnsw_rabitq_index_hash.h" +#include "hnsw_rabitq_params.h" + +namespace zvec { +namespace core { + +//! HnswRabitqStreamerEntity manage vector data, pkey, and node's neighbors +class HnswRabitqStreamerEntity : public HnswRabitqEntity { + public: + //! Cleanup + //! return 0 on success, or errCode in failure + virtual int cleanup() override; + + //! Make a copy of streamer entity, to support thread-safe operation. + //! The segment in container cannot be read concurrenly + virtual const HnswRabitqEntity::Pointer clone() const override; + + //! Get primary key of the node id + virtual key_t get_key(node_id_t id) const override; + + //! Get vector feature data by key + virtual const void *get_vector(node_id_t id) const override; + + //! Get vectors feature data by local ids + virtual int get_vector(const node_id_t *ids, uint32_t count, + const void **vecs) const override; + + virtual int get_vector(const node_id_t id, + IndexStorage::MemoryBlock &block) const override; + + virtual int get_vector( + const node_id_t *ids, uint32_t count, + std::vector &vec_blocks) const override; + + //! Get the node id's neighbors on graph level + //! Note: the neighbors cannot be modified, using the following + //! method to get WritableNeighbors if want to + virtual const Neighbors get_neighbors(level_t level, + node_id_t id) const override; + + //! Add vector and key to hnsw entity, and local id will be saved in id + virtual int add_vector(level_t level, key_t key, const void *vec, + node_id_t *id) override; + + //! Add vector and id to hnsw entity + virtual int add_vector_with_id(level_t level, node_id_t id, + const void *vec) override; + + virtual int update_neighbors( + level_t level, node_id_t id, + const std::vector> &neighbors) + override; + + //! Append neighbor_id to node id neighbors on level + //! Notice: the caller must be ensure the neighbors not full + virtual void add_neighbor(level_t level, node_id_t id, uint32_t size, + node_id_t neighbor_id) override; + + //! Dump index by dumper + virtual int dump(const IndexDumper::Pointer &dumper) override; + + virtual void update_ep_and_level(node_id_t ep, level_t level) override; + + void set_use_key_info_map(bool use_id_map) { + use_key_info_map_ = use_id_map; + LOG_DEBUG("use_key_info_map_: %d", (int)use_key_info_map_); + } + + public: + //! Constructor + HnswRabitqStreamerEntity(IndexStreamer::Stats &stats); + + //! Destructor + ~HnswRabitqStreamerEntity(); + + //! Get vector feature data by key + virtual const void *get_vector_by_key(key_t key) const override { + auto id = get_id(key); + return id == kInvalidNodeId ? nullptr : get_vector(id); + } + + virtual int get_vector_by_key( + const key_t key, IndexStorage::MemoryBlock &block) const override { + auto id = get_id(key); + if (id != kInvalidNodeId) { + return get_vector(id, block); + } else { + return IndexError_InvalidArgument; + } + } + + //! Init entity + int init(size_t max_doc_cnt); + + //! Flush graph entity to disk + //! return 0 on success, or errCode in failure + int flush(uint64_t checkpoint); + + //! Open entity from storage + //! return 0 on success, or errCode in failure + int open(IndexStorage::Pointer stg, uint64_t max_index_size, bool check_crc); + + //! Close entity + //! return 0 on success, or errCode in failure + int close(); + + //! Set meta information from entity + int set_index_meta(const IndexMeta &meta) const { + return IndexHelper::SerializeToStorage(meta, broker_->storage().get()); + } + + //! Get meta information from entity + int get_index_meta(IndexMeta *meta) const { + return IndexHelper::DeserializeFromStorage(broker_->storage().get(), meta); + } + + //! Set params: chunk size + inline void set_chunk_size(size_t val) { + chunk_size_ = val; + } + + //! Set params + inline void set_filter_same_key(bool val) { + filter_same_key_ = val; + } + + //! Set params + inline void set_get_vector(bool val) { + get_vector_enabled_ = val; + } + + //! Get vector local id by key + inline node_id_t get_id(key_t key) const { + if (use_key_info_map_) { + keys_map_lock_->lock_shared(); + auto it = keys_map_->find(key); + keys_map_lock_->unlock_shared(); + return it == keys_map_->end() ? kInvalidNodeId : it->second; + } else { + return key; + } + } + + void print_key_map() const { + std::cout << "key map begins" << std::endl; + + auto iter = keys_map_->begin(); + while (iter != keys_map_->end()) { + std::cout << "key: " << iter->first << ", id: " << iter->second + << std::endl; + ; + iter++; + } + + std::cout << "key map ends" << std::endl; + } + + //! Get l0 neighbors size + inline size_t neighbors_size() const { + return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); + } + + //! Get neighbors size for level > 0 + inline size_t upper_neighbors_size() const { + return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); + } + + + private: + union UpperNeighborIndexMeta { + struct { + uint32_t level : 4; + uint32_t index : 28; // index is composite type: chunk idx, and the + // N th neighbors in chunk, they two composite + // the 28 bits location + }; + uint32_t data; + }; + + template + using HashMap = google::dense_hash_map>; + template + using HashMapPointer = std::shared_ptr>; + + template + using HashSet = google::dense_hash_set>; + template + using HashSetPointer = std::shared_ptr>; + + //! upper neighbor index hashmap + using NIHashMap = HnswIndexHashMap; + using NIHashMapPointer = std::shared_ptr; + + //! Private construct, only be called by clone method + HnswRabitqStreamerEntity(IndexStreamer::Stats &stats, const HNSWHeader &hd, + size_t chunk_size, uint32_t node_index_mask_bits, + uint32_t upper_neighbor_mask_bits, + bool filter_same_key, bool get_vector_enabled, + const NIHashMapPointer &upper_neighbor_index, + std::shared_ptr &keys_map_lock, + const HashMapPointer &keys_map, + bool use_key_info_map, + std::vector &&node_chunks, + std::vector &&upper_neighbor_chunks, + const HnswRabitqChunkBroker::Pointer &broker) + : stats_(stats), + chunk_size_(chunk_size), + node_index_mask_bits_(node_index_mask_bits), + node_cnt_per_chunk_(1UL << node_index_mask_bits_), + node_index_mask_(node_cnt_per_chunk_ - 1), + upper_neighbor_mask_bits_(upper_neighbor_mask_bits), + upper_neighbor_mask_((1U << upper_neighbor_mask_bits_) - 1), + filter_same_key_(filter_same_key), + get_vector_enabled_(get_vector_enabled), + use_key_info_map_(use_key_info_map), + upper_neighbor_index_(upper_neighbor_index), + keys_map_lock_(keys_map_lock), + keys_map_(keys_map), + node_chunks_(std::move(node_chunks)), + upper_neighbor_chunks_(std::move(upper_neighbor_chunks)), + broker_(broker) { + *mutable_header() = hd; + + neighbor_size_ = neighbors_size(); + upper_neighbor_size_ = upper_neighbors_size(); + } + + //! Called only in searching procedure per context, so no need to lock + void sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE type, size_t idx, + std::vector *chunks) const { + if (ailego_likely(idx < chunks->size())) { + return; + } + for (size_t i = chunks->size(); i <= idx; ++i) { + auto chunk = broker_->get_chunk(type, i); + // the storage can ensure get chunk will success after the first get + ailego_assert_with(!!chunk, "get chunk failed"); + chunks->emplace_back(std::move(chunk)); + } + } + + //! return pair: chunk index + chunk offset + inline std::pair get_vector_chunk_loc( + node_id_t id) const { + uint32_t chunk_idx = id >> node_index_mask_bits_; + uint32_t offset = (id & node_index_mask_) * node_size(); + + sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, chunk_idx, + &node_chunks_); + return std::make_pair(chunk_idx, offset); + } + + //! return pair: chunk index + chunk offset + inline std::pair get_key_chunk_loc(node_id_t id) const { + uint32_t chunk_idx = id >> node_index_mask_bits_; + uint32_t offset = (id & node_index_mask_) * node_size() + vector_size(); + + sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, chunk_idx, + &node_chunks_); + return std::make_pair(chunk_idx, offset); + } + + inline std::pair get_upper_neighbor_chunk_loc( + level_t level, node_id_t id) const { + auto it = upper_neighbor_index_->find(id); + ailego_assert_abort(it != upper_neighbor_index_->end(), + "Get upper neighbor header failed"); + auto meta = reinterpret_cast(&it->second); + uint32_t chunk_idx = (meta->index) >> upper_neighbor_mask_bits_; + uint32_t offset = (((meta->index) & upper_neighbor_mask_) + level - 1) * + upper_neighbor_size_; + sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, chunk_idx, + &upper_neighbor_chunks_); + ailego_assert_abort(chunk_idx < upper_neighbor_chunks_.size(), + "invalid chunk idx"); + ailego_assert_abort(offset < upper_neighbor_chunks_[chunk_idx]->data_size(), + "invalid chunk offset"); + return std::make_pair(chunk_idx, offset); + } + + //! return pair: chunk + chunk offset + inline std::pair get_neighbor_chunk_loc(level_t level, + node_id_t id) const { + if (level == 0UL) { + uint32_t chunk_idx = id >> node_index_mask_bits_; + uint32_t offset = + (id & node_index_mask_) * node_size() + vector_size() + sizeof(key_t); + + sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, chunk_idx, + &node_chunks_); + ailego_assert_abort(chunk_idx < node_chunks_.size(), "invalid chunk idx"); + ailego_assert_abort(offset < node_chunks_[chunk_idx]->data_size(), + "invalid chunk offset"); + return std::make_pair(node_chunks_[chunk_idx].get(), offset); + } else { + auto p = get_upper_neighbor_chunk_loc(level, id); + return std::make_pair(upper_neighbor_chunks_[p.first].get(), p.second); + } + } + + //! Chunk hnsw index valid + int check_hnsw_index(const HNSWHeader *hd) const; + + size_t get_total_upper_neighbors_size(level_t level) const { + return level * upper_neighbor_size_; + } + + //! Add upper neighbor header and reserve space for upper neighbor + int add_upper_neighbor(level_t level, node_id_t id) { + if (level == 0) { + return 0; + } + Chunk::Pointer chunk; + uint64_t chunk_offset = -1UL; + size_t neighbors_size = get_total_upper_neighbors_size(level); + uint64_t chunk_index = upper_neighbor_chunks_.size() - 1UL; + if (chunk_index == -1UL || + (upper_neighbor_chunks_[chunk_index]->padding_size() < + neighbors_size)) { // no space left and need to alloc + chunk_index++; + if (ailego_unlikely(upper_neighbor_chunks_.capacity() == + upper_neighbor_chunks_.size())) { + LOG_ERROR("add upper neighbor failed for no memory quota"); + return IndexError_IndexFull; + } + auto p = + broker_->alloc_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, + chunk_index, upper_neighbor_chunk_size_); + if (ailego_unlikely(p.first != 0)) { + LOG_ERROR("Alloc data chunk failed"); + return p.first; + } + chunk = p.second; + chunk_offset = 0UL; + upper_neighbor_chunks_.emplace_back(chunk); + } else { + chunk = upper_neighbor_chunks_[chunk_index]; + chunk_offset = chunk->data_size(); + } + ailego_assert_with((size_t)level < kMaxGraphLayers, "invalid level"); + ailego_assert_with(chunk_offset % upper_neighbor_size_ == 0, + "invalid offset"); + ailego_assert_with((chunk_offset / upper_neighbor_size_) < + (1U << upper_neighbor_mask_bits_), + "invalid offset"); + ailego_assert_with(chunk_index < (1U << (28 - upper_neighbor_mask_bits_)), + "invalid chunk index"); + UpperNeighborIndexMeta meta; + meta.level = level; + meta.index = (chunk_index << upper_neighbor_mask_bits_) | + (chunk_offset / upper_neighbor_size_); + chunk_offset += upper_neighbor_size_ * level; + if (ailego_unlikely(!upper_neighbor_index_->insert(id, meta.data))) { + LOG_ERROR("HashMap insert value failed"); + return IndexError_Runtime; + } + + if (ailego_unlikely(chunk->resize(chunk_offset) != chunk_offset)) { + LOG_ERROR("Chunk resize to %zu failed", (size_t)chunk_offset); + return IndexError_Runtime; + } + + return 0; + } + + size_t estimate_doc_capacity() const { + return node_chunks_.capacity() * node_cnt_per_chunk_; + } + + int init_chunk_params(size_t max_index_size, bool huge_page) { + node_cnt_per_chunk_ = std::max(1, chunk_size_ / node_size()); + //! align node cnt per chunk to pow of 2 + node_index_mask_bits_ = std::ceil(std::log2(node_cnt_per_chunk_)); + node_cnt_per_chunk_ = 1UL << node_index_mask_bits_; + if (huge_page) { + chunk_size_ = AlignHugePageSize(node_cnt_per_chunk_ * node_size()); + } else { + chunk_size_ = AlignPageSize(node_cnt_per_chunk_ * node_size()); + } + node_index_mask_ = node_cnt_per_chunk_ - 1; + + if (max_index_size == 0UL) { + max_index_size_ = chunk_size_ * kDefaultMaxChunkCnt; + } else { + max_index_size_ = max_index_size; + } + + //! To get a balanced upper neighbor chunk size. + //! If the upper chunk size is equal to node chunk size, it may waste + //! upper neighbor chunk space; if the upper neighbor chunk size is too + //! small, the will need large upper neighbor chunks index space. So to + //! get a balanced ratio be sqrt of the node/neighbor size ratio + float ratio = + std::sqrt(node_size() * scaling_factor() * 1.0f / upper_neighbor_size_); + if (huge_page) { + upper_neighbor_chunk_size_ = AlignHugePageSize( + std::max(get_total_upper_neighbors_size(kMaxGraphLayers), + static_cast(chunk_size_ / ratio))); + } else { + upper_neighbor_chunk_size_ = AlignPageSize( + std::max(get_total_upper_neighbors_size(kMaxGraphLayers), + static_cast(chunk_size_ / ratio))); + } + upper_neighbor_mask_bits_ = + std::ceil(std::log2(upper_neighbor_chunk_size_ / upper_neighbor_size_)); + upper_neighbor_mask_ = (1 << upper_neighbor_mask_bits_) - 1; + + size_t max_node_chunk_cnt = std::ceil(max_index_size_ / chunk_size_); + size_t max_upper_chunk_cnt = std::ceil( + (max_node_chunk_cnt * node_cnt_per_chunk_ * 1.0f / scaling_factor()) / + (upper_neighbor_chunk_size_ / upper_neighbor_size_)); + max_upper_chunk_cnt = + max_upper_chunk_cnt + std::ceil(max_upper_chunk_cnt / scaling_factor()); + + //! reserve space to avoid memmove in chunks vector emplace chunk, so + //! as to lock-free in reading chunk + node_chunks_.reserve(max_node_chunk_cnt); + upper_neighbor_chunks_.reserve(max_upper_chunk_cnt); + + LOG_DEBUG( + "Settings: nodeSize=%zu chunkSize=%u upperNeighborSize=%u " + "upperNeighborChunkSize=%u " + "nodeCntPerChunk=%u maxChunkCnt=%zu maxNeighborChunkCnt=%zu " + "maxIndexSize=%zu ratio=%.3f", + node_size(), chunk_size_, upper_neighbor_size_, + upper_neighbor_chunk_size_, node_cnt_per_chunk_, max_node_chunk_cnt, + max_upper_chunk_cnt, max_index_size_, ratio); + + return 0; + } + + //! Init node chunk and neighbor chunks + int init_chunks(const Chunk::Pointer &header_chunk); + + int flush_header(void) { + if (!broker_->dirty()) { + // do not need to flush + return 0; + } + auto header_chunk = + broker_->get_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_HEADER, + HnswRabitqChunkBroker::kDefaultChunkSeqId); + if (ailego_unlikely(!header_chunk)) { + LOG_ERROR("get header chunk failed"); + return IndexError_Runtime; + } + size_t size = header_chunk->write(0UL, &header(), header_size()); + if (ailego_unlikely(size != header_size())) { + LOG_ERROR("Write header chunk failed"); + return IndexError_WriteData; + } + + return 0; + } + + private: + HnswRabitqStreamerEntity(const HnswRabitqStreamerEntity &) = delete; + HnswRabitqStreamerEntity &operator=(const HnswRabitqStreamerEntity &) = + delete; + static constexpr uint64_t kUpperHashMemoryInflateRatio = 2.0f; + + private: + IndexStreamer::Stats &stats_; + HNSWHeader header_{}; + std::mutex mutex_{}; + size_t max_index_size_{0UL}; + uint32_t chunk_size_{kDefaultChunkSize}; + uint32_t upper_neighbor_chunk_size_{kDefaultChunkSize}; + uint32_t node_index_mask_bits_{0U}; + uint32_t node_cnt_per_chunk_{0U}; + uint32_t node_index_mask_{0U}; + uint32_t neighbor_size_{0U}; + uint32_t upper_neighbor_size_{0U}; + //! UpperNeighborIndex.index composite chunkIdx and offset in chunk by the + //! following mask + uint32_t upper_neighbor_mask_bits_{0U}; + uint32_t upper_neighbor_mask_{0U}; + bool filter_same_key_{false}; + bool get_vector_enabled_{false}; + bool use_key_info_map_{true}; + + NIHashMapPointer upper_neighbor_index_{}; + + mutable std::shared_ptr keys_map_lock_{}; + HashMapPointer keys_map_{}; + + //! the chunks will be changed in searcher, so need mutable + //! data chunk include: vector, key, level 0 neighbors + mutable std::vector node_chunks_{}; + + //! upper neighbor chunk inlude: UpperNeighborHeader + (1~level) neighbors + mutable std::vector upper_neighbor_chunks_{}; + + HnswRabitqChunkBroker::Pointer broker_{}; // chunk broker +}; + +} // namespace core +} // namespace zvec \ No newline at end of file diff --git a/src/core/algorithm/hnsw-rabitq/rabitq_converter.cc b/src/core/algorithm/hnsw-rabitq/rabitq_converter.cc new file mode 100644 index 00000000..923e9862 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/rabitq_converter.cc @@ -0,0 +1,311 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rabitq_converter.h" +#include +#include +#include +#include +#include +#include +#include "algorithm/hnsw-rabitq/rabitq_reformer.h" +#include "zvec/core/framework/index_cluster.h" +#include "zvec/core/framework/index_error.h" +#include "zvec/core/framework/index_factory.h" +#include "zvec/core/framework/index_features.h" +#include "zvec/core/framework/index_holder.h" +#include "zvec/core/framework/index_meta.h" +#include "rabitq_params.h" +#include "rabitq_utils.h" + +namespace zvec { +namespace core { + +RabitqConverter::~RabitqConverter() { + this->cleanup(); +} + +int RabitqConverter::init(const IndexMeta &meta, const ailego::Params ¶ms) { + // Copy meta and ensure it has metric information + meta_ = meta; + dimension_ = meta.dimension(); + + // Ensure meta has metric set + if (meta_.metric_name().empty()) { + meta_.set_metric("SquaredEuclidean", 0, ailego::Params()); + } + + // Round up dimension to multiple of 64 + padded_dim_ = ((dimension_ + 63) / 64) * 64; + + // Get RaBitQ parameters with defaults + uint32_t total_bits = 0; + params.get(PARAM_RABITQ_TOTAL_BITS, &total_bits); + if (total_bits == 0) { + total_bits = kDefaultRabitqTotalBits; + } + if (total_bits < 1 || total_bits > 9) { + LOG_ERROR("Invalid total_bits: %zu, must be in [1, 9]", (size_t)total_bits); + return IndexError_InvalidArgument; + } + ex_bits_ = total_bits - 1; + + params.get(PARAM_RABITQ_NUM_CLUSTERS, &num_clusters_); + if (num_clusters_ == 0) { + num_clusters_ = kDefaultNumClusters; + } + + // Validate parameters + if (num_clusters_ == 0 || num_clusters_ > 256) { + LOG_ERROR("Invalid num_clusters: %zu, must be in [1, 256]", num_clusters_); + return IndexError_InvalidArgument; + } + + if (ex_bits_ > 8) { + LOG_ERROR("Invalid ex_bits: %zu, must be <= 8", ex_bits_); + return IndexError_InvalidArgument; + } + + if (meta.data_type() != IndexMeta::DataType::DT_FP32) { + LOG_ERROR("RaBitQ only supports FP32 data type"); + return IndexError_Unsupported; + } + params.get(PARAM_RABITQ_SAMPLE_COUNT, &sample_count_); + + std::string rotator_type_str; + params.get(PARAM_RABITQ_ROTATOR_TYPE, &rotator_type_str); + if (rotator_type_str.empty()) { + rotator_type_ = rabitqlib::RotatorType::FhtKacRotator; + } else if (strncasecmp(rotator_type_str.c_str(), "fht", 3) == 0) { + rotator_type_ = rabitqlib::RotatorType::FhtKacRotator; + } else if (strncasecmp(rotator_type_str.c_str(), "matrix", 6) == 0) { + rotator_type_ = rabitqlib::RotatorType::MatrixRotator; + } else { + LOG_ERROR("Invalid rotator_type: %s", rotator_type_str.c_str()); + return IndexError_InvalidArgument; + } + + // Create rotator + rotator_.reset( + rabitqlib::choose_rotator(dimension_, rotator_type_, padded_dim_)); + + LOG_INFO( + "RabitqConverter initialized: dim=%zu, padded_dim=%zu, " + "num_clusters=%zu, ex_bits=%zu, rotator_type=%d[%s] sample_count[%zu]", + dimension_, padded_dim_, num_clusters_, ex_bits_, (int)rotator_type_, + rotator_type_str.c_str(), sample_count_); + + return 0; +} + +int RabitqConverter::cleanup() { + centroids_.clear(); + rotated_centroids_.clear(); + result_holder_.reset(); + rotator_.reset(); + return 0; +} + +int RabitqConverter::train(IndexHolder::Pointer holder) { + if (!holder) { + LOG_ERROR("Null holder for training"); + return IndexError_InvalidArgument; + } + + ailego::ElapsedTime timer; + + size_t vector_count = holder->count(); + if (vector_count == 0) { + LOG_ERROR("No vectors for training"); + return IndexError_InvalidArgument; + } + + // do sampling from all data + size_t sample_count = vector_count; + if (sample_count_ > 0) { + sample_count = std::min(sample_count_, vector_count); + } + LOG_INFO("Training with %zu vectors from %zu of holder", sample_count, + vector_count); + auto sampler = std::make_shared>( + meta_, sample_count); + auto iter = holder->create_iterator(); + if (!iter) { + LOG_ERROR("Create iterator error"); + return IndexError_Runtime; + } + for (; iter->is_valid(); iter->next()) { + sampler->emplace(iter->data()); + } + + // Holder is not needed, cleanup it. + holder.reset(); + + if (sampler->count() == 0) { + LOG_ERROR("Load training data error"); + return IndexError_InvalidLength; + } + + + // Create KmeansCluster for training centroids + auto cluster = IndexFactory::CreateCluster("OptKmeansCluster"); + if (!cluster) { + LOG_ERROR("Failed to create OptKmeansCluster"); + return IndexError_NoExist; + } + + // Initialize cluster + LOG_INFO( + "Initializing KmeansCluster with meta: dim=%u, data_type=%d, metric=%s", + meta_.dimension(), (int)meta_.data_type(), meta_.metric_name().c_str()); + ailego::Params cluster_params; + int ret = cluster->init(meta_, cluster_params); + if (ret != 0) { + LOG_ERROR("Failed to initialize KmeansCluster: %d", ret); + return ret; + } + + ret = cluster->mount(sampler); + cluster->suggest(num_clusters_); + + // Perform clustering + IndexCluster::CentroidList cents; + // TODO: support specify threads with argument + auto threads = std::make_shared(0, false); + ret = cluster->cluster(threads, cents); + if (ret != 0) { + LOG_ERROR("Failed to perform clustering: %d", ret); + return ret; + } + + if (cents.size() != num_clusters_) { + LOG_WARN("Expected %zu clusters, got %zu", num_clusters_, cents.size()); + num_clusters_ = cents.size(); + } + // Extract original centroids (for LinearSeeker query) + centroids_.resize(num_clusters_ * dimension_); + // Extract rotated centroids (for quantization) + rotated_centroids_.resize(num_clusters_ * padded_dim_); + for (uint32_t i = 0; i < num_clusters_; ++i) { + const float *cent_data = static_cast(cents[i].feature()); + // Save original centroids + std::memcpy(¢roids_[i * dimension_], cent_data, + dimension_ * sizeof(float)); + // Save rotated centroids + this->rotator_->rotate(cent_data, &rotated_centroids_[i * padded_dim_]); + } + + stats_.set_trained_count(sampler->count()); + stats_.set_trained_costtime(timer.milli_seconds()); + + LOG_INFO("Training completed: %zu centroids, cost %zu ms", num_clusters_, + static_cast(timer.milli_seconds())); + + return 0; +} + + +int RabitqConverter::transform(IndexHolder::Pointer holder) { + if (!holder) { + LOG_ERROR("Null holder for transformation"); + return IndexError_InvalidArgument; + } + + if (rotated_centroids_.empty()) { + LOG_ERROR("Centroids not trained yet"); + return IndexError_NoReady; + } + + LOG_ERROR("Not implemented"); + return IndexError_NotImplemented; +} + +int RabitqConverter::dump(const IndexDumper::Pointer &dumper) { + if (!dumper) { + LOG_ERROR("Null dumper"); + return IndexError_InvalidArgument; + } + + if (rotated_centroids_.empty() || centroids_.empty()) { + LOG_ERROR("No centroids to dump"); + return IndexError_NoReady; + } + + ailego::ElapsedTime timer; + size_t dumped_size = 0; + + int ret = dump_rabitq_centroids( + dumper, dimension_, padded_dim_, ex_bits_, num_clusters_, rotator_type_, + rotated_centroids_, centroids_, rotator_, &dumped_size); + if (ret != 0) { + return ret; + } + + stats_.set_dumped_size(dumped_size); + stats_.set_dumped_costtime(timer.milli_seconds()); + + LOG_INFO("Dump completed: %zu bytes, cost %zu ms", stats_.dumped_size(), + static_cast(timer.milli_seconds())); + return 0; +} + +int RabitqConverter::to_reformer(IndexReformer::Pointer *reformer) { + auto memory_dumper = IndexFactory::CreateDumper("MemoryDumper"); + memory_dumper->init(ailego::Params()); + std::string file_id = ailego::StringHelper::Concat( + "rabitq_converter_", ailego::Monotime::MilliSeconds(), rand()); + int ret = memory_dumper->create(file_id); + if (ret != 0) { + LOG_ERROR("Failed to create memory dumper: %d", ret); + return ret; + } + ret = this->dump(memory_dumper); + if (ret != 0) { + LOG_ERROR("Failed to dump RabitqConverter: %d", ret); + return ret; + } + ret = memory_dumper->close(); + if (ret != 0) { + LOG_ERROR("Failed to close memory dumper: %d", ret); + return ret; + } + + auto res = std::make_shared(); + ailego::Params reformer_params; + reformer_params.set(PARAM_RABITQ_METRIC_NAME, meta_.metric_name()); + ret = res->init(reformer_params); + if (ret != 0) { + LOG_ERROR("Failed to initialize RabitqReformer: %d", ret); + return ret; + } + auto memory_storage = IndexFactory::CreateStorage("MemoryReadStorage"); + ret = memory_storage->open(file_id, false); + if (ret != 0) { + LOG_ERROR("Failed to open memory storage: %d", ret); + return ret; + } + ret = res->load(memory_storage); + if (ret != 0) { + LOG_ERROR("Failed to load RabitqReformer: %d", ret); + return ret; + } + *reformer = std::move(res); + // TODO: release memory of memory_storage + return 0; +} + +INDEX_FACTORY_REGISTER_CONVERTER_ALIAS(RabitqConverter, RabitqConverter); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/rabitq_converter.h b/src/core/algorithm/hnsw-rabitq/rabitq_converter.h new file mode 100644 index 00000000..d7e52a78 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/rabitq_converter.h @@ -0,0 +1,101 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include "zvec/core/framework/index_cluster.h" +#include "zvec/core/framework/index_converter.h" +#include "zvec/core/framework/index_reformer.h" +#include "zvec/core/framework/index_threads.h" +#include "rabitq_params.h" + +namespace zvec { +namespace core { + +class RabitqReformer; + +/*! RaBitQ Converter + * Trains KMeans centroids and quantizes vectors using RaBitQ + */ +class RabitqConverter : public IndexConverter { + public: + //! Constructor + RabitqConverter() = default; + + //! Destructor + ~RabitqConverter() override; + + //! Initialize Converter + int init(const IndexMeta &meta, const ailego::Params ¶ms) override; + + //! Cleanup Converter + int cleanup(void) override; + + //! Train the data - perform KMeans clustering + int train(IndexHolder::Pointer holder) override; + + //! Transform the data - quantize vectors using RaBitQ + int transform(IndexHolder::Pointer holder) override; + + //! Dump centroids and config into storage + int dump(const IndexDumper::Pointer &dumper) override; + + //! Retrieve statistics + const Stats &stats(void) const override { + return stats_; + } + + //! Retrieve a holder as result + IndexHolder::Pointer result(void) const override { + return result_holder_; + } + + //! Retrieve Index Meta + const IndexMeta &meta(void) const override { + return meta_; + } + + int to_reformer(IndexReformer::Pointer *reformer) override; + + private: + static inline size_t AlignSize(size_t size) { + return (size + 0x1F) & (~0x1F); + } + + private: + IndexMeta meta_; + IndexHolder::Pointer result_holder_; + Stats stats_; + size_t sample_count_{0}; + + // RaBitQ parameters + size_t num_clusters_{0}; + size_t ex_bits_{0}; + size_t dimension_{0}; + size_t padded_dim_{0}; + + // Original centroids: num_clusters * dimension (for LinearSeeker query) + std::vector centroids_; + // Rotated centroids: num_clusters * padded_dim (for quantization) + std::vector rotated_centroids_; + + // Rotator for vector transformation + rabitqlib::RotatorType rotator_type_{rabitqlib::RotatorType::FhtKacRotator}; + std::unique_ptr> rotator_; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/rabitq_params.h b/src/core/algorithm/hnsw-rabitq/rabitq_params.h new file mode 100644 index 00000000..7468f0ca --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/rabitq_params.h @@ -0,0 +1,40 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include + +namespace zvec { +namespace core { + +// RaBitQ Converter parameters +static const std::string PARAM_RABITQ_NUM_CLUSTERS( + "proxima.rabitq.num_clusters"); +static const std::string PARAM_RABITQ_TOTAL_BITS("proxima.rabitq.total_bits"); +static const std::string PARAM_RABITQ_METRIC_NAME("proxima.rabitq.metric_name"); +static const std::string PARAM_RABITQ_ROTATOR_TYPE( + "proxima.rabitq.rotator.type"); +static const std::string PARAM_RABITQ_SAMPLE_COUNT( + "proxima.rabitq.sample_count"); + +// Default values +static constexpr size_t kDefaultNumClusters = 16; +// 4-bit, 5-bit, and 7-bit quantization typically achieve 90%, 95%, and 99% +// recall, respectively—without accessing raw vectors for reranking +static constexpr size_t kDefaultRabitqTotalBits = 7; + + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/rabitq_reformer.cc b/src/core/algorithm/hnsw-rabitq/rabitq_reformer.cc new file mode 100644 index 00000000..0cf4ea3c --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/rabitq_reformer.cc @@ -0,0 +1,430 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rabitq_reformer.h" +#include +#include +#include +#include +#include +#include "zvec/core/framework/index_error.h" +#include "zvec/core/framework/index_factory.h" +#include "zvec/core/framework/index_meta.h" +#include "zvec/core/framework/index_storage.h" +#include "rabitq_converter.h" +#include "rabitq_utils.h" + +namespace zvec { +namespace core { + +RabitqReformer::~RabitqReformer() { + this->cleanup(); +} + +int RabitqReformer::init(const ailego::Params ¶ms) { + std::string metric_name = params.get_as_string(PARAM_RABITQ_METRIC_NAME); + if (metric_name == "SquaredEuclidean") { + metric_type_ = rabitqlib::METRIC_L2; + } else if (metric_name == "InnerProduct") { + metric_type_ = rabitqlib::METRIC_IP; + } else if (metric_name == "Cosine") { + metric_type_ = rabitqlib::METRIC_IP; + } else { + LOG_ERROR("Unsupported metric name: %s", metric_name.c_str()); + return IndexError_InvalidArgument; + } + LOG_DEBUG("Rabitq reformer init done. metric_name=%s metric_type=%d", + metric_name.c_str(), static_cast(metric_type_)); + return 0; +} + +int RabitqReformer::cleanup() { + centroids_.clear(); + rotated_centroids_.clear(); + centroid_seeker_.reset(); + centroid_features_.reset(); + loaded_ = false; + rotator_.reset(); + return 0; +} + +int RabitqReformer::unload() { + return this->cleanup(); +} + +int RabitqReformer::load(IndexStorage::Pointer storage) { + if (!storage) { + LOG_ERROR("Invalid storage for load"); + return IndexError_InvalidArgument; + } + + auto segment = storage->get(RABITQ_CONVERER_SEG_ID); + if (!segment) { + LOG_ERROR("Failed to get segment %s", RABITQ_CONVERER_SEG_ID.c_str()); + return IndexError_InvalidFormat; + } + + size_t offset = 0; + RabitqConverterHeader header; + IndexStorage::MemoryBlock block; + size_t size = segment->read(offset, block, sizeof(header)); + if (size != sizeof(header)) { + LOG_ERROR("Failed to read header"); + return IndexError_InvalidFormat; + } + memcpy(&header, block.data(), sizeof(header)); + dimension_ = header.dim; + padded_dim_ = header.padded_dim; + ex_bits_ = header.ex_bits; + num_clusters_ = header.num_clusters; + rotator_type_ = static_cast(header.rotator_type); + offset += sizeof(header); + + // Read rotated centroids + size_t rotated_centroids_size = + sizeof(float) * header.num_clusters * header.padded_dim; + size = segment->read(offset, block, rotated_centroids_size); + if (size != rotated_centroids_size) { + LOG_ERROR("Failed to read rotated centroids"); + return IndexError_InvalidFormat; + } + rotated_centroids_.resize(header.num_clusters * header.padded_dim); + memcpy(rotated_centroids_.data(), block.data(), rotated_centroids_size); + offset += size; + + // Read original centroids (for LinearSeeker query) + size_t centroids_size = sizeof(float) * header.num_clusters * header.dim; + size = segment->read(offset, block, centroids_size); + if (size != centroids_size) { + LOG_ERROR("Failed to read centroids"); + return IndexError_InvalidFormat; + } + centroids_.resize(header.num_clusters * header.dim); + memcpy(centroids_.data(), block.data(), centroids_size); + offset += size; + + // Read rotator + size_t rotator_size = header.rotator_size; + size = segment->read(offset, block, rotator_size); + if (size != rotator_size) { + LOG_ERROR("Failed to read rotator"); + return IndexError_InvalidFormat; + } + // Create rotator + rotator_.reset( + rabitqlib::choose_rotator(dimension_, rotator_type_, padded_dim_)); + rotator_->load(reinterpret_cast(block.data())); + offset += size; + + this->query_config_ = rabitqlib::quant::faster_config( + padded_dim_, rabitqlib::SplitSingleQuery::kNumBits); + this->config_ = rabitqlib::quant::faster_config(padded_dim_, ex_bits_ + 1); + + size_bin_data_ = rabitqlib::BinDataMap::data_bytes(padded_dim_); + size_ex_data_ = + rabitqlib::ExDataMap::data_bytes(padded_dim_, ex_bits_); + + // Initialize LinearSeeker for centroid search + IndexMeta centroid_meta; + centroid_meta.set_data_type(IndexMeta::DataType::DT_FP32); + centroid_meta.set_dimension(static_cast(dimension_)); + centroid_meta.set_metric("SquaredEuclidean", 0, ailego::Params()); + + centroid_features_ = std::make_shared(); + centroid_features_->mount(centroid_meta, centroids_.data(), + centroids_.size() * sizeof(float)); + + centroid_seeker_ = std::make_shared(); + int ret = centroid_seeker_->init(centroid_meta); + if (ret != 0) { + LOG_ERROR("Failed to init centroid seeker. ret[%d]", ret); + return ret; + } + ret = centroid_seeker_->mount(centroid_features_); + if (ret != 0) { + LOG_ERROR("Failed to mount centroid features. ret[%d]", ret); + return ret; + } + + LOG_INFO( + "Rabitq reformer load done. dimension=%zu, padded_dim=%zu, " + "ex_bits=%zu, num_clusters=%zu, size_bin_data=%zu, size_ex_data=%zu " + "rotator_type=%d", + dimension_, padded_dim_, ex_bits_, num_clusters_, size_bin_data_, + size_ex_data_, (int)rotator_type_); + loaded_ = true; + return 0; +} + +int RabitqReformer::convert(const void *record, const IndexQueryMeta &rmeta, + std::string *out, IndexQueryMeta *ometa) const { + if (!loaded_) { + LOG_ERROR("Centroids not loaded yet"); + return IndexError_NoReady; + } + + if (!record || !out) { + LOG_ERROR("Invalid arguments for convert"); + return IndexError_InvalidArgument; + } + + // Validate input + // input may be transformed, require rmeta.dimension >= dimension_ + if (rmeta.dimension() < dimension_ || + rmeta.data_type() != IndexMeta::DataType::DT_FP32) { + LOG_ERROR("Invalid record meta: dimension=%zu, data_type=%d", + static_cast(rmeta.dimension()), (int)rmeta.data_type()); + return IndexError_InvalidArgument; + } + + // Find nearest centroid using LinearSeeker + Seeker::Document doc; + int ret = centroid_seeker_->seek(record, dimension_ * sizeof(float), &doc); + if (ret != 0) { + LOG_ERROR("Failed to seek centroid. ret[%d]", ret); + return ret; + } + uint32_t cluster_id = doc.index; + + // Quantize vector + const float *vector = static_cast(record); + ret = quantize_vector(vector, cluster_id, out); + if (ret != 0) { + LOG_ERROR("Failed to quantize vector"); + return ret; + } + + ometa->set_meta(IndexMeta::DataType::DT_INT8, (uint32_t)out->size()); + + return 0; +} + +int RabitqReformer::transform(const void *, const IndexQueryMeta &, + std::string *, IndexQueryMeta *) const { + return IndexError_NotImplemented; +} + +int RabitqReformer::transform_to_entity(const void *query, + HnswRabitqQueryEntity *entity) const { + if (!loaded_) { + LOG_ERROR("Centroids not loaded yet"); + return IndexError_NoReady; + } + + if (!query) { + LOG_ERROR("Invalid arguments for transform"); + return IndexError_InvalidArgument; + } + + const float *query_vector = static_cast(query); + + // Apply rotator + entity->rotated_query.resize(padded_dim_); + rotator_->rotate(query_vector, entity->rotated_query.data()); + + // Quantize query to 4-bit representation + // TODO: add IP support + + entity->query_wrapper = std::make_unique>( + entity->rotated_query.data(), padded_dim_, ex_bits_, query_config_, + metric_type_); + + // Preprocess - get the distance from query to all centroids + entity->q_to_centroids.resize(num_clusters_); + + if (metric_type_ == rabitqlib::METRIC_L2) { + for (size_t i = 0; i < num_clusters_; i++) { + entity->q_to_centroids[i] = std::sqrt(rabitqlib::euclidean_sqr( + entity->rotated_query.data(), + rotated_centroids_.data() + (i * padded_dim_), padded_dim_)); + } + } else if (metric_type_ == rabitqlib::METRIC_IP) { + entity->q_to_centroids.resize(num_clusters_ * 2); + // first half as g_add, second half as g_error + for (size_t i = 0; i < num_clusters_; i++) { + entity->q_to_centroids[i] = rabitqlib::dot_product( + entity->rotated_query.data(), + rotated_centroids_.data() + (i * padded_dim_), padded_dim_); + entity->q_to_centroids[i + num_clusters_] = + std::sqrt(rabitqlib::euclidean_sqr( + entity->rotated_query.data(), + rotated_centroids_.data() + (i * padded_dim_), padded_dim_)); + } + } + + return 0; +} + +size_t RabitqReformer::find_nearest_centroid(const float *vector) const { + size_t nearest_id = 0; + float min_dist = std::numeric_limits::max(); + + for (size_t i = 0; i < num_clusters_; ++i) { + const float *centroid = &rotated_centroids_[i * padded_dim_]; + + // Compute L2 distance + float dist = 0.0f; + for (size_t d = 0; d < dimension_; ++d) { + float diff = vector[d] - centroid[d]; + dist += diff * diff; + } + + if (dist < min_dist) { + min_dist = dist; + nearest_id = i; + } + } + + return nearest_id; +} + +int RabitqReformer::quantize_vector(const float *raw_vector, + uint32_t cluster_id, + std::string *quantized_data) const { + // Quantize raw data and initialize quantized data + std::vector rotated_data(padded_dim_); + rotator_->rotate(raw_vector, rotated_data.data()); + + // quantized format: + // cluster_id + bin_data + ex_data + quantized_data->resize(sizeof(cluster_id) + size_bin_data_ + size_ex_data_); + memcpy(&(*quantized_data)[0], &cluster_id, sizeof(cluster_id)); + int bin_data_offset = sizeof(cluster_id); + int ex_data_offset = bin_data_offset + size_bin_data_; + rabitqlib::quant::quantize_split_single( + rotated_data.data(), + rotated_centroids_.data() + (cluster_id * padded_dim_), padded_dim_, + ex_bits_, &(*quantized_data)[bin_data_offset], + &(*quantized_data)[ex_data_offset], metric_type_, config_); + + return 0; +} + +int RabitqReformer::dump(const IndexDumper::Pointer &dumper) { + if (!dumper) { + LOG_ERROR("Null dumper"); + return IndexError_InvalidArgument; + } + + if (!loaded_ || rotated_centroids_.empty() || centroids_.empty()) { + LOG_ERROR("No centroids to dump"); + return IndexError_NoReady; + } + + size_t dumped_size = 0; + int ret = dump_rabitq_centroids( + dumper, dimension_, padded_dim_, ex_bits_, num_clusters_, rotator_type_, + rotated_centroids_, centroids_, rotator_, &dumped_size); + if (ret != 0) { + return ret; + } + + LOG_INFO("RabitqReformer dump completed: %zu bytes", dumped_size); + return 0; +} + +int RabitqReformer::dump(const IndexStorage::Pointer &storage) { + if (!storage) { + LOG_ERROR("Null storage"); + return IndexError_InvalidArgument; + } + + if (!loaded_ || rotated_centroids_.empty() || centroids_.empty()) { + LOG_ERROR("No centroids to dump"); + return IndexError_NoReady; + } + + auto align_size = [](size_t size) -> size_t { + return (size + 0x1F) & (~0x1F); + }; + + // Calculate total size + size_t header_size = sizeof(RabitqConverterHeader); + size_t rotated_centroids_size = rotated_centroids_.size() * sizeof(float); + size_t centroids_size = centroids_.size() * sizeof(float); + size_t rotator_size = rotator_->dump_bytes(); + size_t data_size = + header_size + rotated_centroids_size + centroids_size + rotator_size; + size_t total_size = align_size(data_size); + + // Append segment + int ret = storage->append(RABITQ_CONVERER_SEG_ID, total_size); + if (ret != 0) { + LOG_ERROR("Failed to append segment %s, ret=%d", + RABITQ_CONVERER_SEG_ID.c_str(), ret); + return ret; + } + + // Get segment + auto segment = storage->get(RABITQ_CONVERER_SEG_ID); + if (!segment) { + LOG_ERROR("Failed to get segment %s", RABITQ_CONVERER_SEG_ID.c_str()); + return IndexError_ReadData; + } + + size_t offset = 0; + + // Write header + RabitqConverterHeader header; + header.dim = static_cast(dimension_); + header.padded_dim = static_cast(padded_dim_); + header.num_clusters = static_cast(num_clusters_); + header.ex_bits = static_cast(ex_bits_); + header.rotator_type = static_cast(rotator_type_); + header.rotator_size = static_cast(rotator_size); + size_t written = segment->write(offset, &header, header_size); + if (written != header_size) { + LOG_ERROR("Failed to write header: written=%zu, expected=%zu", written, + header_size); + return IndexError_WriteData; + } + offset += header_size; + + // Write rotated centroids + written = + segment->write(offset, rotated_centroids_.data(), rotated_centroids_size); + if (written != rotated_centroids_size) { + LOG_ERROR("Failed to write rotated centroids: written=%zu, expected=%zu", + written, rotated_centroids_size); + return IndexError_WriteData; + } + offset += rotated_centroids_size; + + // Write original centroids + written = segment->write(offset, centroids_.data(), centroids_size); + if (written != centroids_size) { + LOG_ERROR("Failed to write centroids: written=%zu, expected=%zu", written, + centroids_size); + return IndexError_WriteData; + } + offset += centroids_size; + + // Write rotator data + std::vector buffer(rotator_size); + rotator_->save(buffer.data()); + written = segment->write(offset, buffer.data(), rotator_size); + if (written != rotator_size) { + LOG_ERROR("Failed to write rotator data: written=%zu, expected=%zu", + written, rotator_size); + return IndexError_WriteData; + } + + LOG_INFO("RabitqReformer dump to storage completed: %zu bytes", data_size); + return 0; +} + +INDEX_FACTORY_REGISTER_REFORMER_ALIAS(RabitqReformer, RabitqReformer); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/rabitq_reformer.h b/src/core/algorithm/hnsw-rabitq/rabitq_reformer.h new file mode 100644 index 00000000..57df5be8 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/rabitq_reformer.h @@ -0,0 +1,116 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include "core/algorithm/cluster/linear_seeker.h" +#include "zvec/core/framework/index_features.h" +#include "zvec/core/framework/index_reformer.h" +#include "zvec/core/framework/index_storage.h" +#include "hnsw_rabitq_query_entity.h" +#include "rabitq_params.h" + +namespace zvec { +namespace core { + +/*! RaBitQ Reformer + * Loads centroids and performs query transformation and vector quantization + */ +class RabitqReformer : public IndexReformer { + public: + typedef std::shared_ptr Pointer; + + //! Constructor + RabitqReformer() = default; + + //! Destructor + ~RabitqReformer() override; + + //! Initialize Reformer + int init(const ailego::Params ¶ms) override; + + //! Cleanup Reformer + int cleanup(void) override; + + //! Load centroids from storage + int load(IndexStorage::Pointer storage) override; + + //! Unload index + int unload(void) override; + + //! Transform query - rotate and quantize for search + int transform(const void *query, const IndexQueryMeta &qmeta, + std::string *out, IndexQueryMeta *ometa) const override; + + //! Convert record - quantize vector for add operation + int convert(const void *record, const IndexQueryMeta &rmeta, std::string *out, + IndexQueryMeta *ometa) const override; + + //! Dump reformer into dumper + int dump(const IndexDumper::Pointer &dumper); + + //! Dump reformer into storage + int dump(const IndexStorage::Pointer &dumper); + + int transform_to_entity(const void *query, + HnswRabitqQueryEntity *entity) const; + + size_t num_clusters() const { + return num_clusters_; + } + + rabitqlib::MetricType rabitq_metric_type() const { + return metric_type_; + } + + private: + //! Find nearest centroid for a vector + size_t find_nearest_centroid(const float *vector) const; + + //! Quantize a single vector + int quantize_vector(const float *raw_vector, uint32_t cluster_id, + std::string *quantized_data) const; + + private: + // RaBitQ parameters + size_t num_clusters_{0}; + size_t ex_bits_{0}; + size_t dimension_{0}; + size_t padded_dim_{0}; + + // Original centroids: num_clusters * dimension (for LinearSeeker query) + std::vector centroids_; + // Rotated centroids: num_clusters * padded_dim (for quantization) + std::vector rotated_centroids_; + + // Rotator for vector transformation + rabitqlib::RotatorType rotator_type_{rabitqlib::RotatorType::FhtKacRotator}; + std::unique_ptr> rotator_; + rabitqlib::quant::RabitqConfig query_config_; + rabitqlib::quant::RabitqConfig config_; + rabitqlib::MetricType metric_type_{rabitqlib::METRIC_L2}; + size_t size_bin_data_{0}; + size_t size_ex_data_{0}; + + // LinearSeeker for centroid search + LinearSeeker::Pointer centroid_seeker_; + CoherentIndexFeatures::Pointer centroid_features_; + + bool loaded_{false}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/rabitq_utils.cc b/src/core/algorithm/hnsw-rabitq/rabitq_utils.cc new file mode 100644 index 00000000..58a3a4d9 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/rabitq_utils.cc @@ -0,0 +1,115 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rabitq_utils.h" +#include +#include +#include "zvec/core/framework/index_error.h" +#include "zvec/core/framework/index_logger.h" + +namespace zvec { +namespace core { + +int dump_rabitq_centroids( + const IndexDumper::Pointer &dumper, size_t dimension, size_t padded_dim, + size_t ex_bits, size_t num_clusters, rabitqlib::RotatorType rotator_type, + const std::vector &rotated_centroids, + const std::vector ¢roids, + const std::unique_ptr> &rotator, + size_t *out_dumped_size) { + auto align_size = [](size_t size) -> size_t { + return (size + 0x1F) & (~0x1F); + }; + + uint32_t crc = 0; + size_t dumped_size = 0; + + // Write header + RabitqConverterHeader header; + header.dim = static_cast(dimension); + header.padded_dim = static_cast(padded_dim); + header.num_clusters = static_cast(num_clusters); + header.ex_bits = static_cast(ex_bits); + header.rotator_type = static_cast(rotator_type); + header.rotator_size = static_cast(rotator->dump_bytes()); + size_t size = dumper->write(&header, sizeof(header)); + if (size != sizeof(header)) { + LOG_ERROR("Failed to write header: written=%zu, expected=%zu", size, + sizeof(header)); + return IndexError_WriteData; + } + crc = ailego::Crc32c::Hash(&header, sizeof(header), crc); + dumped_size += size; + + // Write rotated centroids + size = dumper->write(rotated_centroids.data(), + rotated_centroids.size() * sizeof(float)); + if (size != rotated_centroids.size() * sizeof(float)) { + LOG_ERROR("Failed to write rotated centroids: written=%zu, expected=%zu", + size, rotated_centroids.size() * sizeof(float)); + return IndexError_WriteData; + } + crc = ailego::Crc32c::Hash(rotated_centroids.data(), + rotated_centroids.size() * sizeof(float), crc); + dumped_size += size; + + // Write original centroids + size = dumper->write(centroids.data(), centroids.size() * sizeof(float)); + if (size != centroids.size() * sizeof(float)) { + LOG_ERROR("Failed to write centroids: written=%zu, expected=%zu", size, + centroids.size() * sizeof(float)); + return IndexError_WriteData; + } + crc = ailego::Crc32c::Hash(centroids.data(), centroids.size() * sizeof(float), + crc); + dumped_size += size; + + // Write rotator data + std::vector buffer(rotator->dump_bytes()); + rotator->save(buffer.data()); + size = dumper->write(buffer.data(), buffer.size()); + if (size != buffer.size()) { + LOG_ERROR("Failed to write rotator data: written=%zu, expected=%zu", size, + buffer.size()); + return IndexError_WriteData; + } + crc = ailego::Crc32c::Hash(buffer.data(), buffer.size(), crc); + dumped_size += size; + + // Write padding + size_t padding_size = align_size(dumped_size) - dumped_size; + if (padding_size > 0) { + std::string padding(padding_size, '\0'); + if (dumper->write(padding.data(), padding_size) != padding_size) { + LOG_ERROR("Append padding failed, size %lu", padding_size); + return IndexError_WriteData; + } + } + + int ret = + dumper->append(RABITQ_CONVERER_SEG_ID, dumped_size, padding_size, crc); + if (ret != 0) { + LOG_ERROR("Dump segment %s meta failed, ret=%d", + RABITQ_CONVERER_SEG_ID.c_str(), ret); + return ret; + } + + if (out_dumped_size) { + *out_dumped_size = dumped_size; + } + return 0; +} + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/hnsw-rabitq/rabitq_utils.h b/src/core/algorithm/hnsw-rabitq/rabitq_utils.h new file mode 100644 index 00000000..035676b6 --- /dev/null +++ b/src/core/algorithm/hnsw-rabitq/rabitq_utils.h @@ -0,0 +1,53 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include "zvec/core/framework/index_dumper.h" + +namespace zvec { +namespace core { + +inline const std::string RABITQ_CONVERER_SEG_ID{"rabitq.converter"}; + +struct RabitqConverterHeader { + uint32_t num_clusters; + uint32_t dim; + uint32_t padded_dim; + uint32_t rotator_size; + uint8_t ex_bits; + uint8_t rotator_type; + uint8_t padding[2]; + uint32_t reserve[3]; + + RabitqConverterHeader() { + memset(this, 0, sizeof(RabitqConverterHeader)); + } +}; +static_assert(sizeof(RabitqConverterHeader) % 32 == 0, + "RabitqConverterHeader must be aligned with 32 bytes"); + +// Common dump implementation for RabitqConverter and RabitqReformer +int dump_rabitq_centroids( + const IndexDumper::Pointer &dumper, size_t dimension, size_t padded_dim, + size_t ex_bits, size_t num_clusters, rabitqlib::RotatorType rotator_type, + const std::vector &rotated_centroids, + const std::vector ¢roids, + const std::unique_ptr> &rotator, + size_t *out_dumped_size = nullptr); + +} // namespace core +} // namespace zvec diff --git a/src/core/interface/index.cc b/src/core/interface/index.cc index 038f67d4..40d99deb 100644 --- a/src/core/interface/index.cc +++ b/src/core/interface/index.cc @@ -134,6 +134,14 @@ int Index::CreateAndInitConverterReformer(const QuantizerParam ¶m, return core::IndexError_Unsupported; } break; + case QuantizerType::kRabitq: + if (index_param.data_type == DataType::DT_FP32) { + converter_name = "CosineNormalizeConverter"; + } else { + LOG_ERROR("Unsupported data type: "); + return core::IndexError_Unsupported; + } + break; case QuantizerType::kFP16: converter_name = "CosineFp16Converter"; break; @@ -160,6 +168,9 @@ int Index::CreateAndInitConverterReformer(const QuantizerParam ¶m, case QuantizerType::kInt4: converter_name = "Int4StreamingConverter"; break; + case QuantizerType::kRabitq: + // no converter here + return 0; default: LOG_ERROR("Unsupported quantizer type: "); return core::IndexError_Unsupported; @@ -800,4 +811,30 @@ int Index::_get_coarse_search_topk( return floor(search_param->topk * scale_factor); } +std::string Index::get_metric_name(MetricType metric_type, bool is_sparse) { + if (is_sparse) { + switch (metric_type) { + case MetricType::kInnerProduct: + return "InnerProductSparse"; + case MetricType::kMIPSL2sq: + return "MipsSquaredEuclideanSparse"; + default: + return ""; + } + } else { + switch (metric_type) { + case MetricType::kL2sq: + return "SquaredEuclidean"; + case MetricType::kInnerProduct: + return "InnerProduct"; + case MetricType::kCosine: + return "Cosine"; + case MetricType::kMIPSL2sq: + return "MipsSquaredEuclidean"; + default: + return ""; + } + } +} + } // namespace zvec::core_interface diff --git a/src/core/interface/index_factory.cc b/src/core/interface/index_factory.cc index 699c9ce0..0d815728 100644 --- a/src/core/interface/index_factory.cc +++ b/src/core/interface/index_factory.cc @@ -45,6 +45,8 @@ Index::Pointer IndexFactory::CreateAndInitIndex(const BaseIndexParam ¶m) { ptr = std::make_shared(); } else if (param.index_type == IndexType::kIVF) { ptr = std::make_shared(); + } else if (param.index_type == IndexType::kHNSWRabitq) { + ptr = std::make_shared(); } else { LOG_ERROR("Unsupported index type: "); return nullptr; @@ -104,6 +106,15 @@ BaseIndexParam::Pointer IndexFactory::DeserializeIndexParamFromJson( } return param; } + case IndexType::kHNSWRabitq: { + HNSWRabitqIndexParam::Pointer param = + std::make_shared(); + if (!param->DeserializeFromJson(json_str)) { + LOG_ERROR("Failed to deserialize hnsqrabitq index param"); + return nullptr; + } + return param; + } default: LOG_ERROR("Unsupported index type: %s", magic_enum::enum_name(index_type).data()); @@ -151,6 +162,11 @@ std::string IndexFactory::QueryParamSerializeToJson(const QueryParamType ¶m, // ailego::JsonValue(QueryParamSerializeToJson(param.l1QueryParam))); // json_obj.set("l2QueryParam", // ailego::JsonValue(QueryParamSerializeToJson(param.l2QueryParam))); + } else if constexpr (std::is_same_v) { + if (!omit_empty_value || param.ef_search != 0) { + json_obj.set("ef_search", ailego::JsonValue(param.ef_search)); + } + index_type = IndexType::kHNSWRabitq; } json_obj.set("index_type", @@ -245,6 +261,17 @@ typename QueryParamType::Pointer IndexFactory::QueryParamDeserializeFromJson( return nullptr; } return param; + } else if (index_type == IndexType::kHNSWRabitq) { + auto param = std::make_shared(); + if (!parse_common_fields(param)) { + return nullptr; + } + if (!extract_value_from_json(json_obj, "ef_search", param->ef_search, + tmp_json_value)) { + LOG_ERROR("Failed to deserialize ef_search"); + return nullptr; + } + return param; } else { LOG_ERROR("Unsupported index type: %s", magic_enum::enum_name(index_type).data()); @@ -268,6 +295,12 @@ typename QueryParamType::Pointer IndexFactory::QueryParamDeserializeFromJson( LOG_ERROR("Failed to deserialize nprobe"); return nullptr; } + } else if constexpr (std::is_same_v) { + if (!extract_value_from_json(json_obj, "ef_search", param->ef_search, + tmp_json_value)) { + LOG_ERROR("Failed to deserialize ef_search"); + return nullptr; + } } else { LOG_ERROR("Unsupported index type: %s", magic_enum::enum_name(index_type).data()); diff --git a/src/core/interface/index_param.cc b/src/core/interface/index_param.cc index 71e40123..e21256ba 100644 --- a/src/core/interface/index_param.cc +++ b/src/core/interface/index_param.cc @@ -141,6 +141,39 @@ bool HNSWIndexParam::DeserializeFromJsonObject( return true; } +bool HNSWRabitqIndexParam::DeserializeFromJsonObject( + const ailego::JsonObject &json_obj) { + if (!BaseIndexParam::DeserializeFromJsonObject(json_obj)) { + return false; + } + + if (index_type != IndexType::kHNSWRabitq) { + LOG_ERROR("index_type is not kHNSWRabitq"); + return false; + } + + DESERIALIZE_VALUE_FIELD(json_obj, m); + DESERIALIZE_VALUE_FIELD(json_obj, ef_construction); + DESERIALIZE_VALUE_FIELD(json_obj, total_bits); + DESERIALIZE_VALUE_FIELD(json_obj, num_clusters); + DESERIALIZE_VALUE_FIELD(json_obj, sample_count); + + return true; +} + +ailego::JsonObject HNSWRabitqIndexParam::SerializeToJsonObject( + bool omit_empty_value) const { + auto json_obj = BaseIndexParam::SerializeToJsonObject(omit_empty_value); + json_obj.set("m", ailego::JsonValue(m)); + json_obj.set("ef_construction", ailego::JsonValue(ef_construction)); + json_obj.set("total_bits", ailego::JsonValue(total_bits)); + json_obj.set("num_clusters", ailego::JsonValue(num_clusters)); + if (!omit_empty_value || sample_count != 0) { + json_obj.set("sample_count", ailego::JsonValue(sample_count)); + } + return json_obj; +} + ailego::JsonObject QuantizerParam::SerializeToJsonObject( bool omit_empty_value) const { ailego::JsonObject json_obj; diff --git a/src/core/interface/indexes/hnsw_rabitq_index.cc b/src/core/interface/indexes/hnsw_rabitq_index.cc new file mode 100644 index 00000000..894e6e89 --- /dev/null +++ b/src/core/interface/indexes/hnsw_rabitq_index.cc @@ -0,0 +1,133 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#if RABITQ_SUPPORTED +#include "algorithm/hnsw-rabitq/hnsw_rabitq_params.h" +#include "algorithm/hnsw-rabitq/hnsw_rabitq_streamer.h" +#include "algorithm/hnsw-rabitq/rabitq_params.h" +#endif + +namespace zvec::core_interface { + +int HNSWRabitqIndex::CreateAndInitStreamer(const BaseIndexParam ¶m) { +#if !RABITQ_SUPPORTED + LOG_ERROR("RaBitQ is not supported on this platform (Linux x86_64 only)"); + return core::IndexError_Unsupported; +#else + param_ = dynamic_cast(param); + + if (is_sparse_) { + LOG_ERROR("Sparse index is not supported"); + return core::IndexError_Runtime; + } + + // validate parameters + param_.ef_construction = std::max(1, std::min(2048, param_.ef_construction)); + param_.m = std::max(5, std::min(1024, param_.m)); + + proxima_index_params_.set(core::PARAM_HNSW_RABITQ_STREAMER_EFCONSTRUCTION, + param_.ef_construction); + proxima_index_params_.set(core::PARAM_HNSW_RABITQ_STREAMER_MAX_NEIGHBOR_COUNT, + param_.m); + proxima_index_params_.set(core::PARAM_HNSW_RABITQ_STREAMER_GET_VECTOR_ENABLE, + true); + proxima_index_params_.set(core::PARAM_HNSW_RABITQ_STREAMER_EF, + kDefaultHnswEfSearch); + proxima_index_params_.set(core::PARAM_HNSW_RABITQ_STREAMER_USE_ID_MAP, + param_.use_id_map); + proxima_index_params_.set(core::PARAM_HNSW_RABITQ_GENERAL_DIMENSION, + input_vector_meta_.dimension()); + proxima_index_params_.set(core::PARAM_RABITQ_TOTAL_BITS, param_.total_bits); + // num_clusters, sample_count are parameters for rabitq converter + // proxima_index_params_.set(core::PARAM_RABITQ_NUM_CLUSTERS, + // param_.num_clusters); + + auto streamer = std::make_shared(); + streamer->set_provider(param_.provider); + streamer->set_reformer(param_.reformer); + streamer_ = streamer; + + if (ailego_unlikely(!streamer_)) { + LOG_ERROR("Failed to create HnswRabitqStreamer"); + return core::IndexError_Runtime; + } + if (ailego_unlikely( + streamer_->init(proxima_index_meta_, proxima_index_params_) != 0)) { + LOG_ERROR("Failed to init HnswRabitqStreamer"); + return core::IndexError_Runtime; + } + return 0; +#endif // RABITQ_SUPPORTED +} + +int HNSWRabitqIndex::_prepare_for_search( + const VectorData & /*vector_data*/, + const BaseIndexQueryParam::Pointer &search_param, + core::IndexContext::Pointer &context) { +#if !RABITQ_SUPPORTED + LOG_ERROR("RaBitQ is not supported on this platform (Linux x86_64 only)"); + return core::IndexError_Unsupported; +#else + const auto &hnsw_search_param = + std::dynamic_pointer_cast(search_param); + + if (ailego_unlikely(!hnsw_search_param)) { + LOG_ERROR("Invalid search param type, expected HNSWRabitqQueryParam"); + return core::IndexError_Runtime; + } + + if (0 >= hnsw_search_param->ef_search || + hnsw_search_param->ef_search > 2048) { + LOG_ERROR( + "ef_search must be greater than 0 and less than or equal to 2048."); + return core::IndexError_Runtime; + } + + context->set_topk(hnsw_search_param->topk); + context->set_fetch_vector(hnsw_search_param->fetch_vector); + if (hnsw_search_param->filter) { + context->set_filter(std::move(*hnsw_search_param->filter)); + } + if (hnsw_search_param->radius > 0.0f) { + context->set_threshold(hnsw_search_param->radius); + } + ailego::Params params; + const int real_search_ef = + std::max(1u, std::min(2048u, hnsw_search_param->ef_search)); + params.set(core::PARAM_HNSW_RABITQ_STREAMER_EF, real_search_ef); + context->update(params); + return 0; +#endif // RABITQ_SUPPORTED +} + +int HNSWRabitqIndex::_get_coarse_search_topk( + const BaseIndexQueryParam::Pointer &search_param) { +#if !RABITQ_SUPPORTED + LOG_ERROR("RaBitQ is not supported on this platform (Linux x86_64 only)"); + return -1; +#else + const auto &hnsw_search_param = + std::dynamic_pointer_cast(search_param); + + auto ret = std::max(search_param->topk, hnsw_search_param->ef_search); + return ret; +#endif // RABITQ_SUPPORTED +} + + +} // namespace zvec::core_interface diff --git a/src/db/index/column/vector_column/engine_helper.hpp b/src/db/index/column/vector_column/engine_helper.hpp index de1cfc6c..5d53a25c 100644 --- a/src/db/index/column/vector_column/engine_helper.hpp +++ b/src/db/index/column/vector_column/engine_helper.hpp @@ -20,6 +20,8 @@ #include #include #include +#include "zvec/db/index_params.h" +#include "zvec/db/type.h" #include "vector_column_params.h" @@ -162,6 +164,25 @@ class ProximaEngineHelper { return std::move(hnsw_query_param); } + case IndexType::HNSW_RABITQ: { + auto hnsw_query_param_result = + _build_common_query_param( + query_params); + if (!hnsw_query_param_result.has_value()) { + return tl::make_unexpected(Status::InvalidArgument( + "failed to build query param: " + + hnsw_query_param_result.error().message())); + } + auto &hnsw_query_param = hnsw_query_param_result.value(); + if (query_params.query_params) { + auto db_hnsw_rabitq_query_params = + dynamic_cast( + query_params.query_params.get()); + hnsw_query_param->ef_search = db_hnsw_rabitq_query_params->ef(); + } + return std::move(hnsw_query_param); + } + case IndexType::IVF: { auto ivf_query_param_result = _build_common_query_param( @@ -212,6 +233,8 @@ class ProximaEngineHelper { return core_interface::QuantizerType::kInt8; case QuantizeType::INT4: return core_interface::QuantizerType::kInt4; + case QuantizeType::RABITQ: + return core_interface::QuantizerType::kRabitq; default: return tl::make_unexpected( Status::InvalidArgument("unsupported quantize type")); @@ -244,6 +267,9 @@ class ProximaEngineHelper { _build_common_index_param(const FieldSchema &field_schema) { auto db_index_params = dynamic_cast( field_schema.index_params().get()); + if (db_index_params == nullptr) { + return tl::make_unexpected(Status::InvalidArgument("bad_cast")); + } auto index_param_builder = std::make_shared(); // db will ensure the id is consecutive @@ -322,6 +348,32 @@ class ProximaEngineHelper { return index_param_builder->Build(); } + case IndexType::HNSW_RABITQ: { + auto index_param_builder_result = _build_common_index_param< + HnswRabitqIndexParams, core_interface::HNSWRabitqIndexParamBuilder>( + field_schema); + if (!index_param_builder_result.has_value()) { + return tl::make_unexpected(Status::InvalidArgument( + "failed to build index param: " + + index_param_builder_result.error().message())); + } + auto index_param_builder = index_param_builder_result.value(); + + auto db_index_params = dynamic_cast( + field_schema.index_params().get()); + index_param_builder->WithM(db_index_params->m()); + index_param_builder->WithEFConstruction( + db_index_params->ef_construction()); + index_param_builder->WithTotalBits(db_index_params->total_bits()); + index_param_builder->WithNumClusters(db_index_params->num_clusters()); + index_param_builder->WithSampleCount(db_index_params->sample_count()); + index_param_builder->WithProvider( + db_index_params->raw_vector_provider()); + index_param_builder->WithReformer(db_index_params->rabitq_reformer()); + + return index_param_builder->Build(); + } + case IndexType::IVF: { auto index_param_builder_result = _build_common_index_param< IVFIndexParams, core_interface::IVFIndexParamBuilder>(field_schema); diff --git a/src/db/index/column/vector_column/vector_column_indexer.h b/src/db/index/column/vector_column/vector_column_indexer.h index 80766e1d..0006080e 100644 --- a/src/db/index/column/vector_column/vector_column_indexer.h +++ b/src/db/index/column/vector_column/vector_column_indexer.h @@ -26,6 +26,7 @@ #include "db/common/typedef.h" #include "db/index/column/common/index_results.h" #include "db/index/common/meta.h" +#include "zvec/core/framework/index_provider.h" #include "vector_column_params.h" #include "vector_index_results.h" @@ -88,6 +89,9 @@ class VectorColumnIndexer { // Result BatchFetch(const std::vector &doc_ids) // const; + core::IndexProvider::Pointer create_index_provider() const { + return index->create_index_provider(); + } public: std::string index_file_path() const { diff --git a/src/db/index/common/proto_converter.cc b/src/db/index/common/proto_converter.cc index 16516c55..46eb93f5 100644 --- a/src/db/index/common/proto_converter.cc +++ b/src/db/index/common/proto_converter.cc @@ -37,6 +37,32 @@ proto::HnswIndexParams ProtoConverter::ToPb(const HnswIndexParams *params) { return params_pb; } +// HnswRabitqIndexParams +HnswRabitqIndexParams::OPtr ProtoConverter::FromPb( + const proto::HnswRabitqIndexParams ¶ms_pb) { + auto params = std::make_shared( + MetricTypeCodeBook::Get(params_pb.base().metric_type()), + params_pb.total_bits(), params_pb.num_clusters(), params_pb.m(), + params_pb.ef_construction(), params_pb.sample_count()); + + return params; +} + +proto::HnswRabitqIndexParams ProtoConverter::ToPb( + const HnswRabitqIndexParams *params) { + proto::HnswRabitqIndexParams params_pb; + params_pb.mutable_base()->set_metric_type( + MetricTypeCodeBook::Get(params->metric_type())); + params_pb.mutable_base()->set_quantize_type( + QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.set_m(params->m()); + params_pb.set_ef_construction(params->ef_construction()); + params_pb.set_total_bits(params->total_bits()); + params_pb.set_num_clusters(params->num_clusters()); + params_pb.set_sample_count(params->sample_count()); + return params_pb; +} + // FlatIndexParams FlatIndexParams::OPtr ProtoConverter::FromPb( const proto::FlatIndexParams ¶ms_pb) { @@ -157,6 +183,8 @@ IndexParams::Ptr ProtoConverter::FromPb(const proto::IndexParams ¶ms_pb) { return ProtoConverter::FromPb(params_pb.ivf()); } else if (params_pb.has_flat()) { return ProtoConverter::FromPb(params_pb.flat()); + } else if (params_pb.has_hnsw_rabitq()) { + return ProtoConverter::FromPb(params_pb.hnsw_rabitq()); } return nullptr; @@ -211,6 +239,14 @@ proto::IndexParams ProtoConverter::ToPb(const IndexParams *params) { } break; } + case IndexType::HNSW_RABITQ: { + auto hnsw_rabitq_params = + dynamic_cast(params); + if (hnsw_rabitq_params) { + params_pb.mutable_hnsw_rabitq()->CopyFrom( + ProtoConverter::ToPb(hnsw_rabitq_params)); + } + } default: break; } diff --git a/src/db/index/common/proto_converter.h b/src/db/index/common/proto_converter.h index 48e17016..ad96007a 100644 --- a/src/db/index/common/proto_converter.h +++ b/src/db/index/common/proto_converter.h @@ -25,6 +25,11 @@ struct ProtoConverter { static proto::HnswIndexParams ToPb(const HnswIndexParams *params); + // HnswRabitqIndexParams + static HnswRabitqIndexParams::OPtr FromPb( + const proto::HnswRabitqIndexParams ¶ms_pb); + static proto::HnswRabitqIndexParams ToPb(const HnswRabitqIndexParams *params); + // FlatIndexParams static FlatIndexParams::OPtr FromPb(const proto::FlatIndexParams ¶ms_pb); static proto::FlatIndexParams ToPb(const FlatIndexParams *params); diff --git a/src/db/index/common/schema.cc b/src/db/index/common/schema.cc index 02789c61..b087c678 100644 --- a/src/db/index/common/schema.cc +++ b/src/db/index/common/schema.cc @@ -29,7 +29,8 @@ namespace zvec { std::unordered_map> quantize_type_map = { {DataType::VECTOR_FP32, - {QuantizeType::FP16, QuantizeType::INT4, QuantizeType::INT8}}, + {QuantizeType::FP16, QuantizeType::INT4, QuantizeType::INT8, + QuantizeType::RABITQ}}, // {DataType::VECTOR_FP64, {QuantizeType::FP16}}, {DataType::SPARSE_VECTOR_FP32, {QuantizeType::FP16}}, }; @@ -46,7 +47,7 @@ std::unordered_set support_sparse_vector_type = { }; std::unordered_set support_dense_vector_index = { - IndexType::FLAT, IndexType::HNSW, IndexType::IVF}; + IndexType::FLAT, IndexType::HNSW, IndexType::HNSW_RABITQ, IndexType::IVF}; std::unordered_set support_sparse_vector_index = {IndexType::FLAT, IndexType::HNSW}; @@ -126,6 +127,21 @@ Status FieldSchema::validate() const { } } + if (index_params_->type() == IndexType::HNSW_RABITQ) { + if (data_type_ != DataType::VECTOR_FP32) { + return Status::InvalidArgument( + "schema validate failed: HNSW_RABITQ index only support FP32 " + "data types"); + } + auto metric_type = vector_index_params->metric_type(); + if (metric_type != MetricType::L2 && metric_type != MetricType::IP && + metric_type != MetricType::COSINE) { + return Status::InvalidArgument( + "schema validate failed: HNSW_RABITQ index only support " + "L2/IP/COSINE metric"); + } + } + if (vector_index_params->quantize_type() != QuantizeType::UNDEFINED) { auto iter = quantize_type_map.find(data_type_); if (iter == quantize_type_map.end()) { diff --git a/src/db/index/common/type_helper.h b/src/db/index/common/type_helper.h index 33440dc5..33d2ee34 100644 --- a/src/db/index/common/type_helper.h +++ b/src/db/index/common/type_helper.h @@ -27,6 +27,8 @@ struct IndexTypeCodeBook { switch (type) { case proto::IT_HNSW: return IndexType::HNSW; + case proto::IT_HNSW_RABITQ: + return IndexType::HNSW_RABITQ; case proto::IT_FLAT: return IndexType::FLAT; case proto::IT_IVF: @@ -44,6 +46,8 @@ struct IndexTypeCodeBook { switch (type) { case IndexType::HNSW: return proto::IT_HNSW; + case IndexType::HNSW_RABITQ: + return proto::IT_HNSW_RABITQ; case IndexType::FLAT: return proto::IT_FLAT; case IndexType::IVF: @@ -61,8 +65,12 @@ struct IndexTypeCodeBook { switch (type) { case IndexType::HNSW: return "HNSW"; - // case IndexType::SPARSE_HNSW: - // return "SPARSE_HNSW"; + case IndexType::HNSW_RABITQ: + return "HNSW_RABITQ"; + case IndexType::FLAT: + return "FLAT"; + case IndexType::IVF: + return "IVF"; case IndexType::INVERT: return "INVERT"; default: @@ -414,6 +422,8 @@ struct QuantizeTypeCodeBook { return QuantizeType::INT4; case proto::QuantizeType::QT_INT8: return QuantizeType::INT8; + case proto::QuantizeType::QT_RABITQ: + return QuantizeType::RABITQ; default: return QuantizeType::UNDEFINED; } @@ -427,6 +437,8 @@ struct QuantizeTypeCodeBook { return proto::QuantizeType::QT_INT4; case QuantizeType::INT8: return proto::QuantizeType::QT_INT8; + case QuantizeType::RABITQ: + return proto::QuantizeType::QT_RABITQ; default: return proto::QuantizeType::QT_UNDEFINED; } @@ -440,6 +452,8 @@ struct QuantizeTypeCodeBook { return "INT4"; case QuantizeType::INT8: return "INT8"; + case QuantizeType::RABITQ: + return "RABITQ"; default: return "UNDEFINED"; } diff --git a/src/db/index/segment/segment.cc b/src/db/index/segment/segment.cc index 517215a3..bd070011 100644 --- a/src/db/index/segment/segment.cc +++ b/src/db/index/segment/segment.cc @@ -37,11 +37,15 @@ #include #include #include +#if RABITQ_SUPPORTED +#include "core/algorithm/hnsw-rabitq/rabitq_params.h" +#endif #include "db/common/constants.h" #include "db/common/file_helper.h" #include "db/common/global_resource.h" #include "db/common/typedef.h" #include "db/index/column/inverted_column/inverted_indexer.h" +#include "db/index/column/vector_column/engine_helper.hpp" #include "db/index/column/vector_column/vector_column_indexer.h" #include "db/index/column/vector_column/vector_column_params.h" #include "db/index/common/index_filter.h" @@ -53,6 +57,11 @@ #include "db/index/storage/mmap_forward_store.h" #include "db/index/storage/store_helper.h" #include "db/index/storage/wal/wal_file.h" +#include "zvec/ailego/container/params.h" +#include "zvec/core/framework/index_factory.h" +#include "zvec/core/framework/index_meta.h" +#include "zvec/core/framework/index_provider.h" +#include "zvec/core/framework/index_reformer.h" #include "column_merging_reader.h" #include "sql_expr_parser.h" @@ -1646,6 +1655,8 @@ Status SegmentImpl::create_vector_index( auto original_index_params = std::dynamic_pointer_cast(field->index_params()); + core::IndexProvider::Pointer raw_vector_provider; + if (!(vector_index_params->metric_type() == original_index_params->metric_type() && vector_indexers_[column].size() == 1)) { @@ -1674,31 +1685,112 @@ Status SegmentImpl::create_vector_index( block.set_max_doc_id(meta()->max_doc_id()); block.set_doc_count(meta()->doc_count()); new_segment_meta->add_persisted_block(block); + if (vector_index_params->quantize_type() == QuantizeType::RABITQ) { + raw_vector_provider = vector_indexer.value()->create_index_provider(); + } + } else { + raw_vector_provider = + vector_indexers_[column][0]->create_index_provider(); } - auto quant_block_id = allocate_block_id(); - auto field_with_new_index_params = std::make_shared(*field); - field_with_new_index_params->set_index_params(index_params); + if (vector_index_params->quantize_type() != QuantizeType::RABITQ) { + auto quant_block_id = allocate_block_id(); + auto field_with_new_index_params = std::make_shared(*field); + field_with_new_index_params->set_index_params(index_params); - std::string index_file_path = FileHelper::MakeQuantizeVectorIndexPath( - path_, column, segment_meta_->id(), quant_block_id); - auto vector_indexer = merge_vector_indexer( - index_file_path, column, *field_with_new_index_params, concurrency); - if (!vector_indexer.has_value()) { - return vector_indexer.error(); - } + std::string index_file_path = FileHelper::MakeQuantizeVectorIndexPath( + path_, column, segment_meta_->id(), quant_block_id); + auto vector_indexer = merge_vector_indexer( + index_file_path, column, *field_with_new_index_params, concurrency); + if (!vector_indexer.has_value()) { + return vector_indexer.error(); + } - quant_vector_indexers->insert({column, vector_indexer.value()}); + quant_vector_indexers->insert({column, vector_indexer.value()}); - new_segment_meta->remove_vector_persisted_block(column, true); - BlockMeta block; - block.set_id(quant_block_id); - block.set_type(BlockType::VECTOR_INDEX_QUANTIZE); - block.set_columns({column}); - block.set_min_doc_id(meta()->min_doc_id()); - block.set_max_doc_id(meta()->max_doc_id()); - block.set_doc_count(meta()->doc_count()); - new_segment_meta->add_persisted_block(block); + new_segment_meta->remove_vector_persisted_block(column, true); + BlockMeta block; + block.set_id(quant_block_id); + block.set_type(BlockType::VECTOR_INDEX_QUANTIZE); + block.set_columns({column}); + block.set_min_doc_id(meta()->min_doc_id()); + block.set_max_doc_id(meta()->max_doc_id()); + block.set_doc_count(meta()->doc_count()); + new_segment_meta->add_persisted_block(block); + } else { +#if !RABITQ_SUPPORTED + LOG_ERROR("RaBitQ is not supported on this platform (Linux x86_64 only)"); + return Status::NotSupported("RabitQ is not supported on this platform"); +#else + // rabitq + auto rabitq_params = std::dynamic_pointer_cast( + vector_index_params->clone()); + if (!rabitq_params) { + return Status::InternalError("Expect HnswRabitqIndexParams"); + } + // train rabitq converter + auto converter = core::IndexFactory::CreateConverter("RabitqConverter"); + if (!converter) { + return Status::NotSupported("RabitqConverter not found"); + } + core::IndexMeta index_meta; + index_meta.set_meta( + ProximaEngineHelper::convert_to_engine_data_type(field->data_type()) + .value(), + // use field dimension + field->dimension()); + index_meta.set_metric( + core_interface::Index::get_metric_name( + ProximaEngineHelper::convert_to_engine_metric_type( + vector_index_params->metric_type()) + .value(), + false), + 0, ailego::Params{}); + ailego::Params converter_params; + converter_params.set(core::PARAM_RABITQ_TOTAL_BITS, + rabitq_params->total_bits()); + converter_params.set(core::PARAM_RABITQ_NUM_CLUSTERS, + rabitq_params->num_clusters()); + converter_params.set(core::PARAM_RABITQ_SAMPLE_COUNT, + rabitq_params->sample_count()); + if (int ret = converter->init(index_meta, converter_params); ret != 0) { + return Status::InternalError("Failed to init rabitq converter:", ret); + } + if (int ret = converter->train(raw_vector_provider); ret != 0) { + return Status::InternalError("Failed to train rabitq converter:", ret); + } + core::IndexReformer::Pointer reformer; + if (int ret = converter->to_reformer(&reformer); ret != 0) { + return Status::InternalError("Failed to to get rabitq reformer:", ret); + } + rabitq_params->set_rabitq_reformer(reformer); + rabitq_params->set_raw_vector_provider(raw_vector_provider); + + auto quant_block_id = allocate_block_id(); + auto field_with_new_index_params = std::make_shared(*field); + field_with_new_index_params->set_index_params(rabitq_params); + + std::string index_file_path = FileHelper::MakeQuantizeVectorIndexPath( + path_, column, segment_meta_->id(), quant_block_id); + auto vector_indexer = merge_vector_indexer( + index_file_path, column, *field_with_new_index_params, concurrency); + if (!vector_indexer.has_value()) { + return vector_indexer.error(); + } + + quant_vector_indexers->insert({column, vector_indexer.value()}); + + new_segment_meta->remove_vector_persisted_block(column, true); + BlockMeta block; + block.set_id(quant_block_id); + block.set_type(BlockType::VECTOR_INDEX_QUANTIZE); + block.set_columns({column}); + block.set_min_doc_id(meta()->min_doc_id()); + block.set_max_doc_id(meta()->max_doc_id()); + block.set_doc_count(meta()->doc_count()); + new_segment_meta->add_persisted_block(block); +#endif + } *segment_meta = new_segment_meta; } diff --git a/src/db/proto/zvec.proto b/src/db/proto/zvec.proto index a2b310d3..3c9d3331 100644 --- a/src/db/proto/zvec.proto +++ b/src/db/proto/zvec.proto @@ -56,6 +56,8 @@ enum IndexType { IT_IVF = 2; // Proxima FLAT Index IT_FLAT = 3; + // Proxima HNSW RABITQ Index + IT_HNSW_RABITQ = 4; // Invert Index IT_INVERT = 10; }; @@ -65,6 +67,7 @@ enum QuantizeType { QT_FP16 = 1; QT_INT8 = 2; QT_INT4 = 3; + QT_RABITQ = 4; }; enum MetricType { @@ -74,9 +77,11 @@ enum MetricType { MT_COSINE = 3; }; -message InvertIndexParams { bool enable_range_optimization = 1; }; +message InvertIndexParams { + bool enable_range_optimization = 1; +}; -message BaseIndexParams { +message BaseIndexParams { MetricType metric_type = 1; QuantizeType quantize_type = 2; }; @@ -85,7 +90,16 @@ message HnswIndexParams { BaseIndexParams base = 1; int32 m = 2; int32 ef_construction = 3; -}; +} + +message HnswRabitqIndexParams { + BaseIndexParams base = 1; + int32 m = 2; + int32 ef_construction = 3; + int32 total_bits = 4; + int32 num_clusters = 5; + int32 sample_count = 6; +} message FlatIndexParams { BaseIndexParams base = 1; @@ -104,6 +118,7 @@ message IndexParams { HnswIndexParams hnsw = 2; FlatIndexParams flat = 3; IVFIndexParams ivf = 4; + HnswRabitqIndexParams hnsw_rabitq = 5; }; }; @@ -131,11 +146,11 @@ enum BlockType { message BlockMeta { uint32 block_id = 1; - BlockType block_type = 2; // for getting filename prefix + BlockType block_type = 2; // for getting filename prefix uint64 min_doc_id = 3; uint64 max_doc_id = 4; uint64 doc_count = 5; - repeated string columns = 6; // columns contained in this block + repeated string columns = 6; // columns contained in this block }; // message AlterColumnMeta { diff --git a/src/include/zvec/core/framework/index_converter.h b/src/include/zvec/core/framework/index_converter.h index 272bc315..f25aa030 100644 --- a/src/include/zvec/core/framework/index_converter.h +++ b/src/include/zvec/core/framework/index_converter.h @@ -18,6 +18,7 @@ #include #include #include +#include "zvec/core/framework/index_reformer.h" namespace zvec { namespace core { @@ -219,6 +220,11 @@ class IndexConverter : public IndexModule { static int TrainTransformAndDump(const IndexConverter::Pointer &converter, IndexHolder::Pointer holder, const IndexDumper::Pointer &dumper); + + //! Convert to reformer + virtual int to_reformer(IndexReformer::Pointer *) { + return IndexError_NotImplemented; + } }; } // namespace core diff --git a/src/include/zvec/core/framework/index_holder.h b/src/include/zvec/core/framework/index_holder.h index 7bb9ba9c..f7f6bba8 100644 --- a/src/include/zvec/core/framework/index_holder.h +++ b/src/include/zvec/core/framework/index_holder.h @@ -398,13 +398,22 @@ class MultiPassNumericalIndexHolder : public IndexHolder { features_.reserve(size); } - private: - //! Disable them - MultiPassNumericalIndexHolder(void) = delete; + //! Get vector data pointer by index + const void *get_vector_by_index(size_t index) const { + if (index >= features_.size()) { + return nullptr; + } + return features_[index].second.data(); + } + protected: //! Members size_t dimension_{0}; std::vector>> features_; + + private: + //! Disable them + MultiPassNumericalIndexHolder(void) = delete; }; /*! One-Pass Binary Index Holder @@ -617,13 +626,22 @@ class MultiPassBinaryIndexHolder : public IndexHolder { features_.reserve(size); } - private: - //! Disable them - MultiPassBinaryIndexHolder(void) = delete; + //! Get vector data pointer by index + const void *get_vector_by_index(size_t index) const { + if (index >= features_.size()) { + return nullptr; + } + return features_[index].second.data(); + } + protected: //! Members size_t dimension_{0}; std::vector>> features_; + + private: + //! Disable them + MultiPassBinaryIndexHolder(void) = delete; }; /*! One-Pass Index Hybrid Holder diff --git a/src/include/zvec/core/framework/index_provider.h b/src/include/zvec/core/framework/index_provider.h index b5ce143b..66e34e5a 100644 --- a/src/include/zvec/core/framework/index_provider.h +++ b/src/include/zvec/core/framework/index_provider.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -71,5 +72,399 @@ struct IndexSparseProvider : IndexSparseHolder { virtual const std::string &owner_class(void) const = 0; }; +/*! Multi-Pass Numerical Index Provider + */ +template +class MultiPassNumericalIndexProvider : public IndexProvider { + public: + //! Constructor + explicit MultiPassNumericalIndexProvider(size_t dim) + : holder_(dim), owner_class_("MultiPassNumericalIndexProvider") {} + + //! Destructor + virtual ~MultiPassNumericalIndexProvider(void) {} + + //! Retrieve count of elements in holder + size_t count(void) const override { + return holder_.count(); + } + + //! Retrieve dimension + size_t dimension(void) const override { + return holder_.dimension(); + } + + //! Retrieve element size in bytes + size_t element_size(void) const override { + return holder_.element_size(); + } + + //! Create a new iterator + IndexHolder::Iterator::Pointer create_iterator(void) override { + return holder_.create_iterator(); + } + + //! Retrieve a vector using a primary key + const void *get_vector(const uint64_t key) const override { + auto it = indice_map_.find(key); + if (it == indice_map_.end()) { + return nullptr; + } + return holder_.get_vector_by_index(it->second); + } + + //! Retrieve a vector using a primary key + int get_vector(const uint64_t key, + IndexStorage::MemoryBlock &block) const override { + const void *data = get_vector(key); + if (data == nullptr) { + return IndexError_NoExist; + } + block.reset(const_cast(data)); + return 0; + } + + //! Retrieve the owner class + const std::string &owner_class(void) const override { + return owner_class_; + } + + //! Append an element into holder + bool emplace(uint64_t key, const ailego::NumericalVector &vec) { + if (!holder_.emplace(key, vec)) { + return false; + } + indice_map_[key] = static_cast(holder_.count() - 1); + return true; + } + + //! Append an element into holder + bool emplace(uint64_t key, ailego::NumericalVector &&vec) { + if (!holder_.emplace(key, std::move(vec))) { + return false; + } + indice_map_[key] = static_cast(holder_.count() - 1); + return true; + } + + private: + //! Members + MultiPassNumericalIndexHolder holder_; + std::unordered_map indice_map_; + std::string owner_class_; +}; + +/*! Multi-Pass Binary Index Provider + */ +template +class MultiPassBinaryIndexProvider : public IndexProvider { + public: + //! Constructor + explicit MultiPassBinaryIndexProvider(size_t dim) + : holder_(dim), owner_class_("MultiPassBinaryIndexProvider") {} + + //! Destructor + virtual ~MultiPassBinaryIndexProvider(void) {} + + //! Retrieve count of elements in holder + size_t count(void) const override { + return holder_.count(); + } + + //! Retrieve dimension + size_t dimension(void) const override { + return holder_.dimension(); + } + + //! Retrieve element size in bytes + size_t element_size(void) const override { + return holder_.element_size(); + } + + //! Create a new iterator + IndexHolder::Iterator::Pointer create_iterator(void) override { + return holder_.create_iterator(); + } + + //! Retrieve a vector using a primary key + const void *get_vector(const uint64_t key) const override { + auto it = indice_map_.find(key); + if (it == indice_map_.end()) { + return nullptr; + } + return holder_.get_vector_by_index(it->second); + } + + //! Retrieve a vector using a primary key + int get_vector(const uint64_t key, + IndexStorage::MemoryBlock &block) const override { + const void *data = get_vector(key); + if (data == nullptr) { + return IndexError_NoExist; + } + block.reset(const_cast(data)); + return 0; + } + + //! Retrieve the owner class + const std::string &owner_class(void) const override { + return owner_class_; + } + + //! Append an element into holder + bool emplace(uint64_t key, const ailego::BinaryVector &vec) { + if (!holder_.emplace(key, vec)) { + return false; + } + indice_map_[key] = static_cast(holder_.count() - 1); + return true; + } + + //! Append an element into holder + bool emplace(uint64_t key, ailego::BinaryVector &&vec) { + if (!holder_.emplace(key, std::move(vec))) { + return false; + } + indice_map_[key] = static_cast(holder_.count() - 1); + return true; + } + + private: + //! Members + MultiPassBinaryIndexHolder holder_; + std::unordered_map indice_map_; + std::string owner_class_; +}; + +/*! Multi-Pass Index Provider + */ +template +struct MultiPassIndexProvider; + +/*! Multi-Pass Index Provider (BINARY32) + */ +template <> +struct MultiPassIndexProvider + : public MultiPassBinaryIndexProvider { + //! Constructor + using MultiPassBinaryIndexProvider::MultiPassBinaryIndexProvider; + + //! Retrieve type information + IndexMeta::DataType data_type(void) const override { + return IndexMeta::DataType::DT_BINARY32; + } +}; + +/*! Multi-Pass Index Provider (BINARY64) + */ +template <> +struct MultiPassIndexProvider + : public MultiPassBinaryIndexProvider { + //! Constructor + using MultiPassBinaryIndexProvider::MultiPassBinaryIndexProvider; + + //! Retrieve type information + IndexMeta::DataType data_type(void) const override { + return IndexMeta::DataType::DT_BINARY64; + } +}; + +/*! Multi-Pass Index Provider (FP16) + */ +template <> +struct MultiPassIndexProvider + : public MultiPassNumericalIndexProvider { + //! Constructor + using MultiPassNumericalIndexProvider::MultiPassNumericalIndexProvider; + + //! Retrieve type information + IndexMeta::DataType data_type(void) const override { + return IndexMeta::DataType::DT_FP16; + } +}; + +/*! Multi-Pass Index Provider (FP32) + */ +template <> +struct MultiPassIndexProvider + : public MultiPassNumericalIndexProvider { + //! Constructor + using MultiPassNumericalIndexProvider::MultiPassNumericalIndexProvider; + + //! Retrieve type information + IndexMeta::DataType data_type(void) const override { + return IndexMeta::DataType::DT_FP32; + } +}; + +/*! Multi-Pass Index Provider (FP64) + */ +template <> +struct MultiPassIndexProvider + : public MultiPassNumericalIndexProvider { + //! Constructor + using MultiPassNumericalIndexProvider::MultiPassNumericalIndexProvider; + + //! Retrieve type information + IndexMeta::DataType data_type(void) const override { + return IndexMeta::DataType::DT_FP64; + } +}; + +/*! Multi-Pass Index Provider (INT8) + */ +template <> +struct MultiPassIndexProvider + : public MultiPassNumericalIndexProvider { + //! Constructor + using MultiPassNumericalIndexProvider::MultiPassNumericalIndexProvider; + + //! Retrieve type information + IndexMeta::DataType data_type(void) const override { + return IndexMeta::DataType::DT_INT8; + } +}; + +/*! Multi-Pass Index Provider (INT16) + */ +template <> +struct MultiPassIndexProvider + : public MultiPassNumericalIndexProvider { + //! Constructor + using MultiPassNumericalIndexProvider::MultiPassNumericalIndexProvider; + + //! Retrieve type information + IndexMeta::DataType data_type(void) const override { + return IndexMeta::DataType::DT_INT16; + } +}; + +/*! Convert IndexHolder to IndexProvider + * @param holder The IndexHolder to convert + * @return IndexProvider::Pointer + */ +inline IndexProvider::Pointer convert_holder_to_provider( + const IndexHolder::Pointer &holder) { + if (!holder) { + return nullptr; + } + + IndexMeta::DataType data_type = holder->data_type(); + size_t dimension = holder->dimension(); + + switch (data_type) { + case IndexMeta::DataType::DT_FP16: { + auto provider = std::make_shared< + MultiPassIndexProvider>(dimension); + auto iter = holder->create_iterator(); + while (iter->is_valid()) { + uint64_t key = iter->key(); + const ailego::Float16 *data = + static_cast(iter->data()); + ailego::NumericalVector vec(dimension); + std::memcpy(vec.data(), data, dimension * sizeof(ailego::Float16)); + provider->emplace(key, std::move(vec)); + iter->next(); + } + return provider; + } + + case IndexMeta::DataType::DT_FP32: { + auto provider = std::make_shared< + MultiPassIndexProvider>(dimension); + auto iter = holder->create_iterator(); + while (iter->is_valid()) { + uint64_t key = iter->key(); + const float *data = static_cast(iter->data()); + ailego::NumericalVector vec(dimension); + std::memcpy(vec.data(), data, dimension * sizeof(float)); + provider->emplace(key, std::move(vec)); + iter->next(); + } + return provider; + } + + case IndexMeta::DataType::DT_FP64: { + auto provider = std::make_shared< + MultiPassIndexProvider>(dimension); + auto iter = holder->create_iterator(); + while (iter->is_valid()) { + uint64_t key = iter->key(); + const double *data = static_cast(iter->data()); + ailego::NumericalVector vec(dimension); + std::memcpy(vec.data(), data, dimension * sizeof(double)); + provider->emplace(key, std::move(vec)); + iter->next(); + } + return provider; + } + + case IndexMeta::DataType::DT_INT8: { + auto provider = std::make_shared< + MultiPassIndexProvider>(dimension); + auto iter = holder->create_iterator(); + while (iter->is_valid()) { + uint64_t key = iter->key(); + const int8_t *data = static_cast(iter->data()); + ailego::NumericalVector vec(dimension); + std::memcpy(vec.data(), data, dimension * sizeof(int8_t)); + provider->emplace(key, std::move(vec)); + iter->next(); + } + return provider; + } + + case IndexMeta::DataType::DT_INT16: { + auto provider = std::make_shared< + MultiPassIndexProvider>(dimension); + auto iter = holder->create_iterator(); + while (iter->is_valid()) { + uint64_t key = iter->key(); + const int16_t *data = static_cast(iter->data()); + ailego::NumericalVector vec(dimension); + std::memcpy(vec.data(), data, dimension * sizeof(int16_t)); + provider->emplace(key, std::move(vec)); + iter->next(); + } + return provider; + } + + case IndexMeta::DataType::DT_BINARY32: { + auto provider = std::make_shared< + MultiPassIndexProvider>(dimension); + auto iter = holder->create_iterator(); + while (iter->is_valid()) { + uint64_t key = iter->key(); + const uint32_t *data = static_cast(iter->data()); + size_t binary_size = (dimension + 31) / 32; + ailego::BinaryVector vec(dimension); + std::memcpy(vec.data(), data, binary_size * sizeof(uint32_t)); + provider->emplace(key, std::move(vec)); + iter->next(); + } + return provider; + } + + case IndexMeta::DataType::DT_BINARY64: { + auto provider = std::make_shared< + MultiPassIndexProvider>(dimension); + auto iter = holder->create_iterator(); + while (iter->is_valid()) { + uint64_t key = iter->key(); + const uint64_t *data = static_cast(iter->data()); + size_t binary_size = (dimension + 63) / 64; + ailego::BinaryVector vec(dimension); + std::memcpy(vec.data(), data, binary_size * sizeof(uint64_t)); + provider->emplace(key, std::move(vec)); + iter->next(); + } + return provider; + } + + default: + return nullptr; + } +} + } // namespace core } // namespace zvec diff --git a/src/include/zvec/core/interface/constants.h b/src/include/zvec/core/interface/constants.h index 79d563bf..1bc61dce 100644 --- a/src/include/zvec/core/interface/constants.h +++ b/src/include/zvec/core/interface/constants.h @@ -23,5 +23,8 @@ constexpr static uint32_t kDefaultHnswNeighborCnt = 50; constexpr static uint32_t kDefaultHnswEfSearch = 300; +constexpr const uint32_t kDefaultRabitqTotalBits = 7; +constexpr const uint32_t kDefaultRabitqNumClusters = 16; + } // namespace zvec::core_interface \ No newline at end of file diff --git a/src/include/zvec/core/interface/index.h b/src/include/zvec/core/interface/index.h index 71258cb0..8634e390 100644 --- a/src/include/zvec/core/interface/index.h +++ b/src/include/zvec/core/interface/index.h @@ -31,6 +31,7 @@ #include #include #include +#include "zvec/core/framework/index_provider.h" namespace zvec::core_interface { @@ -154,6 +155,12 @@ class Index { return streamer_; } + core::IndexProvider::Pointer create_index_provider() const { + return streamer_->create_provider(); + } + + static std::string get_metric_name(MetricType metric_type, bool is_sparse); + protected: int _sparse_fetch(const uint32_t doc_id, VectorDataBuffer *vector_data_buffer); @@ -292,5 +299,22 @@ class HNSWIndex : public Index { HNSWIndexParam param_{}; }; +class HNSWRabitqIndex : public Index { + public: + HNSWRabitqIndex() = default; + + protected: + virtual int CreateAndInitStreamer(const BaseIndexParam ¶m) override; + + virtual int _prepare_for_search( + const VectorData &query, const BaseIndexQueryParam::Pointer &search_param, + core::IndexContext::Pointer &context) override; + int _get_coarse_search_topk( + const BaseIndexQueryParam::Pointer &search_param) override; + + private: + HNSWRabitqIndexParam param_{}; +}; + } // namespace zvec::core_interface diff --git a/src/include/zvec/core/interface/index_param.h b/src/include/zvec/core/interface/index_param.h index 98da5b12..0d7bf301 100644 --- a/src/include/zvec/core/interface/index_param.h +++ b/src/include/zvec/core/interface/index_param.h @@ -23,6 +23,7 @@ #include #include #include +#include "zvec/core/framework/index_framework.h" namespace zvec::core_interface { #define MAX_DIMENSION 65536 @@ -61,6 +62,7 @@ enum class IndexType { kFlat, kIVF, // it's actual a two-layer index kHNSW, + kHNSWRabitq, }; enum class IVFSearchMethod { kBF, kHNSW }; @@ -80,7 +82,8 @@ enum class QuantizerType { kAQ, kFP16, kInt8, - kInt4 + kInt4, + kRabitq, }; struct SerializableBase { @@ -186,6 +189,16 @@ struct HNSWQueryParam : public BaseIndexQueryParam { } }; +struct HNSWRabitqQueryParam : public BaseIndexQueryParam { + using Pointer = std::shared_ptr; + + uint32_t ef_search = kDefaultHnswEfSearch; + + BaseIndexQueryParam::Pointer Clone() const override { + return std::make_shared(*this); + } +}; + struct IVFQueryParam : public BaseIndexQueryParam { int nprobe = 10; std::shared_ptr l1QueryParam = nullptr; @@ -316,4 +329,37 @@ struct HNSWIndexParam : public BaseIndexParam { bool omit_empty_value = false) const override; }; +struct HNSWRabitqIndexParam : public BaseIndexParam { + using Pointer = std::shared_ptr; + + // HNSW parameters + int m = kDefaultHnswNeighborCnt; + int ef_construction = kDefaultHnswEfConstruction; + + // Rabitq parameters + int total_bits = kDefaultRabitqTotalBits; + int num_clusters = kDefaultRabitqNumClusters; + int sample_count = 0; + core::IndexProvider::Pointer provider = nullptr; + core::IndexReformer::Pointer reformer = nullptr; + + // Constructors with delegation + HNSWRabitqIndexParam() : BaseIndexParam(IndexType::kHNSWRabitq) {} + + HNSWRabitqIndexParam(int m, int ef_construction) + : BaseIndexParam(IndexType::kHNSWRabitq), + m(m), + ef_construction(ef_construction) {} + + HNSWRabitqIndexParam(MetricType metric, int dim, int m, int ef_construction) + : BaseIndexParam(IndexType::kHNSWRabitq, metric, dim), + m(m), + ef_construction(ef_construction) {} + + protected: + bool DeserializeFromJsonObject(const ailego::JsonObject &json_obj) override; + ailego::JsonObject SerializeToJsonObject( + bool omit_empty_value = false) const override; +}; + } // namespace zvec::core_interface \ No newline at end of file diff --git a/src/include/zvec/core/interface/index_param_builders.h b/src/include/zvec/core/interface/index_param_builders.h index f319450d..f2eb48d3 100644 --- a/src/include/zvec/core/interface/index_param_builders.h +++ b/src/include/zvec/core/interface/index_param_builders.h @@ -16,6 +16,9 @@ #include #include +#include "zvec/core/framework/index_provider.h" +#include "zvec/core/framework/index_reformer.h" +#include "zvec/core/interface/index.h" namespace zvec::core_interface { @@ -148,6 +151,46 @@ class HNSWIndexParamBuilder } }; +class HNSWRabitqIndexParamBuilder + : public BaseIndexParamBuilder { + public: + HNSWRabitqIndexParamBuilder() = default; + HNSWRabitqIndexParamBuilder &WithM(int m) { + param->m = m; + return *this; + } + HNSWRabitqIndexParamBuilder &WithEFConstruction(int ef_construction) { + param->ef_construction = ef_construction; + return *this; + } + HNSWRabitqIndexParamBuilder &WithTotalBits(int total_bits) { + param->total_bits = total_bits; + return *this; + } + HNSWRabitqIndexParamBuilder &WithNumClusters(int num_clusters) { + param->num_clusters = num_clusters; + return *this; + } + HNSWRabitqIndexParamBuilder &WithSampleCount(int sample_count) { + param->sample_count = sample_count; + return *this; + } + HNSWRabitqIndexParamBuilder &WithReformer( + core::IndexReformer::Pointer reformer) { + param->reformer = std::move(reformer); + return *this; + } + HNSWRabitqIndexParamBuilder &WithProvider( + core::IndexProvider::Pointer provider) { + param->provider = std::move(provider); + return *this; + } + std::shared_ptr Build() override { + return param; + } +}; + // class CompositeIndexParamBuilder : public // BaseIndexParamBuilder // { public: diff --git a/src/include/zvec/db/index_params.h b/src/include/zvec/db/index_params.h index dc0a22ae..fcccf080 100644 --- a/src/include/zvec/db/index_params.h +++ b/src/include/zvec/db/index_params.h @@ -18,6 +18,8 @@ #include #include #include +#include "zvec/core/framework/index_provider.h" +#include "zvec/core/framework/index_reformer.h" namespace zvec { @@ -44,7 +46,7 @@ class IndexParams { bool is_vector_index_type() const { return type_ == IndexType::FLAT || type_ == IndexType::HNSW || - type_ == IndexType::IVF; + type_ == IndexType::HNSW_RABITQ || type_ == IndexType::IVF; } IndexType type() const { @@ -200,9 +202,121 @@ class HnswIndexParams : public VectorIndexParams { return ef_construction_; } + protected: + int m_; + int ef_construction_; +}; + +class HnswRabitqIndexParams : public VectorIndexParams { + public: + HnswRabitqIndexParams( + MetricType metric_type, + int total_bits = core_interface::kDefaultRabitqTotalBits, + int num_clusters = core_interface::kDefaultRabitqNumClusters, + int m = core_interface::kDefaultHnswNeighborCnt, + int ef_construction = core_interface::kDefaultHnswEfConstruction, + int sample_count = 0) + : VectorIndexParams(IndexType::HNSW_RABITQ, metric_type, + QuantizeType::RABITQ), + total_bits_(total_bits), + num_clusters_(num_clusters), + sample_count_(sample_count), + m_(m), + ef_construction_(ef_construction) {} + + using OPtr = std::shared_ptr; + + Ptr clone() const override { + auto obj = std::make_shared( + metric_type_, total_bits_, num_clusters_, m_, ef_construction_, + sample_count_); + obj->set_rabitq_reformer(rabitq_reformer_); + obj->set_raw_vector_provider(raw_vector_provider_); + return obj; + } + + std::string to_string() const override { + auto base_str = vector_index_params_to_string("HnswRabitqIndexParams", + metric_type_, quantize_type_); + std::ostringstream oss; + oss << base_str << ",total_bits:" << total_bits_ + << ",num_clusters:" << num_clusters_ + << ",sample_count:" << sample_count_ << ",m:" << m_ + << ",ef_construction:" << ef_construction_ << "}"; + return oss.str(); + } + + bool operator==(const IndexParams &other) const override { + if (type() != other.type()) { + return false; + } + auto &other_rabitq = dynamic_cast(other); + return metric_type() == other_rabitq.metric_type() && + quantize_type_ == other_rabitq.quantize_type_ && + total_bits_ == other_rabitq.total_bits_ && + num_clusters_ == other_rabitq.num_clusters_ && + sample_count_ == other_rabitq.sample_count_ && + m_ == other_rabitq.m_ && + ef_construction_ == other_rabitq.ef_construction_; + } + + void set_m(int m) { + m_ = m; + } + int m() const { + return m_; + } + void set_ef_construction(int ef_construction) { + ef_construction_ = ef_construction; + } + int ef_construction() const { + return ef_construction_; + } + + void set_raw_vector_provider( + core::IndexProvider::Pointer raw_vector_provider) { + raw_vector_provider_ = std::move(raw_vector_provider); + } + + void set_rabitq_reformer(core::IndexReformer::Pointer rabitq_reformer) { + rabitq_reformer_ = std::move(rabitq_reformer); + } + core::IndexReformer::Pointer rabitq_reformer() const { + return rabitq_reformer_; + } + core::IndexProvider::Pointer raw_vector_provider() const { + return raw_vector_provider_; + } + + void set_total_bits(int total_bits) { + total_bits_ = total_bits; + } + int total_bits() const { + return total_bits_; + } + + void set_num_clusters(int num_clusters) { + num_clusters_ = num_clusters; + } + int num_clusters() const { + return num_clusters_; + } + + void set_sample_count(int sample_count) { + sample_count_ = sample_count; + } + int sample_count() const { + return sample_count_; + } + private: + int total_bits_; + int num_clusters_; + int sample_count_; int m_; int ef_construction_; + core::IndexProvider::Pointer raw_vector_provider_; + core::IndexReformer::Pointer rabitq_reformer_; }; class FlatIndexParams : public VectorIndexParams { diff --git a/src/include/zvec/db/query_params.h b/src/include/zvec/db/query_params.h index d187d762..ba62dab9 100644 --- a/src/include/zvec/db/query_params.h +++ b/src/include/zvec/db/query_params.h @@ -125,6 +125,31 @@ class IVFQueryParams : public QueryParams { float scale_factor_{10}; }; +class HnswRabitqQueryParams : public QueryParams { + public: + HnswRabitqQueryParams(int ef = core_interface::kDefaultHnswEfSearch, + float radius = 0.0f, bool is_linear = false, + bool is_using_refiner = false) + : QueryParams(IndexType::HNSW_RABITQ), ef_(ef) { + set_radius(radius); + set_is_linear(is_linear); + set_is_using_refiner(is_using_refiner); + } + + virtual ~HnswRabitqQueryParams() = default; + + int ef() const { + return ef_; + } + + void set_ef(int ef) { + ef_ = ef; + } + + private: + int ef_; +}; + class FlatQueryParams : public QueryParams { public: FlatQueryParams(bool is_using_refiner = false, float scale_factor = 10) diff --git a/src/include/zvec/db/type.h b/src/include/zvec/db/type.h index 188c1bdc..1578f81d 100644 --- a/src/include/zvec/db/type.h +++ b/src/include/zvec/db/type.h @@ -23,8 +23,9 @@ namespace zvec { enum class IndexType : uint32_t { UNDEFINED = 0, HNSW = 1, - IVF = 3, - FLAT = 4, + IVF = 2, + FLAT = 3, + HNSW_RABITQ = 4, INVERT = 10, }; @@ -72,6 +73,7 @@ enum class QuantizeType : uint32_t { FP16 = 1, INT8 = 2, INT4 = 3, + RABITQ = 4, }; enum class MetricType : uint32_t { diff --git a/tests/core/algorithm/CMakeLists.txt b/tests/core/algorithm/CMakeLists.txt index 0e9aa725..0abfb506 100644 --- a/tests/core/algorithm/CMakeLists.txt +++ b/tests/core/algorithm/CMakeLists.txt @@ -7,3 +7,6 @@ cc_directories(flat_sparse) cc_directories(ivf) cc_directories(hnsw) cc_directories(hnsw_sparse) +if(RABITQ_SUPPORTED) +cc_directories(hnsw_rabitq) +endif() diff --git a/tests/core/algorithm/hnsw_rabitq/CMakeLists.txt b/tests/core/algorithm/hnsw_rabitq/CMakeLists.txt new file mode 100644 index 00000000..930b1766 --- /dev/null +++ b/tests/core/algorithm/hnsw_rabitq/CMakeLists.txt @@ -0,0 +1,34 @@ +include(${CMAKE_SOURCE_DIR}/cmake/bazel.cmake) +include(${CMAKE_SOURCE_DIR}/cmake/option.cmake) + +if(APPLE) + set(APPLE_FRAMEWORK_LIBS + -framework CoreFoundation + -framework CoreGraphics + -framework CoreData + -framework CoreText + -framework Security + -framework Foundation + -Wl,-U,_MallocExtension_ReleaseFreeMemory + -Wl,-U,_ProfilerStart + -Wl,-U,_ProfilerStop + -Wl,-U,_RegisterThriftProtocol + ) +endif() + +file(GLOB_RECURSE ALL_TEST_SRCS *_test.cc) + +foreach(CC_SRCS ${ALL_TEST_SRCS}) + get_filename_component(CC_TARGET ${CC_SRCS} NAME_WE) + cc_gtest( + NAME ${CC_TARGET} + STRICT + LIBS zvec_ailego core_framework core_utility core_metric core_quantizer core_knn_hnsw_rabitq core_knn_flat core_knn_cluster + ${CMAKE_THREAD_LIBS_INIT} + ${CMAKE_DL_LIBS} + SRCS ${CC_SRCS} + INCS . ${CMAKE_SOURCE_DIR}/src/core ${CMAKE_SOURCE_DIR}/src/core/algorithm/hnsw-rabitq + LDFLAGS ${APPLE_FRAMEWORK_LIBS} + ) + cc_test_suite(hnsw_rabitq ${CC_TARGET}) +endforeach() diff --git a/tests/core/algorithm/hnsw_rabitq/hnsw_rabitq_builder_test.cc b/tests/core/algorithm/hnsw_rabitq/hnsw_rabitq_builder_test.cc new file mode 100644 index 00000000..9c446e13 --- /dev/null +++ b/tests/core/algorithm/hnsw_rabitq/hnsw_rabitq_builder_test.cc @@ -0,0 +1,369 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hnsw_rabitq_builder.h" +#include +#include +#include +#include +#include +#include +#include +#include "zvec/core/framework/index_framework.h" +#include "zvec/core/framework/index_logger.h" +#include "zvec/core/framework/index_provider.h" + +#if defined(__GNUC__) || defined(__GNUG__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-result" +#endif + +using namespace std; +using namespace zvec::ailego; + +namespace zvec { +namespace core { + +constexpr size_t static dim = 128; + +class HnswRabitqBuilderTest : public testing::Test { + protected: + void SetUp(void); + void TearDown(void); + + static std::string _dir; + static shared_ptr _index_meta_ptr; +}; + +std::string HnswRabitqBuilderTest::_dir("hnswRabitqBuilderTest"); +shared_ptr HnswRabitqBuilderTest::_index_meta_ptr; + +void HnswRabitqBuilderTest::SetUp(void) { + IndexLoggerBroker::SetLevel(0); + _index_meta_ptr.reset(new (nothrow) + IndexMeta(IndexMeta::DataType::DT_FP32, dim)); + _index_meta_ptr->set_metric("SquaredEuclidean", 0, ailego::Params()); +} + +void HnswRabitqBuilderTest::TearDown(void) { + char cmdBuf[100]; + snprintf(cmdBuf, 100, "rm -rf %s", _dir.c_str()); + // system(cmdBuf); +} + +TEST_F(HnswRabitqBuilderTest, TestGeneral) { + IndexBuilder::Pointer builder = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 1000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i * dim + j) / 1000.0f; + } + ASSERT_TRUE(holder->emplace(i, std::move(vec))); + } + + ailego::Params params; + params.set("proxima.rabitq.num_clusters", 16UL); + params.set("proxima.rabitq.total_bits", 2UL); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + + ASSERT_EQ(0, builder->train(holder)); + + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestGeneral"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + auto &stats = builder->stats(); + ASSERT_EQ(0UL, stats.trained_count()); + ASSERT_EQ(doc_cnt, stats.built_count()); + ASSERT_EQ(doc_cnt, stats.dumped_count()); + ASSERT_EQ(0UL, stats.discarded_count()); + ASSERT_EQ(0UL, stats.trained_costtime()); + ASSERT_GT(stats.built_costtime(), 0UL); +} + +TEST_F(HnswRabitqBuilderTest, TestLoad) { + // Load index with searcher and verify search + auto searcher = IndexFactory::CreateSearcher("HnswRabitqSearcher"); + ASSERT_NE(searcher, nullptr); + + ailego::Params search_params; + search_params.set("proxima.hnsw_rabitq.searcher.ef", 100UL); + ASSERT_EQ(0, searcher->init(search_params)); + + auto loader = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_NE(loader, nullptr); + ASSERT_EQ(0, loader->init(ailego::Params())); + string path = _dir + "/TestGeneral"; + ASSERT_EQ(0, loader->open(path, false)); + + ASSERT_EQ(0, searcher->load(loader, nullptr)); + + // Perform search verification + NumericalVector query_vec(dim); + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = static_cast(j) / 1000.0f; + } + + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + + auto context = searcher->create_context(); + ASSERT_NE(context, nullptr); + context->set_topk(10); + + ASSERT_EQ(0, searcher->search_impl(query_vec.data(), query_meta, 1, context)); + + const auto &result = context->result(0); + ASSERT_GT(result.size(), 0UL); + ASSERT_LE(result.size(), 10UL); +} + +TEST_F(HnswRabitqBuilderTest, TestMemquota) { + IndexBuilder::Pointer builder = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 1000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i * dim + j) / 1000.0f; + } + ASSERT_TRUE(holder->emplace(i, std::move(vec))); + } + + ailego::Params params; + params.set("proxima.rabitq.num_clusters", 16UL); + params.set("proxima.rabitq.total_bits", 2UL); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + params.set("proxima.hnsw_rabitq.builder.memory_quota", 100000UL); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + ASSERT_EQ(0, builder->train(holder)); + ASSERT_EQ(IndexError_NoMemory, builder->build(holder)); +} + +TEST_F(HnswRabitqBuilderTest, TestIndexThreads) { + IndexBuilder::Pointer builder1 = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder1, nullptr); + IndexBuilder::Pointer builder2 = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder2, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 1000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i * dim + j) / 1000.0f; + } + ASSERT_TRUE(holder->emplace(i, std::move(vec))); + } + + ailego::Params params; + params.set("proxima.rabitq.num_clusters", 16UL); + params.set("proxima.rabitq.total_bits", 2UL); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + + std::srand(ailego::Realtime::MilliSeconds()); + auto threads = + std::make_shared(std::rand() % 4, false); + ASSERT_EQ(0, builder1->init(*_index_meta_ptr, params)); + ASSERT_EQ(0, builder2->init(*_index_meta_ptr, params)); + + auto build_index1 = [&]() { + ASSERT_EQ(0, builder1->train(threads, holder)); + ASSERT_EQ(0, builder1->build(threads, holder)); + }; + auto build_index2 = [&]() { + ASSERT_EQ(0, builder2->train(threads, holder)); + ASSERT_EQ(0, builder2->build(threads, holder)); + }; + + auto t1 = std::async(std::launch::async, build_index1); + auto t2 = std::async(std::launch::async, build_index2); + t1.wait(); + t2.wait(); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestIndexThreads"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder1->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder2->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + auto &stats1 = builder1->stats(); + ASSERT_EQ(doc_cnt, stats1.built_count()); + auto &stats2 = builder2->stats(); + ASSERT_EQ(doc_cnt, stats2.built_count()); +} + +TEST_F(HnswRabitqBuilderTest, TestCosine) { + IndexBuilder::Pointer builder = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 1000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i * dim + j) / 1000.0f; + } + ASSERT_TRUE(holder->emplace(i, std::move(vec))); + } + + IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); + index_meta_raw.set_metric("Cosine", 0, ailego::Params()); + + ailego::Params converter_params; + auto converter = IndexFactory::CreateConverter("CosineFp32Converter"); + converter->init(index_meta_raw, converter_params); + + IndexMeta index_meta = converter->meta(); + + converter->transform(holder); + + auto converted_holder = converter->result(); + converted_holder = convert_holder_to_provider(converted_holder); + + ailego::Params params; + params.set("proxima.rabitq.num_clusters", 16UL); + params.set("proxima.rabitq.total_bits", 2UL); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + + ASSERT_EQ(0, builder->init(index_meta, params)); + + ASSERT_EQ(0, builder->train(converted_holder)); + + ASSERT_EQ(0, builder->build(converted_holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestCosine"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + auto &stats = builder->stats(); + ASSERT_EQ(0UL, stats.trained_count()); + ASSERT_EQ(doc_cnt, stats.built_count()); + ASSERT_EQ(doc_cnt, stats.dumped_count()); + ASSERT_EQ(0UL, stats.discarded_count()); + ASSERT_EQ(0UL, stats.trained_costtime()); + ASSERT_GT(stats.built_costtime(), 0UL); +} + +TEST_F(HnswRabitqBuilderTest, TestCleanupAndRebuild) { + IndexBuilder::Pointer builder = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 1000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i * dim + j) / 1000.0f; + } + ASSERT_TRUE(holder->emplace(i, std::move(vec))); + } + + ailego::Params params; + params.set("proxima.rabitq.num_clusters", 16UL); + params.set("proxima.rabitq.total_bits", 2UL); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + ASSERT_EQ(0, builder->train(holder)); + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestCleanupAndRebuild"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + auto &stats = builder->stats(); + ASSERT_EQ(0UL, stats.trained_count()); + ASSERT_EQ(doc_cnt, stats.built_count()); + ASSERT_EQ(doc_cnt, stats.dumped_count()); + ASSERT_EQ(0UL, stats.discarded_count()); + ASSERT_EQ(0UL, stats.trained_costtime()); + ASSERT_GT(stats.built_costtime(), 0UL); + + // Cleanup and rebuild with more documents + ASSERT_EQ(0, builder->cleanup()); + + auto holder2 = + make_shared>(dim); + size_t doc_cnt2 = 2000UL; + for (size_t i = 0; i < doc_cnt2; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i * dim + j) / 1000.0f; + } + ASSERT_TRUE(holder2->emplace(i, std::move(vec))); + } + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + ASSERT_EQ(0, builder->train(holder2)); + ASSERT_EQ(0, builder->build(holder2)); + + auto dumper2 = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper2, nullptr); + ASSERT_EQ(0, dumper2->create(path)); + ASSERT_EQ(0, builder->dump(dumper2)); + ASSERT_EQ(0, dumper2->close()); + + ASSERT_EQ(0UL, stats.trained_count()); + ASSERT_EQ(doc_cnt2, stats.built_count()); + ASSERT_EQ(doc_cnt2, stats.dumped_count()); + ASSERT_EQ(0UL, stats.discarded_count()); + ASSERT_EQ(0UL, stats.trained_costtime()); + ASSERT_GT(stats.built_costtime(), 0UL); +} + +} // namespace core +} // namespace zvec + +#if defined(__GNUC__) || defined(__GNUG__) +#pragma GCC diagnostic pop +#endif diff --git a/tests/core/algorithm/hnsw_rabitq/hnsw_rabitq_searcher_test.cc b/tests/core/algorithm/hnsw_rabitq/hnsw_rabitq_searcher_test.cc new file mode 100644 index 00000000..f2b38956 --- /dev/null +++ b/tests/core/algorithm/hnsw_rabitq/hnsw_rabitq_searcher_test.cc @@ -0,0 +1,573 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hnsw_rabitq_searcher.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "zvec/core/framework/index_framework.h" +#include "zvec/core/framework/index_logger.h" +#include "hnsw_rabitq_builder.h" + +#if defined(__GNUC__) || defined(__GNUG__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-result" +#endif + +using namespace std; +using namespace zvec::ailego; + +namespace zvec { +namespace core { + +constexpr size_t static dim = 128; + +class HnswRabitqSearcherTest : public testing::Test { + protected: + void SetUp(void); + void TearDown(void); + + static std::string _dir; + static shared_ptr _index_meta_ptr; +}; + +std::string HnswRabitqSearcherTest::_dir("HnswRabitqSearcherTest"); +shared_ptr HnswRabitqSearcherTest::_index_meta_ptr; + +void HnswRabitqSearcherTest::SetUp(void) { + IndexLoggerBroker::SetLevel(0); + _index_meta_ptr.reset(new (nothrow) + IndexMeta(IndexMeta::DataType::DT_FP32, dim)); + _index_meta_ptr->set_metric("SquaredEuclidean", 0, ailego::Params()); +} + +void HnswRabitqSearcherTest::TearDown(void) { + char cmdBuf[100]; + snprintf(cmdBuf, 100, "rm -rf %s", _dir.c_str()); + // system(cmdBuf); +} + +TEST_F(HnswRabitqSearcherTest, TestBasicSearch) { + // Build index first + IndexBuilder::Pointer builder = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i * dim + j) / 1000.0f; + } + ASSERT_TRUE(holder->emplace(i, std::move(vec))); + } + + ailego::Params params; + params.set("proxima.rabitq.num_clusters", 16UL); + params.set("proxima.rabitq.total_bits", 2UL); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + ASSERT_EQ(0, builder->train(holder)); + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestBasicSearch"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + // Test searcher + auto searcher = IndexFactory::CreateSearcher("HnswRabitqSearcher"); + ASSERT_NE(searcher, nullptr); + + ailego::Params search_params; + search_params.set("proxima.hnsw_rabitq.searcher.ef", 100UL); + ASSERT_EQ(0, searcher->init(search_params)); + + auto loader = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_NE(loader, nullptr); + ASSERT_EQ(0, loader->init(ailego::Params())); + ASSERT_EQ(0, loader->open(path, false)); + + ASSERT_EQ(0, searcher->load(loader, nullptr)); + + // Perform search + NumericalVector query_vec(dim); + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = static_cast(j) / 1000.0f; + } + + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + + auto context = searcher->create_context(); + ASSERT_TRUE(!!context); + context->set_topk(10); + + ASSERT_EQ(0, searcher->search_impl(query_vec.data(), query_meta, 1, context)); + + const auto &result = context->result(0); + ASSERT_GT(result.size(), 0UL); + ASSERT_LE(result.size(), 10UL); + + // Verify results are sorted by distance + for (size_t i = 1; i < result.size(); ++i) { + ASSERT_LE(result[i - 1].score(), result[i].score()); + } +} + +TEST_F(HnswRabitqSearcherTest, TestRnnSearch) { + // Build index first + IndexBuilder::Pointer builder = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i); + } + ASSERT_TRUE(holder->emplace(i, std::move(vec))); + } + + ailego::Params params; + params.set("proxima.rabitq.num_clusters", 16UL); + params.set("proxima.rabitq.total_bits", 2UL); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + ASSERT_EQ(0, builder->train(holder)); + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestRnnSearch"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + // Test searcher with radius search + auto searcher = IndexFactory::CreateSearcher("HnswRabitqSearcher"); + ASSERT_NE(searcher, nullptr); + + ailego::Params search_params; + search_params.set("proxima.hnsw_rabitq.searcher.ef", 100UL); + ASSERT_EQ(0, searcher->init(search_params)); + + auto loader = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_NE(loader, nullptr); + ASSERT_EQ(0, loader->init(ailego::Params())); + ASSERT_EQ(0, loader->open(path, false)); + + ASSERT_EQ(0, searcher->load(loader, nullptr)); + + NumericalVector query_vec(dim); + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = 0.0f; + } + + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + auto context = searcher->create_context(); + ASSERT_NE(context, nullptr); + + size_t topk = 50; + context->set_topk(topk); + ASSERT_EQ(0, searcher->search_impl(query_vec.data(), query_meta, 1, context)); + + const auto &results = context->result(0); + ASSERT_EQ(topk, results.size()); + + // Test with radius threshold + float radius = results[topk / 2].score(); + context->set_threshold(radius); + ASSERT_EQ(0, searcher->search_impl(query_vec.data(), query_meta, 1, context)); + ASSERT_GT(topk, results.size()); + for (size_t k = 0; k < results.size(); ++k) { + ASSERT_GE(radius, results[k].score()); + } + + // Test reset threshold + context->reset_threshold(); + ASSERT_EQ(0, searcher->search_impl(query_vec.data(), query_meta, 1, context)); + ASSERT_EQ(topk, results.size()); + ASSERT_LT(radius, results[topk - 1].score()); +} + +TEST_F(HnswRabitqSearcherTest, DISABLED_TestSearchInnerProduct) { + // Build index with InnerProduct metric + IndexBuilder::Pointer builder = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i); + } + ASSERT_TRUE(holder->emplace(i, std::move(vec))); + } + + IndexMeta index_meta(IndexMeta::DataType::DT_FP32, dim); + index_meta.set_metric("InnerProduct", 0, ailego::Params()); + + ailego::Params params; + params.set("proxima.rabitq.num_clusters", 16UL); + params.set("proxima.rabitq.total_bits", 2UL); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + + ASSERT_EQ(0, builder->init(index_meta, params)); + ASSERT_EQ(0, builder->train(holder)); + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestSearchInnerProduct"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + // Test searcher + auto searcher = IndexFactory::CreateSearcher("HnswRabitqSearcher"); + ASSERT_NE(searcher, nullptr); + + ailego::Params search_params; + search_params.set("proxima.hnsw_rabitq.searcher.ef", 100UL); + ASSERT_EQ(0, searcher->init(search_params)); + + auto loader = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_NE(loader, nullptr); + ASSERT_EQ(0, loader->init(ailego::Params())); + ASSERT_EQ(0, loader->open(path, false)); + + ASSERT_EQ(0, searcher->load(loader, nullptr)); + + NumericalVector query_vec(dim); + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = 1.0f; + } + + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + auto context = searcher->create_context(); + ASSERT_TRUE(!!context); + + size_t topk = 50; + context->set_topk(topk); + ASSERT_EQ(0, searcher->search_impl(query_vec.data(), query_meta, 1, context)); + + const auto &results = context->result(0); + ASSERT_EQ(topk, results.size()); + + // Test with radius threshold (note: InnerProduct uses negative scores) + float radius = -results[topk / 2].score(); + context->set_threshold(radius); + ASSERT_EQ(0, searcher->search_impl(query_vec.data(), query_meta, 1, context)); + ASSERT_GT(topk, results.size()); + for (size_t k = 0; k < results.size(); ++k) { + LOG_ERROR("radius: %f, score: %f", radius, results[k].score()); + EXPECT_GE(radius, results[k].score()); + } + + // Test reset threshold + context->reset_threshold(); + ASSERT_EQ(0, searcher->search_impl(query_vec.data(), query_meta, 1, context)); + ASSERT_EQ(topk, results.size()); + ASSERT_LT(-radius, results[topk - 1].score()); +} + +TEST_F(HnswRabitqSearcherTest, TestSearchCosine) { + // Build index with Cosine metric + IndexBuilder::Pointer builder = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(-1.0, 1.0); + + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = dist(gen); + } + ASSERT_TRUE(holder->emplace(i, std::move(vec))); + } + + IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, dim); + index_meta_raw.set_metric("Cosine", 0, ailego::Params()); + + ailego::Params converter_params; + auto converter = IndexFactory::CreateConverter("CosineFp32Converter"); + converter->init(index_meta_raw, converter_params); + + IndexMeta index_meta = converter->meta(); + + converter->transform(holder); + + auto converted_holder = converter->result(); + converted_holder = convert_holder_to_provider(converted_holder); + + ailego::Params params; + params.set("proxima.rabitq.num_clusters", 16UL); + params.set("proxima.rabitq.total_bits", 2UL); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + + ASSERT_EQ(0, builder->init(index_meta, params)); + ASSERT_EQ(0, builder->train(converted_holder)); + ASSERT_EQ(0, builder->build(converted_holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestSearchCosine"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + // Test searcher + auto searcher = IndexFactory::CreateSearcher("HnswRabitqSearcher"); + ASSERT_NE(searcher, nullptr); + + ailego::Params search_params; + search_params.set("proxima.hnsw_rabitq.searcher.ef", 100UL); + ASSERT_EQ(0, searcher->init(search_params)); + + auto loader = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_NE(loader, nullptr); + ASSERT_EQ(0, loader->init(ailego::Params())); + ASSERT_EQ(0, loader->open(path, false)); + + ASSERT_EQ(0, searcher->load(loader, nullptr)); + + NumericalVector query_vec(dim); + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = 1.0f; + } + + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); + ASSERT_TRUE(reformer != nullptr); + + ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); + + std::string new_query; + IndexQueryMeta new_meta; + ASSERT_EQ(0, reformer->transform(query_vec.data(), query_meta, &new_query, + &new_meta)); + + auto context = searcher->create_context(); + ASSERT_TRUE(!!context); + + size_t topk = 50; + context->set_topk(topk); + ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, 1, context)); + + const auto &results = context->result(0); + ASSERT_EQ(topk, results.size()); + + // Test with radius threshold + float radius = 0.5f; + context->set_threshold(radius); + ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, 1, context)); + ASSERT_GT(topk, results.size()); + for (size_t k = 0; k < results.size(); ++k) { + ASSERT_GE(radius, results[k].score()); + } + + // Test reset threshold + context->reset_threshold(); + ASSERT_EQ(0, searcher->search_impl(new_query.data(), new_meta, 1, context)); + ASSERT_EQ(topk, results.size()); + ASSERT_LT(radius, results[topk - 1].score()); +} + +TEST_F(HnswRabitqSearcherTest, TestMultipleQueries) { + // Build index first + IndexBuilder::Pointer builder = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i * dim + j) / 1000.0f; + } + ASSERT_TRUE(holder->emplace(i, std::move(vec))); + } + + ailego::Params params; + params.set("proxima.rabitq.num_clusters", 16UL); + params.set("proxima.rabitq.total_bits", 2UL); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + ASSERT_EQ(0, builder->train(holder)); + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestMultipleQueries"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + // Test searcher with multiple queries + auto searcher = IndexFactory::CreateSearcher("HnswRabitqSearcher"); + ASSERT_NE(searcher, nullptr); + + ailego::Params search_params; + search_params.set("proxima.hnsw_rabitq.searcher.ef", 100UL); + ASSERT_EQ(0, searcher->init(search_params)); + + auto loader = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_NE(loader, nullptr); + ASSERT_EQ(0, loader->init(ailego::Params())); + ASSERT_EQ(0, loader->open(path, false)); + + ASSERT_EQ(0, searcher->load(loader, nullptr)); + + // Test with different query vectors + for (size_t query_id = 0; query_id < 5; ++query_id) { + NumericalVector query_vec(dim); + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = static_cast(query_id * dim + j) / 1000.0f; + } + + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + + auto context = searcher->create_context(); + ASSERT_TRUE(!!context); + context->set_topk(20); + + ASSERT_EQ(0, + searcher->search_impl(query_vec.data(), query_meta, 1, context)); + + const auto &result = context->result(0); + ASSERT_GT(result.size(), 0UL); + ASSERT_LE(result.size(), 20UL); + + // Verify results are sorted + for (size_t i = 1; i < result.size(); ++i) { + ASSERT_LE(result[i - 1].score(), result[i].score()); + } + } +} + +TEST_F(HnswRabitqSearcherTest, TestDifferentTopK) { + // Build index first + IndexBuilder::Pointer builder = + IndexFactory::CreateBuilder("HnswRabitqBuilder"); + ASSERT_NE(builder, nullptr); + + auto holder = + make_shared>(dim); + size_t doc_cnt = 10000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i * dim + j) / 1000.0f; + } + ASSERT_TRUE(holder->emplace(i, std::move(vec))); + } + + ailego::Params params; + params.set("proxima.rabitq.num_clusters", 16UL); + params.set("proxima.rabitq.total_bits", 2UL); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + + ASSERT_EQ(0, builder->init(*_index_meta_ptr, params)); + ASSERT_EQ(0, builder->train(holder)); + ASSERT_EQ(0, builder->build(holder)); + + auto dumper = IndexFactory::CreateDumper("FileDumper"); + ASSERT_NE(dumper, nullptr); + + string path = _dir + "/TestDifferentTopK"; + ASSERT_EQ(0, dumper->create(path)); + ASSERT_EQ(0, builder->dump(dumper)); + ASSERT_EQ(0, dumper->close()); + + // Test searcher with different topk values + auto searcher = IndexFactory::CreateSearcher("HnswRabitqSearcher"); + ASSERT_NE(searcher, nullptr); + + ailego::Params search_params; + search_params.set("proxima.hnsw_rabitq.searcher.ef", 100UL); + ASSERT_EQ(0, searcher->init(search_params)); + + auto loader = IndexFactory::CreateStorage("FileReadStorage"); + ASSERT_NE(loader, nullptr); + ASSERT_EQ(0, loader->init(ailego::Params())); + ASSERT_EQ(0, loader->open(path, false)); + + ASSERT_EQ(0, searcher->load(loader, nullptr)); + + NumericalVector query_vec(dim); + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = static_cast(j) / 1000.0f; + } + + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + + // Test with different topk values + std::vector topk_values = {1, 5, 10, 20, 50, 100}; + for (size_t topk : topk_values) { + auto context = searcher->create_context(); + ASSERT_TRUE(!!context); + context->set_topk(topk); + + ASSERT_EQ(0, + searcher->search_impl(query_vec.data(), query_meta, 1, context)); + + const auto &result = context->result(0); + ASSERT_GT(result.size(), 0UL); + ASSERT_LE(result.size(), topk); + + // Verify results are sorted + for (size_t i = 1; i < result.size(); ++i) { + ASSERT_LE(result[i - 1].score(), result[i].score()); + } + } +} + +} // namespace core +} // namespace zvec + +#if defined(__GNUC__) || defined(__GNUG__) +#pragma GCC diagnostic pop +#endif diff --git a/tests/core/algorithm/hnsw_rabitq/hnsw_rabitq_streamer_test.cc b/tests/core/algorithm/hnsw_rabitq/hnsw_rabitq_streamer_test.cc new file mode 100644 index 00000000..f570b649 --- /dev/null +++ b/tests/core/algorithm/hnsw_rabitq/hnsw_rabitq_streamer_test.cc @@ -0,0 +1,536 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hnsw_rabitq_streamer.h" +#include +#include +#include "zvec/ailego/container/params.h" +#include "zvec/core/framework/index_holder.h" +#include "zvec/core/framework/index_streamer.h" +#include "hnsw_rabitq_streamer.h" +#include "rabitq_converter.h" +#include "rabitq_reformer.h" + +using namespace std; +using namespace zvec::ailego; + +namespace zvec { +namespace core { + +constexpr size_t static dim = 128; + +class HnswRabitqStreamerTest : public testing::Test { + protected: + void SetUp(void); + void TearDown(void); + + static std::string dir_; + static shared_ptr index_meta_ptr_; +}; + +std::string HnswRabitqStreamerTest::dir_("hnswRabitqStreamerTest"); +shared_ptr HnswRabitqStreamerTest::index_meta_ptr_; + +void HnswRabitqStreamerTest::SetUp(void) { + index_meta_ptr_.reset(new (nothrow) + IndexMeta(IndexMeta::DataType::DT_FP32, dim)); + index_meta_ptr_->set_metric("SquaredEuclidean", 0, ailego::Params()); +} + +void HnswRabitqStreamerTest::TearDown(void) { + char cmdBuf[100]; + snprintf(cmdBuf, 100, "rm -rf %s", dir_.c_str()); + system(cmdBuf); +} + +TEST_F(HnswRabitqStreamerTest, TestBuildAndSearch) { + auto holder = + make_shared>(dim); + size_t doc_cnt = 1000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i * dim + j) / 1000.0f; + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + RabitqConverter converter; + converter.init(*index_meta_ptr_, ailego::Params()); + ASSERT_EQ(converter.train(holder), 0); + std::shared_ptr index_reformer; + ASSERT_EQ(converter.to_reformer(&index_reformer), 0); + auto reformer = std::dynamic_pointer_cast(index_reformer); + IndexStreamer::Pointer streamer = + std::make_shared(holder, reformer); + + ailego::Params params; + params.set("proxima.hnsw_rabitq.streamer.max_neighbor_count", 16U); + params.set("proxima.hnsw_rabitq.streamer.upper_neighbor_count", 8U); + params.set("proxima.hnsw_rabitq.streamer.scaling_factor", 5U); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + ASSERT_EQ(0, streamer->init(*index_meta_ptr_, params)); + auto storage = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage); + ailego::Params stg_params; + ASSERT_EQ(0, storage->init(stg_params)); + ASSERT_EQ(0, storage->open(dir_ + "/Test/AddVector", true)); + ASSERT_EQ(0, streamer->open(storage)); + + auto context = streamer->create_context(); + for (auto it = holder->create_iterator(); it->is_valid(); it->next()) { + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + ASSERT_EQ(0, + streamer->add_impl(it->key(), it->data(), query_meta, context)); + } + streamer->flush(0UL); + + // Perform search verification + NumericalVector query_vec(dim); + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = static_cast(j) / 1000.0f; + } + + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + + context->set_topk(10); + ASSERT_EQ(0, streamer->search_impl(query_vec.data(), query_meta, 1, context)); + + const auto &result = context->result(0); + ASSERT_GT(result.size(), 0UL); + ASSERT_LE(result.size(), 10UL); + + // reopen and load reformer from storage + ASSERT_EQ(0, streamer->close()); + IndexStreamer::Pointer new_streamer = + std::make_shared(holder); + ASSERT_EQ(0, new_streamer->init(*index_meta_ptr_, params)); + ASSERT_EQ(0, new_streamer->open(storage)); +} + +TEST_F(HnswRabitqStreamerTest, TestLinearSearch) { + auto holder = + make_shared>(dim); + size_t doc_cnt = 1000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i); + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + RabitqConverter converter; + converter.init(*index_meta_ptr_, ailego::Params()); + ASSERT_EQ(converter.train(holder), 0); + std::shared_ptr index_reformer; + ASSERT_EQ(converter.to_reformer(&index_reformer), 0); + auto reformer = std::dynamic_pointer_cast(index_reformer); + IndexStreamer::Pointer streamer = + std::make_shared(holder, reformer); + + ailego::Params params; + params.set("proxima.hnsw_rabitq.streamer.max_neighbor_count", 16U); + params.set("proxima.hnsw_rabitq.streamer.upper_neighbor_count", 8U); + params.set("proxima.hnsw_rabitq.streamer.scaling_factor", 5U); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + ASSERT_EQ(0, streamer->init(*index_meta_ptr_, params)); + auto storage = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage); + ailego::Params stg_params; + ASSERT_EQ(0, storage->init(stg_params)); + ASSERT_EQ(0, storage->open(dir_ + "/TestLinearSearch", true)); + ASSERT_EQ(0, streamer->open(storage)); + + auto context = streamer->create_context(); + for (auto it = holder->create_iterator(); it->is_valid(); it->next()) { + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + ASSERT_EQ(0, + streamer->add_impl(it->key(), it->data(), query_meta, context)); + } + + // Test linear search with exact match + size_t topk = 3; + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + NumericalVector query_vec(dim); + + for (size_t i = 0; i < doc_cnt; i += 100) { + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = static_cast(i); + } + context->set_topk(1U); + ASSERT_EQ(0, + streamer->search_bf_impl(query_vec.data(), query_meta, context)); + auto &result1 = context->result(); + ASSERT_EQ(1UL, result1.size()); + ASSERT_EQ(i, result1[0].key()); + + // Test with slight offset + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = static_cast(i) + 0.1f; + } + context->set_topk(topk); + ASSERT_EQ(0, + streamer->search_bf_impl(query_vec.data(), query_meta, context)); + auto &result2 = context->result(); + ASSERT_EQ(topk, result2.size()); + ASSERT_EQ(i, result2[0].key()); + } +} + +TEST_F(HnswRabitqStreamerTest, TestKnnSearch) { + auto holder = + make_shared>(dim); + size_t doc_cnt = 2000UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i); + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + RabitqConverter converter; + converter.init(*index_meta_ptr_, ailego::Params()); + ASSERT_EQ(converter.train(holder), 0); + std::shared_ptr index_reformer; + ASSERT_EQ(converter.to_reformer(&index_reformer), 0); + auto reformer = std::dynamic_pointer_cast(index_reformer); + IndexStreamer::Pointer streamer = + std::make_shared(holder, reformer); + + ailego::Params params; + params.set("proxima.hnsw_rabitq.streamer.max_neighbor_count", 16U); + params.set("proxima.hnsw_rabitq.streamer.upper_neighbor_count", 8U); + params.set("proxima.hnsw_rabitq.streamer.scaling_factor", 10U); + params.set("proxima.hnsw_rabitq.streamer.efconstruction", 100U); + params.set("proxima.hnsw_rabitq.streamer.ef", 50U); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + ASSERT_EQ(0, streamer->init(*index_meta_ptr_, params)); + auto storage = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage); + ailego::Params stg_params; + ASSERT_EQ(0, storage->init(stg_params)); + ASSERT_EQ(0, storage->open(dir_ + "/TestKnnSearch", true)); + ASSERT_EQ(0, streamer->open(storage)); + + auto context = streamer->create_context(); + for (auto it = holder->create_iterator(); it->is_valid(); it->next()) { + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + ASSERT_EQ(0, + streamer->add_impl(it->key(), it->data(), query_meta, context)); + } + + // Compare KNN search with brute force search + auto linear_ctx = streamer->create_context(); + auto knn_ctx = streamer->create_context(); + size_t topk = 50; + linear_ctx->set_topk(topk); + knn_ctx->set_topk(topk); + + int total_hits = 0; + int total_cnts = 0; + int topk1_hits = 0; + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + NumericalVector query_vec(dim); + + for (size_t i = 0; i < doc_cnt; i += 100) { + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = static_cast(i) + 0.1f; + } + + ASSERT_EQ(0, + streamer->search_impl(query_vec.data(), query_meta, 1, knn_ctx)); + ASSERT_EQ( + 0, streamer->search_bf_impl(query_vec.data(), query_meta, linear_ctx)); + + auto &knn_result = knn_ctx->result(0); + ASSERT_EQ(topk, knn_result.size()); + topk1_hits += (i == knn_result[0].key()); + + auto &linear_result = linear_ctx->result(); + ASSERT_EQ(topk, linear_result.size()); + ASSERT_EQ(i, linear_result[0].key()); + + for (size_t k = 0; k < topk; ++k) { + total_cnts++; + for (size_t j = 0; j < topk; ++j) { + if (linear_result[j].key() == knn_result[k].key()) { + total_hits++; + break; + } + } + } + } + + float recall = total_hits * 1.0f / total_cnts; + float topk1_recall = topk1_hits * 100.0f / static_cast(doc_cnt); + EXPECT_GT(recall, 0.60f); + // actual: no guarantee + // TODO(jiliang.ljl): check if ok? + EXPECT_GT(topk1_recall, 0.00f); +} + +TEST_F(HnswRabitqStreamerTest, TestRandomData) { + auto holder = + make_shared>(dim); + size_t doc_cnt = 1500UL; + + // Add random vectors + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(rand()) / static_cast(RAND_MAX); + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + RabitqConverter converter; + converter.init(*index_meta_ptr_, ailego::Params()); + ASSERT_EQ(converter.train(holder), 0); + std::shared_ptr index_reformer; + ASSERT_EQ(converter.to_reformer(&index_reformer), 0); + auto reformer = std::dynamic_pointer_cast(index_reformer); + IndexStreamer::Pointer streamer = + std::make_shared(holder, reformer); + + ailego::Params params; + params.set("proxima.hnsw_rabitq.streamer.max_neighbor_count", 32U); + params.set("proxima.hnsw_rabitq.streamer.upper_neighbor_count", 16U); + params.set("proxima.hnsw_rabitq.streamer.scaling_factor", 20U); + params.set("proxima.hnsw_rabitq.streamer.efconstruction", 200U); + params.set("proxima.hnsw_rabitq.streamer.ef", 100U); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + ASSERT_EQ(0, streamer->init(*index_meta_ptr_, params)); + auto storage = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage); + ailego::Params stg_params; + ASSERT_EQ(0, storage->init(stg_params)); + ASSERT_EQ(0, storage->open(dir_ + "/TestRandomData", true)); + ASSERT_EQ(0, streamer->open(storage)); + + auto context = streamer->create_context(); + for (auto it = holder->create_iterator(); it->is_valid(); it->next()) { + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + ASSERT_EQ(0, + streamer->add_impl(it->key(), it->data(), query_meta, context)); + } + + // Test with random queries + auto linear_ctx = streamer->create_context(); + auto knn_ctx = streamer->create_context(); + size_t topk = 50; + linear_ctx->set_topk(topk); + knn_ctx->set_topk(topk); + + int total_hits = 0; + int total_cnts = 0; + int topk1_hits = 0; + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + NumericalVector query_vec(dim); + + size_t query_cnt = 200; + for (size_t i = 0; i < query_cnt; i++) { + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = static_cast(rand()) / static_cast(RAND_MAX); + } + + ASSERT_EQ( + 0, streamer->search_bf_impl(query_vec.data(), query_meta, linear_ctx)); + ASSERT_EQ(0, + streamer->search_impl(query_vec.data(), query_meta, 1, knn_ctx)); + + auto &knn_result = knn_ctx->result(0); + ASSERT_EQ(topk, knn_result.size()); + + auto &linear_result = linear_ctx->result(); + ASSERT_EQ(topk, linear_result.size()); + + topk1_hits += (linear_result[0].key() == knn_result[0].key()); + + for (size_t k = 0; k < topk; ++k) { + total_cnts++; + for (size_t j = 0; j < topk; ++j) { + if (linear_result[j].key() == knn_result[k].key()) { + total_hits++; + break; + } + } + } + } + + float recall = total_hits * 1.0f / total_cnts; + float topk1_recall = topk1_hits * 1.0f / query_cnt; + EXPECT_GT(recall, 0.50f); + EXPECT_GT(topk1_recall, 0.70f); +} + +TEST_F(HnswRabitqStreamerTest, TestOpenClose) { + auto holder = + make_shared>(dim); + size_t doc_cnt = 500UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i); + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + RabitqConverter converter; + converter.init(*index_meta_ptr_, ailego::Params()); + ASSERT_EQ(converter.train(holder), 0); + std::shared_ptr index_reformer; + ASSERT_EQ(converter.to_reformer(&index_reformer), 0); + auto reformer = std::dynamic_pointer_cast(index_reformer); + + ailego::Params params; + params.set("proxima.hnsw_rabitq.streamer.max_neighbor_count", 16U); + params.set("proxima.hnsw_rabitq.streamer.upper_neighbor_count", 8U); + params.set("proxima.hnsw_rabitq.streamer.scaling_factor", 5U); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + + auto storage = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage); + ailego::Params stg_params; + ASSERT_EQ(0, storage->init(stg_params)); + ASSERT_EQ(0, storage->open(dir_ + "/TestOpenClose", true)); + + IndexStreamer::Pointer streamer = + std::make_shared(holder, reformer); + ASSERT_EQ(0, streamer->init(*index_meta_ptr_, params)); + ASSERT_EQ(0, streamer->open(storage)); + + auto context = streamer->create_context(); + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + + // Add first half of vectors + for (size_t i = 0; i < doc_cnt / 2; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i); + } + ASSERT_EQ(0, streamer->add_impl(i, vec.data(), query_meta, context)); + } + + ASSERT_EQ(0, streamer->flush(0UL)); + ASSERT_EQ(0, streamer->close()); + + // Reopen and add second half + IndexStreamer::Pointer streamer2 = + std::make_shared(holder); + ASSERT_EQ(0, streamer2->init(*index_meta_ptr_, params)); + ASSERT_EQ(0, streamer2->open(storage)); + + auto context2 = streamer2->create_context(); + for (size_t i = doc_cnt / 2; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i); + } + ASSERT_EQ(0, streamer2->add_impl(i, vec.data(), query_meta, context2)); + } + + ASSERT_EQ(0, streamer2->flush(0UL)); + + // Verify search works after reopen + NumericalVector query_vec(dim); + for (size_t j = 0; j < dim; ++j) { + query_vec[j] = 10.0f; + } + + context2->set_topk(5); + ASSERT_EQ(0, + streamer2->search_impl(query_vec.data(), query_meta, 1, context2)); + const auto &result = context2->result(0); + ASSERT_EQ(5UL, result.size()); + ASSERT_EQ(10UL, result[0].key()); +} + +TEST_F(HnswRabitqStreamerTest, TestCreateIterator) { + auto holder = + make_shared>(dim); + size_t doc_cnt = 300UL; + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i); + } + ASSERT_TRUE(holder->emplace(i, vec)); + } + + RabitqConverter converter; + converter.init(*index_meta_ptr_, ailego::Params()); + ASSERT_EQ(converter.train(holder), 0); + std::shared_ptr index_reformer; + ASSERT_EQ(converter.to_reformer(&index_reformer), 0); + auto reformer = std::dynamic_pointer_cast(index_reformer); + IndexStreamer::Pointer streamer = + std::make_shared(holder, reformer); + + ailego::Params params; + params.set("proxima.hnsw_rabitq.streamer.max_neighbor_count", 16U); + params.set("proxima.hnsw_rabitq.streamer.upper_neighbor_count", 8U); + params.set("proxima.hnsw_rabitq.streamer.scaling_factor", 5U); + params.set("proxima.hnsw_rabitq.general.dimension", dim); + ASSERT_EQ(0, streamer->init(*index_meta_ptr_, params)); + auto storage = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage); + ailego::Params stg_params; + ASSERT_EQ(0, storage->init(stg_params)); + ASSERT_EQ(0, storage->open(dir_ + "/TestCreateIterator", true)); + ASSERT_EQ(0, streamer->open(storage)); + + auto context = streamer->create_context(); + IndexQueryMeta query_meta(IndexMeta::DataType::DT_FP32, dim); + + for (size_t i = 0; i < doc_cnt; i++) { + NumericalVector vec(dim); + for (size_t j = 0; j < dim; ++j) { + vec[j] = static_cast(i); + } + ASSERT_EQ(0, streamer->add_impl(i, vec.data(), query_meta, context)); + } + + streamer->flush(0UL); + + // Test iterator + auto provider = streamer->create_provider(); + auto iter = provider->create_iterator(); + ASSERT_TRUE(!!iter); + + size_t count = 0; + while (iter->is_valid()) { + ASSERT_EQ(count, iter->key()); + // const float *data = (const float *)iter->data(); + // for (size_t j = 0; j < dim; ++j) { + // ASSERT_EQ(static_cast(count), data[j]); + // } + iter->next(); + count++; + } + ASSERT_EQ(doc_cnt, count); + + // Test get_vector + // for (size_t i = 0; i < doc_cnt; i++) { + // const float *data = (const float *)provider->get_vector(i); + // ASSERT_NE(data, nullptr); + // for (size_t j = 0; j < dim; ++j) { + // ASSERT_EQ(static_cast(i), data[j]); + // } + // } +} + +} // namespace core +} // namespace zvec diff --git a/tests/db/CMakeLists.txt b/tests/db/CMakeLists.txt index 29f7fcc3..22a8830c 100644 --- a/tests/db/CMakeLists.txt +++ b/tests/db/CMakeLists.txt @@ -30,6 +30,7 @@ foreach(CC_SRCS ${ALL_TEST_SRCS}) core_knn_flat core_knn_flat_sparse core_knn_hnsw + core_knn_hnsw_rabitq core_knn_hnsw_sparse core_knn_ivf core_mix_reducer diff --git a/tests/db/collection_test.cc b/tests/db/collection_test.cc index e39c8615..7dd06650 100644 --- a/tests/db/collection_test.cc +++ b/tests/db/collection_test.cc @@ -4104,6 +4104,84 @@ TEST_F(CollectionTest, Feature_Column_MixOperation_Empty) { } } +#if RABITQ_SUPPORTED +TEST_F(CollectionTest, Feature_Optimize_HNSW_RABITQ) { + auto func = [](MetricType metric_type, int concurrency) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 1000; + + // create simple schema with only FP32 dense vector for HNSW_RABITQ + auto schema = std::make_shared("demo"); + schema->set_max_doc_count_per_segment(MAX_DOC_COUNT_PER_SEGMENT); + + auto hnsw_rabitq_params = std::make_shared( + metric_type, 7, 256, 16, 200, 0); + schema->add_field(std::make_shared( + "dense_fp32", DataType::VECTOR_FP32, 128, false, hnsw_rabitq_params)); + + auto options = CollectionOptions{false, true, 64 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc( + col_path, *schema, options, 0, doc_count, false); + + auto check_doc = [&]() { + for (int i = 0; i < doc_count; i++) { + auto expect_doc = TestHelper::CreateDoc(i, *schema); + auto result = collection->Fetch({expect_doc.pk()}); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value().size(), 1); + ASSERT_EQ(result.value().count(expect_doc.pk()), 1); + auto doc = result.value()[expect_doc.pk()]; + ASSERT_NE(doc, nullptr); + if (*doc != expect_doc) { + std::cout << " doc:" << doc->to_detail_string() << std::endl; + std::cout << "expect_doc:" << expect_doc.to_detail_string() + << std::endl; + } + ASSERT_EQ(*doc, expect_doc); + } + }; + + check_doc(); + std::cout << "check success 1" << std::endl; + + ASSERT_TRUE(collection->Flush().ok()); + auto stats = collection->Stats().value(); + ASSERT_EQ(stats.doc_count, doc_count); + ASSERT_EQ(stats.index_completeness["dense_fp32"], 0); + + auto s = collection->Optimize(OptimizeOptions{concurrency}); + if (!s.ok()) { + std::cout << s.message() << std::endl; + } + ASSERT_TRUE(s.ok()); + + stats = collection->Stats().value(); + ASSERT_EQ(stats.doc_count, doc_count); + ASSERT_EQ(stats.index_completeness["dense_fp32"], 1); + + check_doc(); + std::cout << "check success 2" << std::endl; + + collection.reset(); + auto result = Collection::Open(col_path, options); + ASSERT_TRUE(result.has_value()); + collection = std::move(result.value()); + + check_doc(); + std::cout << "check success 3" << std::endl; + }; + + func(MetricType::L2, 0); + func(MetricType::L2, 4); + func(MetricType::IP, 0); + func(MetricType::IP, 4); + // TODO: cosine dense not match, may be accuracy issue + // func(MetricType::COSINE, 0); + // func(MetricType::COSINE, 4); +} +#endif + // **** CORNER CASES **** // TEST_F(CollectionTest, CornerCase_CreateAndOpen) { // Collection::CreateAndOpen diff --git a/tests/db/index/common/schema_test.cc b/tests/db/index/common/schema_test.cc index fe026a6f..5d9afbdd 100644 --- a/tests/db/index/common/schema_test.cc +++ b/tests/db/index/common/schema_test.cc @@ -248,15 +248,6 @@ TEST(FieldSchemaTest, ComparisonOperators) { // Different name EXPECT_FALSE(field1 == field5); EXPECT_TRUE(field1 != field5); - - // Compare with nullptr index params - FieldSchema field6("no_index", DataType::STRING); - FieldSchema field7("no_index", DataType::STRING); - FieldSchema field8("no_index", DataType::STRING, false, 0, index_params1); - - EXPECT_TRUE(field6 == field7); - EXPECT_FALSE(field6 == field8); - EXPECT_TRUE(field6 != field8); } TEST(FieldSchemaTest, Validate) { @@ -823,4 +814,117 @@ TEST(CollectionSchemaTest, Validate) { ASSERT_FALSE(s.ok()); ASSERT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); } +} + +TEST(FieldSchemaTest, HnswRabitqIndexValidationMetricTypes) { + // Test supported combinations: FP32 + (L2/IP/COSINE) + + // FP32 + L2 + { + auto index_params = std::make_shared( + MetricType::L2, 7, 256, 16, 200, 0); + FieldSchema field("vector_field", DataType::VECTOR_FP32, 128, false, + index_params); + auto status = field.validate(); + EXPECT_TRUE(status.ok()) + << "FP32 + L2 should be supported, but got error: " << status.message(); + } + + // FP32 + IP + { + auto index_params = std::make_shared( + MetricType::IP, 7, 256, 16, 200, 0); + FieldSchema field("vector_field", DataType::VECTOR_FP32, 128, false, + index_params); + auto status = field.validate(); + EXPECT_TRUE(status.ok()) + << "FP32 + IP should be supported, but got error: " << status.message(); + } + + // FP32 + COSINE + { + auto index_params = std::make_shared( + MetricType::COSINE, 7, 256, 16, 200, 0); + FieldSchema field("vector_field", DataType::VECTOR_FP32, 128, false, + index_params); + auto status = field.validate(); + EXPECT_TRUE(status.ok()) + << "FP32 + COSINE should be supported, but got error: " + << status.message(); + } + + // FP32 + MIPSL2 + { + auto index_params = std::make_shared( + MetricType::MIPSL2, 7, 256, 16, 200, 0); + FieldSchema field("vector_field", DataType::VECTOR_FP32, 128, false, + index_params); + auto status = field.validate(); + EXPECT_FALSE(status.ok()) + << "FP32 + MIPSL2 should not be supported, but got error: " + << status.message(); + } +} + +TEST(FieldSchemaTest, HnswRabitqIndexValidation_UnsupportedDataTypes) { + // Test unsupported data types with HNSW_RABITQ index + + // FP16 is not supported + { + auto index_params = std::make_shared( + MetricType::L2, 7, 256, 16, 200, 0); + FieldSchema field("vector_field", DataType::VECTOR_FP16, 128, false, + index_params); + auto status = field.validate(); + EXPECT_FALSE(status.ok()) + << "FP16 should not be supported with HNSW_RABITQ"; + EXPECT_NE( + status.message().find("HNSW_RABITQ index only support FP32 data type"), + std::string::npos) + << "Error message should mention FP32 support only, got: " + << status.message(); + } + + // INT8 is not supported + { + auto index_params = std::make_shared( + MetricType::L2, 7, 256, 16, 200, 0); + FieldSchema field("vector_field", DataType::VECTOR_INT8, 128, false, + index_params); + auto status = field.validate(); + EXPECT_FALSE(status.ok()) + << "INT8 should not be supported with HNSW_RABITQ"; + EXPECT_NE( + status.message().find("HNSW_RABITQ index only support FP32 data type"), + std::string::npos) + << "Error message should mention FP32 support only, got: " + << status.message(); + } + + // FP64 is not supported + { + auto index_params = std::make_shared( + MetricType::L2, 7, 256, 16, 200, 0); + FieldSchema field("vector_field", DataType::VECTOR_FP64, 128, false, + index_params); + auto status = field.validate(); + EXPECT_FALSE(status.ok()) + << "FP64 should not be supported with HNSW_RABITQ"; + } + + // Sparse vector is not supported with HNSW_RABITQ + { + auto index_params = std::make_shared( + MetricType::IP, 7, 256, 16, 200, 0); + FieldSchema field("vector_field", DataType::SPARSE_VECTOR_FP32, 128, false, + index_params); + auto status = field.validate(); + EXPECT_FALSE(status.ok()) + << "Sparse vector should not be supported with HNSW_RABITQ"; + EXPECT_NE( + status.message().find("sparse_vector's index_params only support"), + std::string::npos) + << "Error message should mention sparse vector index support, got: " + << status.message(); + } } \ No newline at end of file diff --git a/thirdparty/CMakeLists.txt b/thirdparty/CMakeLists.txt index a32eac5e..d06d9ca4 100644 --- a/thirdparty/CMakeLists.txt +++ b/thirdparty/CMakeLists.txt @@ -24,4 +24,7 @@ add_subdirectory(rocksdb rocksdb EXCLUDE_FROM_ALL) add_subdirectory(CRoaring CRoaring EXCLUDE_FROM_ALL) add_subdirectory(arrow arrow EXCLUDE_FROM_ALL) add_subdirectory(magic_enum magic_enum EXCLUDE_FROM_ALL) +if(RABITQ_SUPPORTED) +add_subdirectory(RaBitQ-Library RaBitQ-Library EXCLUDE_FROM_ALL) +endif() diff --git a/thirdparty/RaBitQ-Library/CMakeLists.txt b/thirdparty/RaBitQ-Library/CMakeLists.txt new file mode 100644 index 00000000..7ac2d2ee --- /dev/null +++ b/thirdparty/RaBitQ-Library/CMakeLists.txt @@ -0,0 +1,4 @@ +add_library(rabitqlib INTERFACE) +target_include_directories( + sparsehash INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/RaBitQ-Library-0.1/include" + ) diff --git a/thirdparty/RaBitQ-Library/RaBitQ-Library-0.1 b/thirdparty/RaBitQ-Library/RaBitQ-Library-0.1 new file mode 160000 index 00000000..858b0d6c --- /dev/null +++ b/thirdparty/RaBitQ-Library/RaBitQ-Library-0.1 @@ -0,0 +1 @@ +Subproject commit 858b0d6c480766d0e4f08fc5e02f34b53d698fad diff --git a/tools/core/CMakeLists.txt b/tools/core/CMakeLists.txt index 46efc39f..1c655c67 100644 --- a/tools/core/CMakeLists.txt +++ b/tools/core/CMakeLists.txt @@ -6,7 +6,7 @@ cc_binary( STRICT PACKED SRCS txt2vecs.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags core_framework zvec_ailego + LIBS gflags core_framework zvec_ailego ) cc_binary( @@ -14,7 +14,7 @@ cc_binary( STRICT PACKED SRCS local_builder.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_cluster core_knn_ivf core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf core_interface ) cc_binary( @@ -22,7 +22,7 @@ cc_binary( STRICT PACKED SRCS recall.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface ) cc_binary( @@ -30,7 +30,7 @@ cc_binary( STRICT PACKED SRCS bench.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface ) @@ -39,7 +39,7 @@ cc_binary( STRICT PACKED SRCS recall_original.cc flow.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface ) cc_binary( @@ -47,7 +47,7 @@ cc_binary( STRICT PACKED SRCS bench_original.cc flow.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface ) cc_binary( @@ -55,5 +55,5 @@ cc_binary( STRICT PACKED SRCS local_builder_original.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_cluster core_knn_ivf core_interface + LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf core_interface ) diff --git a/tools/core/index_meta_helper.h b/tools/core/index_meta_helper.h index 95044635..ba0fc34c 100644 --- a/tools/core/index_meta_helper.h +++ b/tools/core/index_meta_helper.h @@ -108,6 +108,18 @@ class IndexMetaHelper { std::cerr << "Not supported type(" << type << ") for IP" << std::endl; return false; } + } else if (method == std::string("Cosine")) { + if (feature_type == IndexMeta::DataType::DT_FP32) { + meta.set_metric("Cosine", 0, std::move(params)); + } else if (feature_type == IndexMeta::DataType::DT_INT8) { + meta.set_metric("Cosine", 0, std::move(params)); + } else if (feature_type == IndexMeta::DataType::DT_FP16) { + meta.set_metric("Cosine", 0, std::move(params)); + } else { + std::cerr << "Not supported type(" << type << ") for Cosine" + << std::endl; + return false; + } } else if (method == std::string("HAMMING")) { if (feature_type == IndexMeta::DataType::DT_BINARY32) { meta.set_metric("Hamming", 0, std::move(params)); diff --git a/tools/core/local_builder.cc b/tools/core/local_builder.cc index 9d502a1e..4a157e55 100644 --- a/tools/core/local_builder.cc +++ b/tools/core/local_builder.cc @@ -15,16 +15,22 @@ #include #include #include +#include #include #include #include #include "algorithm/flat/flat_utility.h" +#if RABITQ_SUPPORTED +#include "algorithm/hnsw-rabitq/hnsw_rabitq_streamer.h" +#include "algorithm/hnsw-rabitq/rabitq_converter.h" +#endif #include "algorithm/hnsw/hnsw_params.h" #include "zvec/ailego/logger/logger.h" #include "zvec/core/framework/index_dumper.h" #include "zvec/core/framework/index_factory.h" #include "zvec/core/framework/index_logger.h" #include "zvec/core/framework/index_plugin.h" +#include "zvec/core/framework/index_provider.h" #include "zvec/core/framework/index_reformer.h" #include "zvec/core/framework/index_streamer.h" #include "index_meta_helper.h" @@ -113,6 +119,57 @@ bool prepare_params(YAML::Node &&config_params, ailego::Params ¶ms) { return true; } +int setup_hnsw_rabitq_streamer(const IndexStreamer::Pointer &streamer, + const IndexMeta &meta, YAML::Node &config_root, + const std::string &converter_name, + IndexHolder::Pointer *build_holder) { +#if RABITQ_SUPPORTED + RabitqConverter rabitq_converter; + ailego::Params rabitq_converter_params; + if (config_root["RabitqConverterParams"]) { + auto rabitq_params_node = config_root["RabitqConverterParams"]; + if (!prepare_params(std::move(rabitq_params_node), + rabitq_converter_params)) { + cerr << "Failed to prepare rabitq converter params" << endl; + return -1; + } + } + if (rabitq_converter.init(meta, rabitq_converter_params) != 0) { + cerr << "rabitq converter init failed" << std::endl; + return -1; + } + if (rabitq_converter.train(*build_holder) != 0) { + cerr << "rabitq converter train failed" << std::endl; + return -1; + } + IndexReformer::Pointer rabitq_reformer; + rabitq_converter.to_reformer(&rabitq_reformer); + HnswRabitqStreamer *hnsw_rabitq_streamer = + dynamic_cast(streamer.get()); + hnsw_rabitq_streamer->set_reformer(std::move(rabitq_reformer)); + IndexProvider::Pointer provider; + if (converter_name.empty()) { + // build_holder is VecsIndexHolder + provider = std::dynamic_pointer_cast(*build_holder); + } else { + // build_holder is ordinary IndexHolder, need to convert + provider = convert_holder_to_provider(*build_holder); + // reuse provider to release memory + *build_holder = provider; + } + + if (!provider) { + cerr << "Failed to cast build holder to provider" << endl; + return -1; + } + hnsw_rabitq_streamer->set_provider(provider); + return 0; +#else + cerr << "HNSW RaBitQ is not supported on this platform" << endl; + return -1; +#endif +} + bool check_config(YAML::Node &config_root) { auto common = config_root["BuilderCommon"]; if (!common) { @@ -465,7 +522,8 @@ int do_build_by_streamer(IndexStreamer::Pointer &streamer, uint64_t key = holder->get_key(id); if (retrieval_mode == RM_DENSE) { if (reformer) { - ret = reformer->convert(holder->get_vector(id), qmeta, &ovec, &ometa); + ret = reformer->convert(holder->get_vector_by_index(id), qmeta, &ovec, + &ometa); if (ret != 0) { LOG_ERROR("Failed to convert vector for %s", IndexError::What(ret)); errcode = ret; @@ -473,7 +531,8 @@ int do_build_by_streamer(IndexStreamer::Pointer &streamer, } ret = add_to_streamer(key, ovec.data(), ometa, ctx); } else { - ret = add_to_streamer(key, holder->get_vector(id), qmeta, ctx); + ret = + add_to_streamer(key, holder->get_vector_by_index(id), qmeta, ctx); } } else { LOG_ERROR("Retrieval mode not supported"); @@ -874,6 +933,7 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { build_holder->set_metric(metric_name, metric_params); meta.set_metric(metric_name, 0, metric_params); } + IndexMeta input_meta = meta; string converter_name; ailego::Params converter_params; if (config_common["ConverterName"] && @@ -1086,6 +1146,15 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { cout << "Skip train procedure" << endl; } + if (builder_class == "HnswRabitqStreamer") { + if (setup_hnsw_rabitq_streamer(streamer, input_meta, config_root, + converter_name, &cv_build_holder) != 0) { + return -1; + } + } else if (builder_class == "HnswRabitqBuilder" && !converter_name.empty()) { + cv_build_holder = convert_holder_to_provider(cv_build_holder); + } + // BUILD holder = build_holder; signal(SIGINT, stop); diff --git a/tools/core/local_builder_original.cc b/tools/core/local_builder_original.cc index f6a44014..378c072d 100644 --- a/tools/core/local_builder_original.cc +++ b/tools/core/local_builder_original.cc @@ -15,13 +15,20 @@ #include #include #include +#include #include #include #include +#if RABITQ_SUPPORTED +#include "algorithm/hnsw-rabitq/hnsw_rabitq_streamer.h" +#include "algorithm/hnsw-rabitq/rabitq_converter.h" +#include "algorithm/hnsw-rabitq/rabitq_reformer.h" +#endif #include "zvec/core/framework/index_dumper.h" #include "zvec/core/framework/index_factory.h" #include "zvec/core/framework/index_logger.h" #include "zvec/core/framework/index_plugin.h" +#include "zvec/core/framework/index_provider.h" #include "zvec/core/framework/index_reformer.h" #include "zvec/core/framework/index_streamer.h" #include "index_meta_helper.h" @@ -108,6 +115,60 @@ bool prepare_params(YAML::Node &&config_params, ailego::Params ¶ms) { return true; } +int setup_hnsw_rabitq_streamer(const IndexStreamer::Pointer &streamer, + const IndexMeta &meta, YAML::Node &config_root, + const std::string &converter_name, + IndexHolder::Pointer *build_holder) { +#if RABITQ_SUPPORTED + RabitqConverter rabitq_converter; + ailego::Params rabitq_converter_params; + if (config_root["RabitqConverterParams"] && + !prepare_params(std::move(config_root["RabitqConverterParams"]), + rabitq_converter_params)) { + cerr << "Failed to prepare rabitq converter params" << endl; + return -1; + } + if (rabitq_converter.init(meta, rabitq_converter_params) != 0) { + cerr << "rabitq converter init failed" << std::endl; + return -1; + } + if (rabitq_converter.train(*build_holder) != 0) { + cerr << "rabitq converter train failed" << std::endl; + return -1; + } + IndexReformer::Pointer rabitq_reformer; + rabitq_converter.to_reformer(&rabitq_reformer); + HnswRabitqStreamer *hnsw_rabitq_streamer = + dynamic_cast(streamer.get()); + hnsw_rabitq_streamer->set_reformer(std::move(rabitq_reformer)); + IndexProvider::Pointer provider; + if (converter_name.empty()) { + // build_holder is VecsIndexHolder + provider = std::dynamic_pointer_cast(*build_holder); + } else { + // build_holder is ordinary IndexHolder, need to convert + provider = convert_holder_to_provider(*build_holder); + // reuse provider to release memory + *build_holder = provider; + } + + if (!provider) { + cerr << "Failed to cast build holder to provider" << endl; + return -1; + } + hnsw_rabitq_streamer->set_provider(provider); + return 0; +#else + (void)streamer; + (void)meta; + (void)config_root; + (void)converter_name; + (void)build_holder; + cerr << "HNSW RaBitQ is not supported on this platform" << endl; + return -1; +#endif +} + bool check_config(YAML::Node &config_root) { auto common = config_root["BuilderCommon"]; if (!common) { @@ -421,7 +482,8 @@ int do_build_by_streamer(IndexStreamer::Pointer &streamer, uint64_t key = holder->get_key(id); if (retrieval_mode == RM_DENSE) { if (reformer) { - ret = reformer->convert(holder->get_vector(id), qmeta, &ovec, &ometa); + ret = reformer->convert(holder->get_vector_by_index(id), qmeta, &ovec, + &ometa); if (ret != 0) { LOG_ERROR("Failed to convert vector for %s", IndexError::What(ret)); errcode = ret; @@ -429,7 +491,8 @@ int do_build_by_streamer(IndexStreamer::Pointer &streamer, } ret = streamer->add_impl(key, ovec.data(), ometa, ctx); } else { - ret = streamer->add_impl(key, holder->get_vector(id), qmeta, ctx); + ret = streamer->add_impl(key, holder->get_vector_by_index(id), qmeta, + ctx); } } else { cerr << "Retrieval mode not supported"; @@ -828,6 +891,7 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { build_holder->set_metric(metric_name, metric_params); meta.set_metric(metric_name, 0, metric_params); } + IndexMeta input_meta = meta; string converter_name; ailego::Params converter_params; if (config_common["ConverterName"] && @@ -1036,6 +1100,15 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { cout << "Skip train procedure" << endl; } + if (builder_class == "HnswRabitqStreamer") { + if (setup_hnsw_rabitq_streamer(streamer, input_meta, config_root, + converter_name, &cv_build_holder) != 0) { + return -1; + } + } else if (builder_class == "HnswRabitqBuilder" && !converter_name.empty()) { + cv_build_holder = convert_holder_to_provider(cv_build_holder); + } + // BUILD holder = build_holder; signal(SIGINT, stop); diff --git a/tools/core/vecs_index_holder.h b/tools/core/vecs_index_holder.h index 4d5c8e16..4f743fc8 100644 --- a/tools/core/vecs_index_holder.h +++ b/tools/core/vecs_index_holder.h @@ -15,8 +15,12 @@ #pragma once #include +#include #include +#include "zvec/core/framework/index_error.h" #include "zvec/core/framework/index_holder.h" +#include "zvec/core/framework/index_provider.h" +#include "zvec/core/framework/index_storage.h" #include "vecs_reader.h" namespace zvec { @@ -30,12 +34,16 @@ namespace core { * data = iter->data(); * } */ -class VecsIndexHolder : public IndexHybridHolder { +class VecsIndexHolder : public IndexProvider { public: typedef std::shared_ptr Pointer; bool load(const std::string &file_path) { - return vecs_reader_.load(file_path); + if (!vecs_reader_.load(file_path)) { + return false; + } + build_key_index_map(); + return true; } const IndexMeta &index_meta(void) const { @@ -110,8 +118,7 @@ class VecsIndexHolder : public IndexHybridHolder { return iter; } - virtual IndexHybridHolder::Iterator::Pointer create_hybrid_iterator( - void) override { + virtual IndexHybridHolder::Iterator::Pointer create_hybrid_iterator(void) { // make sure iter has value whenn create_iterator finished IndexHybridHolder::Iterator::Pointer iter( new VecsIndexHolder::Iterator(*this, start_cursor_)); @@ -157,10 +164,6 @@ class VecsIndexHolder : public IndexHybridHolder { return vecs_reader_.get_key(idx); } - const void *get_vector(size_t idx) const { - return vecs_reader_.get_vector(idx); - } - uint32_t get_sparse_count(size_t idx) const { return vecs_reader_.get_sparse_count(idx); } @@ -185,7 +188,7 @@ class VecsIndexHolder : public IndexHybridHolder { return start_cursor_; } - size_t total_sparse_count(void) const override { + size_t total_sparse_count(void) const { return vecs_reader_.get_total_sparse_count(); } @@ -209,11 +212,53 @@ class VecsIndexHolder : public IndexHybridHolder { return vecs_reader_.key_base(); } + const void *get_vector_by_index(size_t idx) const { + return vecs_reader_.get_vector(idx); + } + + public: // IndexProvider interface implementation + //! Retrieve a vector using a primary key + const void *get_vector(const uint64_t key) const override { + auto it = key_to_index_map_.find(key); + if (it == key_to_index_map_.end()) { + return nullptr; + } + return vecs_reader_.get_vector(it->second); + } + + //! Retrieve a vector using a primary key + virtual int get_vector(const uint64_t key, + IndexStorage::MemoryBlock &block) const override { + const void *vector = get_vector(key); + if (vector == nullptr) { + return IndexError_NoExist; + } + block.reset((void *)vector); + return 0; + } + + //! Retrieve the owner class + virtual const std::string &owner_class(void) const override { + static std::string owner_class_name = "VecsIndexHolder"; + return owner_class_name; + } + private: + //! Build key to index mapping + void build_key_index_map() { + key_to_index_map_.clear(); + size_t num_vecs = vecs_reader_.num_vecs(); + for (size_t i = 0; i < num_vecs; ++i) { + uint64_t key = vecs_reader_.get_key(i); + key_to_index_map_[key] = i; + } + } + bool stop_{false}; uint32_t start_cursor_{0}; VecsReader vecs_reader_; size_t max_doc_count_{0}; + std::unordered_map key_to_index_map_; };