diff --git a/.github/release-drafter-config.yml b/.github/release-drafter-config.yml new file mode 100644 index 0000000..9ccb28a --- /dev/null +++ b/.github/release-drafter-config.yml @@ -0,0 +1,48 @@ +name-template: '$NEXT_MINOR_VERSION' +tag-template: 'v$NEXT_MINOR_VERSION' +autolabeler: + - label: 'maintenance' + files: + - '*.md' + - '.github/*' + - label: 'bug' + branch: + - '/bug-.+' + - label: 'maintenance' + branch: + - '/maintenance-.+' + - label: 'feature' + branch: + - '/feature-.+' +categories: + - title: 'Breaking Changes' + labels: + - 'breakingchange' + - title: 'πŸ§ͺ Experimental Features' + labels: + - 'experimental' + - title: 'πŸš€ New Features' + labels: + - 'feature' + - 'enhancement' + - title: 'πŸ› Bug Fixes' + labels: + - 'fix' + - 'bugfix' + - 'bug' + - 'BUG' + - title: '🧰 Maintenance' + label: 'maintenance' +change-template: '- $TITLE (#$NUMBER)' +exclude-labels: + - 'skip-changelog' +template: | + # Changes + + $CHANGES + + ## Contributors + We'd like to thank all the contributors who worked on this release! + + $CONTRIBUTORS + diff --git a/.github/workflows/release-drafter.yml b/.github/workflows/release-drafter.yml new file mode 100644 index 0000000..e25f7a1 --- /dev/null +++ b/.github/workflows/release-drafter.yml @@ -0,0 +1,24 @@ +name: Release Drafter - Redis Retrieval Optimizer + +on: + push: + # branches to consider in the event; optional, defaults to all + branches: + - main + +permissions: {} +jobs: + update_release_draft: + permissions: + pull-requests: write # to add label to PR (release-drafter/release-drafter) + contents: write # to create a github release (release-drafter/release-drafter) + + runs-on: ubuntu-latest + steps: + # Drafts your next Release notes as Pull Requests are merged into "main" + - uses: release-drafter/release-drafter@v5 + with: + # (Optional) specify config name to use, relative to .github/. Default: release-drafter.yml + config-name: release-drafter-config.yml + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..1741688 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,90 @@ +name: Publish Release + +on: + release: + types: [published] + +env: + PYTHON_VERSION: "3.11" + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + version: latest + virtualenvs-create: true + virtualenvs-in-project: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} + + - name: Install dependencies + if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' + run: poetry install --only=main + + - name: Build package + run: poetry build + + - name: Upload build + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + + publish: + needs: build + runs-on: ubuntu-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + version: latest + virtualenvs-create: true + virtualenvs-in-project: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} + + - name: Install dependencies + if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' + run: poetry install --only=main + + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - name: Publish to PyPI + env: + POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI }} + run: poetry publish \ No newline at end of file diff --git a/README.md b/README.md index 7270c52..9947f75 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,9 @@ embedding_models: # embedding cache would be awesome here. embedding_cache_name: "vec-cache" # avoid names with including 'ret-opt' as this can cause collisions search_methods: ["bm25", "vector", "hybrid", "rerank", "weighted_rrf"] # must match what is passed in search_method_map + +# data types to be included in the study (optional, defaults to ["float32"]) +vector_data_types: ["float16", "float32"] ``` #### Code diff --git a/docs/examples/grid_study/00_grid_study.ipynb b/docs/examples/grid_study/00_grid_study.ipynb index 48b21c0..c9db08b 100644 --- a/docs/examples/grid_study/00_grid_study.ipynb +++ b/docs/examples/grid_study/00_grid_study.ipynb @@ -120,40 +120,48 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "b66894d7", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "09:44:07 beir.datasets.data_loader INFO Loading Corpus...\n" - ] - }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3633/3633 [00:00<00:00, 163265.62it/s]" + "/Users/robert.shelton/.pyenv/versions/3.12.8/lib/python3.12/site-packages/beir/util.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", + " from tqdm.autonotebook import tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "09:44:07 beir.datasets.data_loader INFO Loaded 3633 TEST Documents.\n", - "09:44:07 beir.datasets.data_loader INFO Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants died, of which 3,619 (60.2%) was due to breast cancer. After adjustment for age, tumor characteristics, and treatment selection, both post-diagnostic and pre-diagnostic statin use were associated with lowered risk of breast cancer death (HR 0.46, 95% CI 0.38–0.55 and HR 0.54, 95% CI 0.44–0.67, respectively). The risk decrease by post-diagnostic statin use was likely affected by healthy adherer bias; that is, the greater likelihood of dying cancer patients to discontinue statin use as the association was not clearly dose-dependent and observed already at low-dose/short-term use. The dose- and time-dependence of the survival benefit among pre-diagnostic statin users suggests a possible causal effect that should be evaluated further in a clinical trial testing statins’ effect on survival in breast cancer patients.', 'title': 'Statin Use and Breast Cancer Survival: A Nationwide Cohort Study from Finland'}\n", - "09:44:07 beir.datasets.data_loader INFO Loading Queries...\n", - "09:44:07 beir.datasets.data_loader INFO Loaded 323 TEST Queries.\n", - "09:44:07 beir.datasets.data_loader INFO Query Example: Do Cholesterol Statin Drugs Cause Breast Cancer?\n" + "15:38:23 beir.datasets.data_loader INFO Loading Corpus...\n" ] }, { - "name": "stderr", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "916589ecbd6c41c6b49927bff7c6c0aa", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3633 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + "application/vnd.jupyter.widget-view+json": { + "model_id": "6a942f45b96d4bacb9668211be4cc288", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Batches: 0%| | 0/1 [00:00\n", + "\n", + "
search_methodmodelavg_query_timerecall@kprecisionndcg@k
4weighted_rrfsentence-transformers/all-MiniLM-L6-v20.0029970.1649640.2445820.212325
3reranksentence-transformers/all-MiniLM-L6-v20.1707450.1669970.2538700.203366
2hybridsentence-transformers/all-MiniLM-L6-v20.0018720.1549880.2433440.202778
1vectorsentence-transformers/all-MiniLM-L6-v20.0038790.1549880.2433440.196586
0
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", "
search_methodmodelvector_data_typeavg_query_timerecallprecision
4weighted_rrfsentence-transformers/all-MiniLM-L6-v2float160.0069010.1671630.244272
9weighted_rrfsentence-transformers/all-MiniLM-L6-v2float320.0062490.1671630.244272
3reranksentence-transformers/all-MiniLM-L6-v2float160.1229770.1661150.254799
8reranksentence-transformers/all-MiniLM-L6-v2float320.0937960.1661150.254799
1vectorsentence-transformers/all-MiniLM-L6-v2float160.0055200.1549880.243344
2hybridsentence-transformers/all-MiniLM-L6-v2float160.0047140.1549880.243344
6vectorsentence-transformers/all-MiniLM-L6-v2float320.0023930.1549880.243344
7hybridsentence-transformers/all-MiniLM-L6-v2float320.0029760.1549880.243344
0bm25sentence-transformers/all-MiniLM-L6-v2float160.0012750.1387660.281526
5bm25sentence-transformers/all-MiniLM-L6-v20.001269float320.0012480.1387660.2815260.191400
\n", "" ], "text/plain": [ - " search_method model avg_query_time \\\n", - "4 weighted_rrf sentence-transformers/all-MiniLM-L6-v2 0.002997 \n", - "3 rerank sentence-transformers/all-MiniLM-L6-v2 0.170745 \n", - "2 hybrid sentence-transformers/all-MiniLM-L6-v2 0.001872 \n", - "1 vector sentence-transformers/all-MiniLM-L6-v2 0.003879 \n", - "0 bm25 sentence-transformers/all-MiniLM-L6-v2 0.001269 \n", + " search_method model vector_data_type \\\n", + "4 weighted_rrf sentence-transformers/all-MiniLM-L6-v2 float16 \n", + "9 weighted_rrf sentence-transformers/all-MiniLM-L6-v2 float32 \n", + "3 rerank sentence-transformers/all-MiniLM-L6-v2 float16 \n", + "8 rerank sentence-transformers/all-MiniLM-L6-v2 float32 \n", + "1 vector sentence-transformers/all-MiniLM-L6-v2 float16 \n", + "2 hybrid sentence-transformers/all-MiniLM-L6-v2 float16 \n", + "6 vector sentence-transformers/all-MiniLM-L6-v2 float32 \n", + "7 hybrid sentence-transformers/all-MiniLM-L6-v2 float32 \n", + "0 bm25 sentence-transformers/all-MiniLM-L6-v2 float16 \n", + "5 bm25 sentence-transformers/all-MiniLM-L6-v2 float32 \n", "\n", - " recall@k precision ndcg@k \n", - "4 0.164964 0.244582 0.212325 \n", - "3 0.166997 0.253870 0.203366 \n", - "2 0.154988 0.243344 0.202778 \n", - "1 0.154988 0.243344 0.196586 \n", - "0 0.138766 0.281526 0.191400 " + " avg_query_time recall precision \n", + "4 0.006901 0.167163 0.244272 \n", + "9 0.006249 0.167163 0.244272 \n", + "3 0.122977 0.166115 0.254799 \n", + "8 0.093796 0.166115 0.254799 \n", + "1 0.005520 0.154988 0.243344 \n", + "2 0.004714 0.154988 0.243344 \n", + "6 0.002393 0.154988 0.243344 \n", + "7 0.002976 0.154988 0.243344 \n", + "0 0.001275 0.138766 0.281526 \n", + "5 0.001248 0.138766 0.281526 " ] }, - "execution_count": 10, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "metrics[[\"search_method\", \"model\", \"avg_query_time\", \"recall\", \"precision\", \"ndcg\"]].sort_values(by=\"ndcg\", ascending=False)" + "metrics[[\"search_method\", \"model\", \"vector_data_type\", \"avg_query_time\", \"recall\", \"precision\"]].sort_values(by=\"recall\", ascending=False)" ] } ], "metadata": { "kernelspec": { - "display_name": "redis-retrieval-optimizer-Z5sMIYJj-py3.11", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -1521,7 +19304,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.12.8" } }, "nbformat": 4, diff --git a/docs/examples/grid_study/grid_study_config.yaml b/docs/examples/grid_study/grid_study_config.yaml index aaae249..694ee21 100644 --- a/docs/examples/grid_study/grid_study_config.yaml +++ b/docs/examples/grid_study/grid_study_config.yaml @@ -22,4 +22,7 @@ embedding_models: # embedding cache would be awesome here. dim: 384 embedding_cache_name: "vec-cache" # avoid names with including 'ret-opt' as this can cause collisions -search_methods: ["bm25", "vector", "hybrid", "rerank", "weighted_rrf"] # must match what is passed in search_method_map \ No newline at end of file +search_methods: ["bm25", "vector", "hybrid", "rerank", "weighted_rrf"] # must match what is passed in search_method_map + +# data types to be included in the study (optional, defaults to ["float32"]) +vector_data_types: ["float16", "float32"] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d33ba12..c6c89c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "redis-retrieval-optimizer" -version = "0.2.1" +version = "0.3.0" description = "A tool to help optimize information retrieval with the Redis Query Engine." authors = [ "Robert Shelton " ] license = "MIT" diff --git a/redis_retrieval_optimizer/__init__.py b/redis_retrieval_optimizer/__init__.py index e69de29..96e359e 100644 --- a/redis_retrieval_optimizer/__init__.py +++ b/redis_retrieval_optimizer/__init__.py @@ -0,0 +1,3 @@ +__version__ = "0.3.0" + +all = ["__version__"] diff --git a/redis_retrieval_optimizer/grid_study.py b/redis_retrieval_optimizer/grid_study.py index 056835f..50fd2df 100644 --- a/redis_retrieval_optimizer/grid_study.py +++ b/redis_retrieval_optimizer/grid_study.py @@ -15,7 +15,12 @@ def update_metric_row( - metrics, grid_study_config, search_method, embedding_settings, trial_metrics: dict + metrics, + grid_study_config, + search_method, + embedding_settings, + trial_metrics: dict, + vector_data_type: str = None, ): metrics["search_method"].append(search_method) metrics["ret_k"].append(grid_study_config.ret_k) @@ -25,7 +30,7 @@ def update_metric_row( metrics["m"].append(grid_study_config.index_settings.m) metrics["distance_metric"].append(grid_study_config.index_settings.distance_metric) metrics["vector_data_type"].append( - grid_study_config.index_settings.vector_data_type + vector_data_type or grid_study_config.index_settings.vector_data_type ) metrics["model"].append(embedding_settings.model) metrics["model_dim"].append(embedding_settings.dim) @@ -44,12 +49,20 @@ def persist_metrics(metrics, redis_url, study_id): def init_index_from_grid_settings( - grid_study_config: GridStudyConfig, redis_url: str, corpus_processor: Callable + grid_study_config: GridStudyConfig, + redis_url: str, + corpus_processor: Callable, + dtype: str = None, ) -> SearchIndex: index_settings = grid_study_config.index_settings.model_dump() embed_settings = grid_study_config.embedding_models[0] index_settings["embedding"] = embed_settings.model_dump() + # Use provided dtype or default from embedding model + if dtype: + index_settings["vector_data_type"] = dtype + embed_settings.dtype = dtype + if grid_study_config.index_settings.from_existing: print(f"Connecting to existing index: {grid_study_config.index_settings.name}") @@ -65,10 +78,12 @@ def init_index_from_grid_settings( ) if ( embed_settings.dim - != index.schema.fields[grid_study_config.vector_field_name].attrs.dims + != index.schema.fields[ + grid_study_config.index_settings.vector_field_name + ].attrs.dims ): raise ValueError( - f"Embedding model dimension {emb_model.dims} does not match index dimension {index.schema.fields[grid_study_config.vector_field_name].attrs['dims']}" + f"Embedding model dimension {embed_settings.dim} does not match index dimension {index.schema.fields[grid_study_config.index_settings.vector_field_name].attrs['dims']}" ) utils.set_last_index_settings(redis_url, index_settings) else: @@ -77,16 +92,19 @@ def init_index_from_grid_settings( index_settings, last_index_settings ) - schema = utils.schema_from_settings( - grid_study_config.index_settings, - ) + # Create a copy of index settings with current dtype + current_index_settings = grid_study_config.index_settings.model_copy() + if dtype: + current_index_settings.vector_data_type = dtype + + schema = utils.schema_from_settings(current_index_settings) index = SearchIndex.from_dict(schema, redis_url=redis_url) index.create(overwrite=recreate_index, drop=recreate_data) if recreate_data: emb_model = utils.get_embedding_model( - grid_study_config.embedding_models[0], redis_url + grid_study_config.embedding_models[0], redis_url, dtype=dtype ) print("Recreating: loading corpus from file") corpus = utils.load_json(grid_study_config.corpus) @@ -117,10 +135,6 @@ def run_grid_study( queries = utils.load_json(grid_study_config.queries) qrels = Qrels(utils.load_json(grid_study_config.qrels)) - index = init_index_from_grid_settings( - grid_study_config, redis_url, corpus_processor - ) - metrics: dict = { "search_method": [], "total_indexing_time": [], @@ -141,73 +155,87 @@ def run_grid_study( } for i, embedding_model in enumerate(grid_study_config.embedding_models): - if i > 0: - # assuming that you didn't pass the same embedding model twice like a fool - print("Recreating index with new embedding model") - - # delete old index and data with embedding cache it's not expensive to recreate - # consider potential of pre-fixing studies with study_id for separation - index_settings = grid_study_config.index_settings - - # assign new vector info to index_settings - index_settings.vector_data_type = embedding_model.dtype - index_settings.vector_dim = embedding_model.dim - - schema = utils.schema_from_settings(index_settings) - index = utils.index_from_schema( - schema, redis_url, recreate_index=True, recreate_data=True - ) - - # TODO: be able to dump existing index corpus to file automatically which shouldn't be too hard - print( - "If using multiple embedding models assuming there is a json version of corpus available." - ) - print("Recreating: loading corpus from file") - emb_model = utils.get_embedding_model(embedding_model, redis_url) - corpus = utils.load_json(grid_study_config.corpus) - - # corpus processing functions should be user defined - corpus_data = corpus_processor(corpus, emb_model) - index.load(corpus_data) - - while float(index.info()["percent_indexed"]) < 1: - time.sleep(1) - logging.info(f"Indexing progress: {index.info()['percent_indexed']}") - - # check if matches with last index settings - emb_model = utils.get_embedding_model(embedding_model, redis_url) - - for search_method in grid_study_config.search_methods: - print(f"Running search method: {search_method}") - # get search method to try - search_fn = search_method_map[search_method] - search_input = SearchMethodInput( - index=index, - raw_queries=queries, - emb_model=emb_model, - id_field_name=grid_study_config.index_settings.id_field_name, - vector_field_name=grid_study_config.index_settings.vector_field_name, - text_field_name=grid_study_config.index_settings.text_field_name, - ) - - search_method_output = search_fn(search_input) - - trial_metrics = utils.eval_trial_metrics(qrels, search_method_output.run) - trial_metrics["total_indexing_time"] = round( - float(index.info()["total_indexing_time"]) / 1000, 5 - ) - trial_metrics["query_stats"] = utils.get_query_time_stats( - search_method_output.query_metrics.query_times - ) - - metrics = update_metric_row( - metrics, - grid_study_config, - search_method, - embedding_model, - trial_metrics, + for dtype in grid_study_config.vector_data_types: + # Update index settings for current dtype + current_index_settings = grid_study_config.index_settings.model_copy() + current_index_settings.vector_data_type = dtype + + # Create or get index for current settings + if i == 0 and dtype == grid_study_config.vector_data_types[0]: + # First iteration - initialize index + index = init_index_from_grid_settings( + grid_study_config, redis_url, corpus_processor, dtype=dtype + ) + else: + # Recreate index with new settings + print(f"Recreating index with dtype: {dtype}") + + # Update index settings for current embedding model and dtype + current_index_settings.vector_dim = embedding_model.dim + + schema = utils.schema_from_settings(current_index_settings) + index = utils.index_from_schema( + schema, redis_url, recreate_index=True, recreate_data=True + ) + + print("Recreating: loading corpus from file") + emb_model = utils.get_embedding_model( + embedding_model, redis_url, dtype=dtype + ) + corpus = utils.load_json(grid_study_config.corpus) + + # corpus processing functions should be user defined + corpus_data = corpus_processor(corpus, emb_model) + index.load(corpus_data) + + while float(index.info()["percent_indexed"]) < 1: + time.sleep(1) + logging.info( + f"Indexing progress: {index.info()['percent_indexed']}" + ) + + # Get embedding model with current dtype + emb_model = utils.get_embedding_model( + embedding_model, redis_url, dtype=dtype ) - persist_metrics(metrics, redis_url, grid_study_config.study_id) - + for search_method in grid_study_config.search_methods: + print(f"Running search method: {search_method} with dtype: {dtype}") + # get search method to try + search_fn = search_method_map[search_method] + search_input = SearchMethodInput( + index=index, + raw_queries=queries, + emb_model=emb_model, + id_field_name=grid_study_config.index_settings.id_field_name, + vector_field_name=grid_study_config.index_settings.vector_field_name, + text_field_name=grid_study_config.index_settings.text_field_name, + ) + + search_method_output = search_fn(search_input) + + trial_metrics = utils.eval_trial_metrics( + qrels, search_method_output.run + ) + trial_metrics["total_indexing_time"] = round( + float(index.info()["total_indexing_time"]) / 1000, 5 + ) + trial_metrics["query_stats"] = utils.get_query_time_stats( + search_method_output.query_metrics.query_times + ) + + # Create embedding settings with current dtype for metrics + embedding_settings_with_dtype = embedding_model.model_copy() + embedding_settings_with_dtype.dtype = dtype + + metrics = update_metric_row( + metrics, + grid_study_config, + search_method, + embedding_settings_with_dtype, + trial_metrics, + vector_data_type=dtype, + ) + + persist_metrics(metrics, redis_url, grid_study_config.study_id) return pd.DataFrame(metrics) diff --git a/redis_retrieval_optimizer/schema.py b/redis_retrieval_optimizer/schema.py index b846f10..3389b09 100644 --- a/redis_retrieval_optimizer/schema.py +++ b/redis_retrieval_optimizer/schema.py @@ -139,6 +139,7 @@ class GridStudyConfig(BaseModel): embedding_models: list[EmbeddingModel] search_methods: list[str] ret_k: int = 6 + vector_data_types: list[str] = ["float32"] # data types to be included in the study def get_trial_settings(trial, study_config): diff --git a/tests/integration/grid_data/test_grid_study_config.yaml b/tests/integration/grid_data/test_grid_study_config.yaml index 2eaa9c3..1dc22b6 100644 --- a/tests/integration/grid_data/test_grid_study_config.yaml +++ b/tests/integration/grid_data/test_grid_study_config.yaml @@ -19,3 +19,6 @@ search_methods: - hybrid - weighted_rrf - rerank +vector_data_types: +- float16 +- float32 diff --git a/tests/integration/test_grid.py b/tests/integration/test_grid.py index fff732f..5454e5e 100644 --- a/tests/integration/test_grid.py +++ b/tests/integration/test_grid.py @@ -20,6 +20,56 @@ def test_run_grid_study(redis_url): study_config["queries"] = f"{TEST_DIR}/grid_data/queries.json" study_config["qrels"] = f"{TEST_DIR}/grid_data/qrels.json" + # Add vector_data_types to test the new dtype functionality + study_config["vector_data_types"] = ["float32"] + + with open(config_path, "w") as f: + yaml.dump(study_config, f) + + metrics = run_grid_study( + config_path=config_path, + redis_url=redis_url, + corpus_processor=eval_beir.process_corpus, + ) + + # Calculate expected number of trials: embedding_models * vector_data_types * search_methods + expected_trials = ( + len(study_config["embedding_models"]) + * len(study_config["vector_data_types"]) + * len(study_config["search_methods"]) + ) + + assert metrics.shape[0] == expected_trials + + for score in metrics["f1"].tolist(): + assert score > 0.0 + + last_schema = utils.get_last_index_settings(redis_url) + assert last_schema is not None + + index = SearchIndex.from_existing(last_schema["name"], redis_url=redis_url) + + assert index.info()["num_docs"] == 5 + + # clean up + index.client.json().delete("ret-opt:last_schema") + index.delete(drop=True) + + +def test_run_grid_study_with_multiple_dtypes(redis_url): + """Test grid study with multiple vector data types.""" + config_path = f"{TEST_DIR}/grid_data/test_grid_study_config.yaml" + + with open(config_path, "r") as f: + study_config = yaml.safe_load(f) + + study_config["corpus"] = f"{TEST_DIR}/grid_data/corpus.json" + study_config["queries"] = f"{TEST_DIR}/grid_data/queries.json" + study_config["qrels"] = f"{TEST_DIR}/grid_data/qrels.json" + + # Test with multiple dtypes + study_config["vector_data_types"] = ["float16", "float32"] + with open(config_path, "w") as f: yaml.dump(study_config, f) @@ -29,10 +79,20 @@ def test_run_grid_study(redis_url): corpus_processor=eval_beir.process_corpus, ) - assert metrics.shape[0] == len(study_config["search_methods"]) * len( - study_config["embedding_models"] + # Calculate expected number of trials: embedding_models * vector_data_types * search_methods + expected_trials = ( + len(study_config["embedding_models"]) + * len(study_config["vector_data_types"]) + * len(study_config["search_methods"]) ) + assert metrics.shape[0] == expected_trials + + # Verify that both dtypes are present in the results + unique_dtypes = metrics["vector_data_type"].unique() + assert "float16" in unique_dtypes + assert "float32" in unique_dtypes + for score in metrics["f1"].tolist(): assert score > 0.0