From aa4a1006ad166a5fb5eb14f8d4559cb7f666d354 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Fri, 31 Jan 2025 15:43:41 -0500 Subject: [PATCH] update structure for better weighting ergonomics --- README.md | 10 +++++----- examples/dbpedia/dbpedia_study_config.yaml | 8 +++++--- .../getting_started/custom_retriever_optimizer.ipynb | 8 +++++--- examples/getting_started/retrieval_optimizer.ipynb | 8 +++++--- examples/getting_started/study_config.yaml | 6 ++++-- optimize/ex_study_config.yaml | 7 ++++--- optimize/models.py | 9 +++++++-- optimize/study.py | 9 ++++++++- optimize/tests/conftest.py | 11 ++++++++--- 9 files changed, 51 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 70e65ab..40323a1 100644 --- a/README.md +++ b/README.md @@ -150,9 +150,10 @@ raw_data_path: "label_app/data/2008-mazda3-chunks.json" input_data_type: "json" labeled_data_path: "label_app/data/mazda_labeled_items.json" # metrics to be used in objective function -metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"] -# weight of each metric -weights: [1, 1, 1] +metric_weights: + f1_at_k: 1 + embedding_latency: 1 + total_indexing_time: 1 # constraints for the optimization n_trials: 10 n_jobs: 1 @@ -176,14 +177,13 @@ embedding_models: |----------------------|------------------------------------------------|--------------------------------------------------|----------| | **raw_data_path** | `label_app/data/2008-mazda3-chunks.json` | Path to raw data file | ✅ | | **labeled_data_path** | `label_app/data/mazda-labeled-rewritten.json` | Path to labeled data file | ✅ | -| **metrics** | f1_at_k, embedding_latency, total_indexing_time | Metrics used in the objective function | ✅ | -| **weights** | [1, 1, 1] | Weights for f1_at_k, embedding_latency, total_indexing_time respectively. | ✅ | | **algorithms** | flat, hnsw | Indexing algorithms to be tested in optimization | ✅ | | **vector_data_types** | float32, float16 | Data types to be tested for vectors | ✅ | | **n_trials** | 15 | Number of optimization trials | ✅ | | **n_jobs** | 1 | Number of parallel jobs | ✅ | | **ret_k** | [1, 10] | Range of values to be tested for `k` in retrieval | ✅ | | **embedding_models** | **Provider:** hf
**Model:** sentence-transformers/all-MiniLM-L6-v2
**Dim:** 384 | List of embedding models and their dimensions | ✅ | +| **metric_weights** | **f1_at_k:** 1
**embedding_latency:** 1
**total_indexing_time:** 1 | Weight for respective metric used in the objective function | defaults to example | | **input_data_type** | json | Type of input data | defaults to example | | **redis_url** | `redis://localhost:6379` | Connection string for redis instance | defaults to example | | **ef_runtime** | [10, 20, 30, 50] | Max top candidates during search for HNSW | defaults to example | diff --git a/examples/dbpedia/dbpedia_study_config.yaml b/examples/dbpedia/dbpedia_study_config.yaml index 73ba4bf..74d9538 100644 --- a/examples/dbpedia/dbpedia_study_config.yaml +++ b/examples/dbpedia/dbpedia_study_config.yaml @@ -6,9 +6,11 @@ labeled_data_path: "data/dbpedia_labeled.json" # labeled data n_trials: 3 n_jobs: 1 -# metrics to be used in objective function -metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"] -weights: [1, 1, 1] # weight of each metric respectively +# metric weights to be used in objective function +metric_weights: + f1_at_k: 1 + embedding_latency: 1 + total_indexing_time: 1 # optimization decision variables algorithms: ["flat", "hnsw"] # indexing algorithms variables diff --git a/examples/getting_started/custom_retriever_optimizer.ipynb b/examples/getting_started/custom_retriever_optimizer.ipynb index 444671b..ac0d2aa 100644 --- a/examples/getting_started/custom_retriever_optimizer.ipynb +++ b/examples/getting_started/custom_retriever_optimizer.ipynb @@ -177,9 +177,11 @@ "n_trials: 20\n", "n_jobs: 1\n", "\n", - "# metrics to be used in objective function\n", - "metrics: [\"f1_at_k\", \"embedding_latency\", \"total_indexing_time\"] \n", - "weights: [1, 1, 1] # weight of each metric respectively \n", + "# Metric weights to be used in objective function\n", + "metric_weights:\n", + " f1_at_k: 1\n", + " embedding_latency: 1\n", + " total_indexing_time: 1\n", "\n", "# optimization decision variables\n", "algorithms: [\"flat\", \"hnsw\"] # indexing algorithms variables\n", diff --git a/examples/getting_started/retrieval_optimizer.ipynb b/examples/getting_started/retrieval_optimizer.ipynb index a28a4cd..6256a8a 100644 --- a/examples/getting_started/retrieval_optimizer.ipynb +++ b/examples/getting_started/retrieval_optimizer.ipynb @@ -256,9 +256,11 @@ "n_trials: 20\n", "n_jobs: 1\n", "\n", - "# metrics to be used in objective function\n", - "metrics: [\"f1_at_k\", \"embedding_latency\", \"total_indexing_time\"] \n", - "weights: [1, 1, 1] # weight of each metric respectively \n", + "# Metric weights to be used in objective function\n", + "metric_weights:\n", + " f1_at_k: 1\n", + " embedding_latency: 1\n", + " total_indexing_time: 1\n", "\n", "# optimization decision variables\n", "algorithms: [\"flat\", \"hnsw\"] # indexing algorithms variables\n", diff --git a/examples/getting_started/study_config.yaml b/examples/getting_started/study_config.yaml index 6fed12d..9a6cb90 100644 --- a/examples/getting_started/study_config.yaml +++ b/examples/getting_started/study_config.yaml @@ -7,8 +7,10 @@ n_trials: 20 n_jobs: 1 # metrics to be used in objective function -metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"] -weights: [1, 1, 1] # weight of each metric respectively +metric_weights: + f1_at_k: 1 + embedding_latency: 1 + total_indexing_time: 1 # optimization decision variables algorithms: ["flat", "hnsw"] # indexing algorithms variables diff --git a/optimize/ex_study_config.yaml b/optimize/ex_study_config.yaml index 3e68a83..e68a54c 100644 --- a/optimize/ex_study_config.yaml +++ b/optimize/ex_study_config.yaml @@ -3,9 +3,10 @@ input_data_type: "json" raw_data_path: "label_app/data/2008-mazda3-chunks.json" labeled_data_path: "label_app/data/mazda-labeled-rewritten.json" # metrics to be used in objective function -metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"] -# weight of each metric -weights: [1, 1, 1] +metric_weights: + f1_at_k: 1 + embedding_latency: 1 + total_indexing_time: 1 algorithms: ["flat", "hnsw"] vector_data_types: ["float32", "float16"] # constraints for the optimization diff --git a/optimize/models.py b/optimize/models.py index 2a4f4d6..49c4505 100644 --- a/optimize/models.py +++ b/optimize/models.py @@ -45,6 +45,12 @@ class EmbeddingModel(BaseModel): dim: int +class MetricWeights(BaseModel): + f1_at_k: int = 1 + embedding_latency: int = 1 + total_indexing_time: int = 1 + + class StudyConfig(BaseModel): study_id: str = str(uuid4()) redis_url: str = "redis://localhost:6379/0" @@ -54,10 +60,9 @@ class StudyConfig(BaseModel): input_data_type: str labeled_data_path: str embedding_models: list[EmbeddingModel] - metrics: list[str] - weights: list[float] n_trials: int n_jobs: int + metric_weights: MetricWeights = MetricWeights() ret_k: tuple[int, int] = [1, 10] # type: ignore # pydantic vs mypy ef_runtime: list = [10, 50] ef_construction: list = [100, 300] diff --git a/optimize/study.py b/optimize/study.py index 9826acd..60df904 100644 --- a/optimize/study.py +++ b/optimize/study.py @@ -155,7 +155,14 @@ def objective(trial, study_config, custom_retrievers, redis_client): metric_values = [e.f1_at_k, norm_index_time, norm_latency] - e.obj_val = cost_fn(metric_values, study_config.weights) + e.obj_val = cost_fn( + metric_values, + [ + study_config.metric_weights.f1_at_k, + study_config.metric_weights.total_indexing_time, + study_config.metric_weights.embedding_latency, + ], + ) # save results as we go in case of failure persist_metrics(redis_client, e, study_config.study_id) diff --git a/optimize/tests/conftest.py b/optimize/tests/conftest.py index 0faeb85..64f4c58 100644 --- a/optimize/tests/conftest.py +++ b/optimize/tests/conftest.py @@ -8,6 +8,7 @@ EmbeddingModel, EmbeddingSettings, IndexSettings, + MetricWeights, Settings, StudyConfig, ) @@ -15,6 +16,11 @@ TEST_REDIS_URL = os.getenv("TEST_REDIS_URL", "redis://localhost:6379/0") +@pytest.fixture +def metric_weights(): + return MetricWeights(f1_at_k=1, embedding_latency=1, total_indexing_time=1) + + @pytest.fixture def settings(): return Settings( @@ -51,7 +57,7 @@ def test_db_client(): @pytest.fixture -def study_config(embedding_model): +def study_config(embedding_model, metric_weights): return StudyConfig( study_id="study_id", redis_url=TEST_REDIS_URL, @@ -67,6 +73,5 @@ def study_config(embedding_model): embedding_models=[embedding_model], n_trials=1, n_jobs=1, - metrics=["f1_at_k", "embedding_latency", "total_indexing_time"], - weights=[1, 1, 1], + metric_weights=metric_weights, )