Skip to content

Commit ec354db

Browse files
committed
cost_fn fixes and clarity
1 parent b2521c6 commit ec354db

File tree

15 files changed

+984
-757
lines changed

15 files changed

+984
-757
lines changed

.github/workflows/test.yml

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
13+
services:
14+
redis:
15+
image: redis/redis-stack:latest
16+
ports:
17+
- 6379:6379
18+
options: >-
19+
--health-cmd "redis-cli ping"
20+
--health-interval 10s
21+
--health-timeout 5s
22+
--health-retries 5
23+
24+
strategy:
25+
matrix:
26+
python-version: ["3.11", "3.12"]
27+
28+
steps:
29+
- uses: actions/checkout@v4
30+
31+
- name: Set up Python ${{ matrix.python-version }}
32+
uses: actions/setup-python@v4
33+
with:
34+
python-version: ${{ matrix.python-version }}
35+
36+
- name: Install Poetry
37+
uses: snok/install-poetry@v1
38+
with:
39+
version: latest
40+
virtualenvs-create: true
41+
virtualenvs-in-project: true
42+
43+
- name: Load cached venv
44+
id: cached-poetry-dependencies
45+
uses: actions/cache@v3
46+
with:
47+
path: .venv
48+
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
49+
50+
- name: Install dependencies
51+
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
52+
run: poetry install --all-extras
53+
54+
- name: Run tests
55+
run: poetry run test
56+
env:
57+
REDIS_URL: redis://localhost:6379/0
58+
59+
- name: Run tests with coverage
60+
run: poetry run pytest
61+
env:
62+
REDIS_URL: redis://localhost:6379/0

docs/examples/bayesian_optimization/00_bayes_study.ipynb

Lines changed: 695 additions & 712 deletions
Large diffs are not rendered by default.

docs/examples/bayesian_optimization/bayes_study_config.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ index_settings:
1717
optimization_settings:
1818
# defines weight of each metric in optimization function
1919
metric_weights:
20-
f1_at_k: 1
21-
total_indexing_time: 1
20+
f1: 2
21+
total_indexing_time: 2
22+
avg_query_time: 2
23+
recall: 2
24+
ndcg: 2
25+
precision: 2
2226
algorithms: ["hnsw"] # indexing algorithm to be included in the study
2327
vector_data_types: ["float16", "float32"] # data types to be included in the study
2428
distance_metrics: ["cosine"] # distance metrics to be included in the study

docs/examples/comparison/00_comparison.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3641,7 +3641,7 @@
36413641
}
36423642
],
36433643
"source": [
3644-
"metrics[[\"search_method\", \"model\", \"model_dim\", 'total_indexing_time', \"avg_query_time\", \"recall@k\", \"precision\", \"ndcg@k\"]].sort_values(by=\"ndcg@k\", ascending=False)"
3644+
"metrics[[\"search_method\", \"model\", \"model_dim\", 'total_indexing_time', \"avg_query_time\", \"recall\", \"precision\", \"ndcg\"]].sort_values(by=\"ndcg\", ascending=False)"
36453645
]
36463646
},
36473647
{

docs/examples/grid_study/00_grid_study.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1501,7 +1501,7 @@
15011501
}
15021502
],
15031503
"source": [
1504-
"metrics[[\"search_method\", \"model\", \"avg_query_time\", \"recall@k\", \"precision\", \"ndcg@k\"]].sort_values(by=\"ndcg@k\", ascending=False)"
1504+
"metrics[[\"search_method\", \"model\", \"avg_query_time\", \"recall\", \"precision\", \"ndcg\"]].sort_values(by=\"ndcg\", ascending=False)"
15051505
]
15061506
}
15071507
],

docs/examples/grid_study/01_custom_grid_study.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@
562562
}
563563
],
564564
"source": [
565-
"metrics[[\"search_method\", \"model\", \"avg_query_time\", \"recall@k\", \"precision\", \"ndcg@k\"]].sort_values(by=\"ndcg@k\", ascending=False)"
565+
"metrics[[\"search_method\", \"model\", \"avg_query_time\", \"recall\", \"precision\", \"ndcg\"]].sort_values(by=\"ndcg\", ascending=False)"
566566
]
567567
}
568568
],

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "redis-retrieval-optimizer"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
description = "A tool to help optimize information retrieval with the Redis Query Engine."
55
authors = [ "Robert Shelton <[email protected]>" ]
66
license = "MIT"

redis_retrieval_optimizer/bayes_study.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,17 @@
2828
"model": [],
2929
"model_dim": [],
3030
"ret_k": [],
31-
"recall@k": [],
32-
"ndcg@k": [],
33-
"f1@k": [],
31+
"recall": [],
32+
"ndcg": [],
33+
"f1": [],
3434
"precision": [],
3535
"algorithm": [],
3636
"ef_construction": [],
3737
"ef_runtime": [],
3838
"m": [],
3939
"distance_metric": [],
4040
"vector_data_type": [],
41+
"objective_value": [],
4142
}
4243

4344

