diff --git a/README.md b/README.md
index 70e65ab..40323a1 100644
--- a/README.md
+++ b/README.md
@@ -150,9 +150,10 @@ raw_data_path: "label_app/data/2008-mazda3-chunks.json"
input_data_type: "json"
labeled_data_path: "label_app/data/mazda_labeled_items.json"
# metrics to be used in objective function
-metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
-# weight of each metric
-weights: [1, 1, 1]
+metric_weights:
+ f1_at_k: 1
+ embedding_latency: 1
+ total_indexing_time: 1
# constraints for the optimization
n_trials: 10
n_jobs: 1
@@ -176,14 +177,13 @@ embedding_models:
|----------------------|------------------------------------------------|--------------------------------------------------|----------|
| **raw_data_path** | `label_app/data/2008-mazda3-chunks.json` | Path to raw data file | ✅ |
| **labeled_data_path** | `label_app/data/mazda-labeled-rewritten.json` | Path to labeled data file | ✅ |
-| **metrics** | f1_at_k, embedding_latency, total_indexing_time | Metrics used in the objective function | ✅ |
-| **weights** | [1, 1, 1] | Weights for f1_at_k, embedding_latency, total_indexing_time respectively. | ✅ |
| **algorithms** | flat, hnsw | Indexing algorithms to be tested in optimization | ✅ |
| **vector_data_types** | float32, float16 | Data types to be tested for vectors | ✅ |
| **n_trials** | 15 | Number of optimization trials | ✅ |
| **n_jobs** | 1 | Number of parallel jobs | ✅ |
| **ret_k** | [1, 10] | Range of values to be tested for `k` in retrieval | ✅ |
| **embedding_models** | **Provider:** hf
**Model:** sentence-transformers/all-MiniLM-L6-v2
**Dim:** 384 | List of embedding models and their dimensions | ✅ |
+| **metric_weights** | **f1_at_k:** 1
**embedding_latency:** 1
**total_indexing_time:** 1 | Weight for respective metric used in the objective function | defaults to example |
| **input_data_type** | json | Type of input data | defaults to example |
| **redis_url** | `redis://localhost:6379` | Connection string for redis instance | defaults to example |
| **ef_runtime** | [10, 20, 30, 50] | Max top candidates during search for HNSW | defaults to example |
diff --git a/examples/dbpedia/dbpedia_study_config.yaml b/examples/dbpedia/dbpedia_study_config.yaml
index 73ba4bf..74d9538 100644
--- a/examples/dbpedia/dbpedia_study_config.yaml
+++ b/examples/dbpedia/dbpedia_study_config.yaml
@@ -6,9 +6,11 @@ labeled_data_path: "data/dbpedia_labeled.json" # labeled data
n_trials: 3
n_jobs: 1
-# metrics to be used in objective function
-metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
-weights: [1, 1, 1] # weight of each metric respectively
+# metric weights to be used in objective function
+metric_weights:
+ f1_at_k: 1
+ embedding_latency: 1
+ total_indexing_time: 1
# optimization decision variables
algorithms: ["flat", "hnsw"] # indexing algorithms variables
diff --git a/examples/getting_started/custom_retriever_optimizer.ipynb b/examples/getting_started/custom_retriever_optimizer.ipynb
index 444671b..ac0d2aa 100644
--- a/examples/getting_started/custom_retriever_optimizer.ipynb
+++ b/examples/getting_started/custom_retriever_optimizer.ipynb
@@ -177,9 +177,11 @@
"n_trials: 20\n",
"n_jobs: 1\n",
"\n",
- "# metrics to be used in objective function\n",
- "metrics: [\"f1_at_k\", \"embedding_latency\", \"total_indexing_time\"] \n",
- "weights: [1, 1, 1] # weight of each metric respectively \n",
+ "# Metric weights to be used in objective function\n",
+ "metric_weights:\n",
+ " f1_at_k: 1\n",
+ " embedding_latency: 1\n",
+ " total_indexing_time: 1\n",
"\n",
"# optimization decision variables\n",
"algorithms: [\"flat\", \"hnsw\"] # indexing algorithms variables\n",
diff --git a/examples/getting_started/retrieval_optimizer.ipynb b/examples/getting_started/retrieval_optimizer.ipynb
index a28a4cd..6256a8a 100644
--- a/examples/getting_started/retrieval_optimizer.ipynb
+++ b/examples/getting_started/retrieval_optimizer.ipynb
@@ -256,9 +256,11 @@
"n_trials: 20\n",
"n_jobs: 1\n",
"\n",
- "# metrics to be used in objective function\n",
- "metrics: [\"f1_at_k\", \"embedding_latency\", \"total_indexing_time\"] \n",
- "weights: [1, 1, 1] # weight of each metric respectively \n",
+ "# Metric weights to be used in objective function\n",
+ "metric_weights:\n",
+ " f1_at_k: 1\n",
+ " embedding_latency: 1\n",
+ " total_indexing_time: 1\n",
"\n",
"# optimization decision variables\n",
"algorithms: [\"flat\", \"hnsw\"] # indexing algorithms variables\n",
diff --git a/examples/getting_started/study_config.yaml b/examples/getting_started/study_config.yaml
index 6fed12d..9a6cb90 100644
--- a/examples/getting_started/study_config.yaml
+++ b/examples/getting_started/study_config.yaml
@@ -7,8 +7,10 @@ n_trials: 20
n_jobs: 1
# metrics to be used in objective function
-metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
-weights: [1, 1, 1] # weight of each metric respectively
+metric_weights:
+ f1_at_k: 1
+ embedding_latency: 1
+ total_indexing_time: 1
# optimization decision variables
algorithms: ["flat", "hnsw"] # indexing algorithms variables
diff --git a/optimize/ex_study_config.yaml b/optimize/ex_study_config.yaml
index 3e68a83..e68a54c 100644
--- a/optimize/ex_study_config.yaml
+++ b/optimize/ex_study_config.yaml
@@ -3,9 +3,10 @@ input_data_type: "json"
raw_data_path: "label_app/data/2008-mazda3-chunks.json"
labeled_data_path: "label_app/data/mazda-labeled-rewritten.json"
# metrics to be used in objective function
-metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
-# weight of each metric
-weights: [1, 1, 1]
+metric_weights:
+ f1_at_k: 1
+ embedding_latency: 1
+ total_indexing_time: 1
algorithms: ["flat", "hnsw"]
vector_data_types: ["float32", "float16"]
# constraints for the optimization
diff --git a/optimize/models.py b/optimize/models.py
index 2a4f4d6..49c4505 100644
--- a/optimize/models.py
+++ b/optimize/models.py
@@ -45,6 +45,12 @@ class EmbeddingModel(BaseModel):
dim: int
+class MetricWeights(BaseModel):
+ f1_at_k: int = 1
+ embedding_latency: int = 1
+ total_indexing_time: int = 1
+
+
class StudyConfig(BaseModel):
study_id: str = str(uuid4())
redis_url: str = "redis://localhost:6379/0"
@@ -54,10 +60,9 @@ class StudyConfig(BaseModel):
input_data_type: str
labeled_data_path: str
embedding_models: list[EmbeddingModel]
- metrics: list[str]
- weights: list[float]
n_trials: int
n_jobs: int
+ metric_weights: MetricWeights = MetricWeights()
ret_k: tuple[int, int] = [1, 10] # type: ignore # pydantic vs mypy
ef_runtime: list = [10, 50]
ef_construction: list = [100, 300]
diff --git a/optimize/study.py b/optimize/study.py
index 9826acd..60df904 100644
--- a/optimize/study.py
+++ b/optimize/study.py
@@ -155,7 +155,14 @@ def objective(trial, study_config, custom_retrievers, redis_client):
metric_values = [e.f1_at_k, norm_index_time, norm_latency]
- e.obj_val = cost_fn(metric_values, study_config.weights)
+ e.obj_val = cost_fn(
+ metric_values,
+ [
+ study_config.metric_weights.f1_at_k,
+ study_config.metric_weights.total_indexing_time,
+ study_config.metric_weights.embedding_latency,
+ ],
+ )
# save results as we go in case of failure
persist_metrics(redis_client, e, study_config.study_id)
diff --git a/optimize/tests/conftest.py b/optimize/tests/conftest.py
index 0faeb85..64f4c58 100644
--- a/optimize/tests/conftest.py
+++ b/optimize/tests/conftest.py
@@ -8,6 +8,7 @@
EmbeddingModel,
EmbeddingSettings,
IndexSettings,
+ MetricWeights,
Settings,
StudyConfig,
)
@@ -15,6 +16,11 @@
TEST_REDIS_URL = os.getenv("TEST_REDIS_URL", "redis://localhost:6379/0")
+@pytest.fixture
+def metric_weights():
+ return MetricWeights(f1_at_k=1, embedding_latency=1, total_indexing_time=1)
+
+
@pytest.fixture
def settings():
return Settings(
@@ -51,7 +57,7 @@ def test_db_client():
@pytest.fixture
-def study_config(embedding_model):
+def study_config(embedding_model, metric_weights):
return StudyConfig(
study_id="study_id",
redis_url=TEST_REDIS_URL,
@@ -67,6 +73,5 @@ def study_config(embedding_model):
embedding_models=[embedding_model],
n_trials=1,
n_jobs=1,
- metrics=["f1_at_k", "embedding_latency", "total_indexing_time"],
- weights=[1, 1, 1],
+ metric_weights=metric_weights,
)