Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ raw_data_path: "label_app/data/2008-mazda3-chunks.json"
input_data_type: "json"
labeled_data_path: "label_app/data/mazda_labeled_items.json"
# metrics to be used in objective function
metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
# weight of each metric
weights: [1, 1, 1]
metric_weights:
f1_at_k: 1
embedding_latency: 1
total_indexing_time: 1
# constraints for the optimization
n_trials: 10
n_jobs: 1
Expand All @@ -176,14 +177,13 @@ embedding_models:
|----------------------|------------------------------------------------|--------------------------------------------------|----------|
| **raw_data_path** | `label_app/data/2008-mazda3-chunks.json` | Path to raw data file | ✅ |
| **labeled_data_path** | `label_app/data/mazda-labeled-rewritten.json` | Path to labeled data file | ✅ |
| **metrics** | f1_at_k, embedding_latency, total_indexing_time | Metrics used in the objective function | ✅ |
| **weights** | [1, 1, 1] | Weights for f1_at_k, embedding_latency, total_indexing_time respectively. | ✅ |
| **algorithms** | flat, hnsw | Indexing algorithms to be tested in optimization | ✅ |
| **vector_data_types** | float32, float16 | Data types to be tested for vectors | ✅ |
| **n_trials** | 15 | Number of optimization trials | ✅ |
| **n_jobs** | 1 | Number of parallel jobs | ✅ |
| **ret_k** | [1, 10] | Range of values to be tested for `k` in retrieval | ✅ |
| **embedding_models** | **Provider:** hf <br> **Model:** sentence-transformers/all-MiniLM-L6-v2 <br> **Dim:** 384 | List of embedding models and their dimensions | ✅ |
| **metric_weights** | **f1_at_k:** 1 <br> **embedding_latency:** 1 <br> **total_indexing_time:** 1 | Weight for respective metric used in the objective function | defaults to example |
| **input_data_type** | json | Type of input data | defaults to example |
| **redis_url** | `redis://localhost:6379` | Connection string for redis instance | defaults to example |
| **ef_runtime** | [10, 20, 30, 50] | Max top candidates during search for HNSW | defaults to example |
Expand Down
8 changes: 5 additions & 3 deletions examples/dbpedia/dbpedia_study_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ labeled_data_path: "data/dbpedia_labeled.json" # labeled data
n_trials: 3
n_jobs: 1

# metrics to be used in objective function
metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
weights: [1, 1, 1] # weight of each metric respectively
# metric weights to be used in objective function
metric_weights:
f1_at_k: 1
embedding_latency: 1
total_indexing_time: 1

# optimization decision variables
algorithms: ["flat", "hnsw"] # indexing algorithms variables
Expand Down
8 changes: 5 additions & 3 deletions examples/getting_started/custom_retriever_optimizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@
"n_trials: 20\n",
"n_jobs: 1\n",
"\n",
"# metrics to be used in objective function\n",
"metrics: [\"f1_at_k\", \"embedding_latency\", \"total_indexing_time\"] \n",
"weights: [1, 1, 1] # weight of each metric respectively \n",
"# Metric weights to be used in objective function\n",
"metric_weights:\n",
" f1_at_k: 1\n",
" embedding_latency: 1\n",
" total_indexing_time: 1\n",
"\n",
"# optimization decision variables\n",
"algorithms: [\"flat\", \"hnsw\"] # indexing algorithms variables\n",
Expand Down
8 changes: 5 additions & 3 deletions examples/getting_started/retrieval_optimizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,11 @@
"n_trials: 20\n",
"n_jobs: 1\n",
"\n",
"# metrics to be used in objective function\n",
"metrics: [\"f1_at_k\", \"embedding_latency\", \"total_indexing_time\"] \n",
"weights: [1, 1, 1] # weight of each metric respectively \n",
"# Metric weights to be used in objective function\n",
"metric_weights:\n",
" f1_at_k: 1\n",
" embedding_latency: 1\n",
" total_indexing_time: 1\n",
"\n",
"# optimization decision variables\n",
"algorithms: [\"flat\", \"hnsw\"] # indexing algorithms variables\n",
Expand Down
6 changes: 4 additions & 2 deletions examples/getting_started/study_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ n_trials: 20
n_jobs: 1

# metrics to be used in objective function
metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
weights: [1, 1, 1] # weight of each metric respectively
metric_weights:
f1_at_k: 1
embedding_latency: 1
total_indexing_time: 1

# optimization decision variables
algorithms: ["flat", "hnsw"] # indexing algorithms variables
Expand Down
7 changes: 4 additions & 3 deletions optimize/ex_study_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ input_data_type: "json"
raw_data_path: "label_app/data/2008-mazda3-chunks.json"
labeled_data_path: "label_app/data/mazda-labeled-rewritten.json"
# metrics to be used in objective function
metrics: ["f1_at_k", "embedding_latency", "total_indexing_time"]
# weight of each metric
weights: [1, 1, 1]
metric_weights:
f1_at_k: 1
embedding_latency: 1
total_indexing_time: 1
algorithms: ["flat", "hnsw"]
vector_data_types: ["float32", "float16"]
# constraints for the optimization
Expand Down
9 changes: 7 additions & 2 deletions optimize/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class EmbeddingModel(BaseModel):
dim: int


class MetricWeights(BaseModel):
f1_at_k: int = 1
embedding_latency: int = 1
total_indexing_time: int = 1


class StudyConfig(BaseModel):
study_id: str = str(uuid4())
redis_url: str = "redis://localhost:6379/0"
Expand All @@ -54,10 +60,9 @@ class StudyConfig(BaseModel):
input_data_type: str
labeled_data_path: str
embedding_models: list[EmbeddingModel]
metrics: list[str]
weights: list[float]
n_trials: int
n_jobs: int
metric_weights: MetricWeights = MetricWeights()
ret_k: tuple[int, int] = [1, 10] # type: ignore # pydantic vs mypy
ef_runtime: list = [10, 50]
ef_construction: list = [100, 300]
Expand Down
9 changes: 8 additions & 1 deletion optimize/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,14 @@ def objective(trial, study_config, custom_retrievers, redis_client):

metric_values = [e.f1_at_k, norm_index_time, norm_latency]

e.obj_val = cost_fn(metric_values, study_config.weights)
e.obj_val = cost_fn(
metric_values,
[
study_config.metric_weights.f1_at_k,
study_config.metric_weights.total_indexing_time,
study_config.metric_weights.embedding_latency,
],
)

# save results as we go in case of failure
persist_metrics(redis_client, e, study_config.study_id)
Expand Down
11 changes: 8 additions & 3 deletions optimize/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@
EmbeddingModel,
EmbeddingSettings,
IndexSettings,
MetricWeights,
Settings,
StudyConfig,
)

TEST_REDIS_URL = os.getenv("TEST_REDIS_URL", "redis://localhost:6379/0")


@pytest.fixture
def metric_weights():
return MetricWeights(f1_at_k=1, embedding_latency=1, total_indexing_time=1)


@pytest.fixture
def settings():
return Settings(
Expand Down Expand Up @@ -51,7 +57,7 @@ def test_db_client():


@pytest.fixture
def study_config(embedding_model):
def study_config(embedding_model, metric_weights):
return StudyConfig(
study_id="study_id",
redis_url=TEST_REDIS_URL,
Expand All @@ -67,6 +73,5 @@ def study_config(embedding_model):
embedding_models=[embedding_model],
n_trials=1,
n_jobs=1,
metrics=["f1_at_k", "embedding_latency", "total_indexing_time"],
weights=[1, 1, 1],
metric_weights=metric_weights,
)