@@ -52,12 +53,13 @@ def update_metric_row(trial_settings: TrialSettings, trial_metrics: dict):
5253
METRICS["vector_data_type"].append(trial_settings.index_settings.vector_data_type)
5354
METRICS["model"].append(trial_settings.embedding.model)
5455
METRICS["model_dim"].append(trial_settings.embedding.dim)
55-
METRICS["recall@k"].append(trial_metrics["recall"])
56-
METRICS["ndcg@k"].append(trial_metrics["ndcg"])
56+
METRICS["recall"].append(trial_metrics["recall"])
57+
METRICS["ndcg"].append(trial_metrics["ndcg"])
5758
METRICS["precision"].append(trial_metrics["precision"])
58-
METRICS["f1@k"].append(trial_metrics["f1"])
59+
METRICS["f1"].append(trial_metrics["f1"])
5960
METRICS["total_indexing_time"].append(trial_metrics["total_indexing_time"])
6061
METRICS["avg_query_time"].append(trial_metrics["avg_query_time"])
62+
METRICS["objective_value"].append(trial_metrics["objective_value"])
6163

6264

6365
def persist_metrics(
@@ -70,17 +72,30 @@ def persist_metrics(
7072
client.json().set(f"study:{study_id}", Path.root_path(), METRICS)
7173

7274

75+
def norm_metric(value: float):
76+
"""Normalize a metric value using 1/(1+value) formula.
77+
78+
Handles edge cases:
79+
- When value is -1, returns a large positive number (infinity equivalent)
80+
- When value is very negative, returns a large positive number
81+
- When value is very positive, returns a small positive number
82+
"""
83+
if value == -1:
84+
# Return a large positive number to represent "infinity" for optimization
85+
return 1000.0
86+
return 1 / (1 + value)
87+
88+
7389
def cost_fn(metrics: dict, weights: dict):
7490
objective = 0
7591
for key in metrics:
76-
objective += weights.get(key, 0) * metrics[key]
92+
if key == "avg_query_time" or key == "total_indexing_time":
93+
objective += weights.get(key, 0) * -norm_metric(metrics[key])
94+
else:
95+
objective += weights.get(key, 0) * metrics[key]
7796
return objective
7897

7998

80-
def norm_metric(value: float):
81-
return 1 / (1 + value)
82-
83-
8499
def objective(trial, study_config, redis_url, corpus_processor, search_method_map):
85100

86101
# optimizer will select hyperparameters from available option in study_config
@@ -152,19 +167,19 @@ def objective(trial, study_config, redis_url, corpus_processor, search_method_ma
152167
search_method_output = search_fn(search_input)
153168

154169
trial_metrics = utils.eval_trial_metrics(qrels, search_method_output.run)
155-
trial_metrics["total_indexing_time"] = -(total_indexing_time)
156-
trial_metrics["avg_query_time"] = -(
157-
utils.get_query_time_stats(search_method_output.query_metrics.query_times)[
158-
"avg_query_time"
159-
]
170+
trial_metrics["total_indexing_time"] = total_indexing_time
171+
trial_metrics["avg_query_time"] = utils.get_query_time_stats(
172+
search_method_output.query_metrics.query_times
173+
)["avg_query_time"]
174+
175+
trial_metrics["objective_value"] = cost_fn(
176+
trial_metrics, study_config.optimization_settings.metric_weights.model_dump()
160177
)
161178

162179
# save results as we go in case of failure
163180
persist_metrics(redis_url, trial_settings, trial_metrics, study_config.study_id)
164181

165-
return cost_fn(
166-
trial_metrics, study_config.optimization_settings.metric_weights.model_dump()
167-
)
182+
return trial_metrics["objective_value"]
168183

169184

170185
def run_bayes_study(

redis_retrieval_optimizer/grid_study.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ def update_metric_row(
2929
)
3030
metrics["model"].append(embedding_settings.model)
3131
metrics["model_dim"].append(embedding_settings.dim)
32-
metrics["recall@k"].append(trial_metrics["recall"])
33-
metrics["ndcg@k"].append(trial_metrics["ndcg"])
32+
metrics["recall"].append(trial_metrics["recall"])
33+
metrics["ndcg"].append(trial_metrics["ndcg"])
3434
metrics["precision"].append(trial_metrics["precision"])
35-
metrics["f1@k"].append(trial_metrics["f1"])
35+
metrics["f1"].append(trial_metrics["f1"])
3636
metrics["total_indexing_time"].append(trial_metrics["total_indexing_time"])
3737
metrics["avg_query_time"].append(trial_metrics["query_stats"]["avg_query_time"])
3838
return metrics
@@ -125,9 +125,9 @@ def run_grid_study(
125125
"search_method": [],
126126
"total_indexing_time": [],
127127
"avg_query_time": [],
128-
"recall@k": [],
129-
"ndcg@k": [],
130-
"f1@k": [],
128+
"recall": [],
129+
"ndcg": [],
130+
"f1": [],
131131
"precision": [],
132132
"ret_k": [],
133133
"algorithm": [],

redis_retrieval_optimizer/schema.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,12 @@ class EmbeddingModel(BaseModel):
8585

8686

8787
class MetricWeights(BaseModel):
88-
f1_at_k: int = 1
89-
embedding_latency: int = 1
90-
total_indexing_time: int = 1
88+
f1: float = 0
89+
recall: float = 0
90+
ndcg: float = 0
91+
precision: float = 0
92+
total_indexing_time: float = 0
93+
avg_query_time: float = 0
9194

9295

9396
class TrialSettings(BaseModel):

0 commit comments

Comments
 (0)