Skip to content

Commit a459fd2

Browse files
authored
update structure for better weighting ergonomics (#14)
1 parent d0549f2 commit a459fd2

File tree

9 files changed

+51
-25
lines changed

9 files changed

+51
-25
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,10 @@ raw_data_path: "label_app/data/2008-mazda3-chunks.json"
150150
input_data_type: "json"
151151
labeled_data_path: "label_app/data/mazda_labeled_items.json"
152152
# metrics to be used in objective function
153-
metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
154-
# weight of each metric
155-
weights: [1, 1, 1]
153+
metric_weights:
154+
f1_at_k: 1
155+
embedding_latency: 1
156+
total_indexing_time: 1
156157
# constraints for the optimization
157158
n_trials: 10
158159
n_jobs: 1
@@ -176,14 +177,13 @@ embedding_models:
176177
|----------------------|------------------------------------------------|--------------------------------------------------|----------|
177178
| **raw_data_path** | `label_app/data/2008-mazda3-chunks.json` | Path to raw data file | ✅ |
178179
| **labeled_data_path** | `label_app/data/mazda-labeled-rewritten.json` | Path to labeled data file | ✅ |
179-
| **metrics** | f1_at_k, embedding_latency, total_indexing_time | Metrics used in the objective function | ✅ |
180-
| **weights** | [1, 1, 1] | Weights for f1_at_k, embedding_latency, total_indexing_time respectively. | ✅ |
181180
| **algorithms** | flat, hnsw | Indexing algorithms to be tested in optimization | ✅ |
182181
| **vector_data_types** | float32, float16 | Data types to be tested for vectors | ✅ |
183182
| **n_trials** | 15 | Number of optimization trials | ✅ |
184183
| **n_jobs** | 1 | Number of parallel jobs | ✅ |
185184
| **ret_k** | [1, 10] | Range of values to be tested for `k` in retrieval | ✅ |
186185
| **embedding_models** | **Provider:** hf <br> **Model:** sentence-transformers/all-MiniLM-L6-v2 <br> **Dim:** 384 | List of embedding models and their dimensions | ✅ |
186+
| **metric_weights** | **f1_at_k:** 1 <br> **embedding_latency:** 1 <br> **total_indexing_time:** 1 | Weight for respective metric used in the objective function | defaults to example |
187187
| **input_data_type** | json | Type of input data | defaults to example |
188188
| **redis_url** | `redis://localhost:6379` | Connection string for redis instance | defaults to example |
189189
| **ef_runtime** | [10, 20, 30, 50] | Max top candidates during search for HNSW | defaults to example |

examples/dbpedia/dbpedia_study_config.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ labeled_data_path: "data/dbpedia_labeled.json" # labeled data
66
n_trials: 3
77
n_jobs: 1
88

9-
# metrics to be used in objective function
10-
metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
11-
weights: [1, 1, 1] # weight of each metric respectively
9+
# metric weights to be used in objective function
10+
metric_weights:
11+
f1_at_k: 1
12+
embedding_latency: 1
13+
total_indexing_time: 1
1214

1315
# optimization decision variables
1416
algorithms: ["flat", "hnsw"] # indexing algorithms variables

examples/getting_started/custom_retriever_optimizer.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,11 @@
177177
"n_trials: 20\n",
178178
"n_jobs: 1\n",
179179
"\n",
180-
"# metrics to be used in objective function\n",
181-
"metrics: [\"f1_at_k\", \"embedding_latency\", \"total_indexing_time\"] \n",
182-
"weights: [1, 1, 1] # weight of each metric respectively \n",
180+
"# Metric weights to be used in objective function\n",
181+
"metric_weights:\n",
182+
" f1_at_k: 1\n",
183+
" embedding_latency: 1\n",
184+
" total_indexing_time: 1\n",
183185
"\n",
184186
"# optimization decision variables\n",
185187
"algorithms: [\"flat\", \"hnsw\"] # indexing algorithms variables\n",

examples/getting_started/retrieval_optimizer.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,11 @@
256256
"n_trials: 20\n",
257257
"n_jobs: 1\n",
258258
"\n",
259-
"# metrics to be used in objective function\n",
260-
"metrics: [\"f1_at_k\", \"embedding_latency\", \"total_indexing_time\"] \n",
261-
"weights: [1, 1, 1] # weight of each metric respectively \n",
259+
"# Metric weights to be used in objective function\n",
260+
"metric_weights:\n",
261+
" f1_at_k: 1\n",
262+
" embedding_latency: 1\n",
263+
" total_indexing_time: 1\n",
262264
"\n",
263265
"# optimization decision variables\n",
264266
"algorithms: [\"flat\", \"hnsw\"] # indexing algorithms variables\n",

examples/getting_started/study_config.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ n_trials: 20
77
n_jobs: 1
88

99
# metrics to be used in objective function
10-
metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
11-
weights: [1, 1, 1] # weight of each metric respectively
10+
metric_weights:
11+
f1_at_k: 1
12+
embedding_latency: 1
13+
total_indexing_time: 1
1214

1315
# optimization decision variables
1416
algorithms: ["flat", "hnsw"] # indexing algorithms variables

optimize/ex_study_config.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ input_data_type: "json"
33
raw_data_path: "label_app/data/2008-mazda3-chunks.json"
44
labeled_data_path: "label_app/data/mazda-labeled-rewritten.json"
55
# metrics to be used in objective function
6-
metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
7-
# weight of each metric
8-
weights: [1, 1, 1]
6+
metric_weights:
7+
f1_at_k: 1
8+
embedding_latency: 1
9+
total_indexing_time: 1
910
algorithms: ["flat", "hnsw"]
1011
vector_data_types: ["float32", "float16"]
1112
# constraints for the optimization

optimize/models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ class EmbeddingModel(BaseModel):
4545
dim: int
4646

4747

48+
class MetricWeights(BaseModel):
49+
f1_at_k: int = 1
50+
embedding_latency: int = 1
51+
total_indexing_time: int = 1
52+
53+
4854
class StudyConfig(BaseModel):
4955
study_id: str = str(uuid4())
5056
redis_url: str = "redis://localhost:6379/0"
@@ -54,10 +60,9 @@ class StudyConfig(BaseModel):
5460
input_data_type: str
5561
labeled_data_path: str
5662
embedding_models: list[EmbeddingModel]
57-
metrics: list[str]
58-
weights: list[float]
5963
n_trials: int
6064
n_jobs: int
65+
metric_weights: MetricWeights = MetricWeights()
6166
ret_k: tuple[int, int] = [1, 10] # type: ignore # pydantic vs mypy
6267
ef_runtime: list = [10, 50]
6368
ef_construction: list = [100, 300]

optimize/study.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,14 @@ def objective(trial, study_config, custom_retrievers, redis_client):
155155

156156
metric_values = [e.f1_at_k, norm_index_time, norm_latency]
157157

158-
e.obj_val = cost_fn(metric_values, study_config.weights)
158+
e.obj_val = cost_fn(
159+
metric_values,
160+
[
161+
study_config.metric_weights.f1_at_k,
162+
study_config.metric_weights.total_indexing_time,
163+
study_config.metric_weights.embedding_latency,
164+
],
165+
)
159166

160167
# save results as we go in case of failure
161168
persist_metrics(redis_client, e, study_config.study_id)

optimize/tests/conftest.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@
88
EmbeddingModel,
99
EmbeddingSettings,
1010
IndexSettings,
11+
MetricWeights,
1112
Settings,
1213
StudyConfig,
1314
)
1415

1516
TEST_REDIS_URL = os.getenv("TEST_REDIS_URL", "redis://localhost:6379/0")
1617

1718

19+
@pytest.fixture
20+
def metric_weights():
21+
return MetricWeights(f1_at_k=1, embedding_latency=1, total_indexing_time=1)
22+
23+
1824
@pytest.fixture
1925
def settings():
2026
return Settings(
@@ -51,7 +57,7 @@ def test_db_client():
5157

5258

5359
@pytest.fixture
54-
def study_config(embedding_model):
60+
def study_config(embedding_model, metric_weights):
5561
return StudyConfig(
5662
study_id="study_id",
5763
redis_url=TEST_REDIS_URL,
@@ -67,6 +73,5 @@ def study_config(embedding_model):
6773
embedding_models=[embedding_model],
6874
n_trials=1,
6975
n_jobs=1,
70-
metrics=["f1_at_k", "embedding_latency", "total_indexing_time"],
71-
weights=[1, 1, 1],
76+
metric_weights=metric_weights,
7277
)

0 commit comments

Comments
 (0)