Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,10 @@ def lookup_data(self) -> Data:
"""Lookup cached data on experiment for this trial.

Returns:
If not merging across timestamps, the latest ``Data`` object
associated with the trial. If merging, all data for trial, merged.
All ``Data`` on the experiment that is associated with this trial.

"""
return self.experiment.lookup_data_for_trial(trial_index=self.index)
return self.experiment.lookup_data(trial_indices={self.index})

def _check_existing_and_name_arm(self, arm: Arm) -> None:
"""Sets name for given arm; if this arm is already in the
Expand Down
23 changes: 3 additions & 20 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,25 +968,6 @@ def attach_fetch_results(
data = Data.from_multiple_data(data=[ok.ok for ok in oks])
self.attach_data(data=data)

def lookup_data_for_trial(self, trial_index: int) -> Data:
"""Look up stored data for a specific trial.

Returns data for this trial. Returns empty data if no data
is present. This method will not fetch data from metrics - to do that,
use `fetch_data()` instead.

Args:
trial_index: The index of the trial to lookup data for.

Returns:
The requested data object.
"""
if trial_index not in self.data.trial_indices:
return Data()
elif {trial_index} == self.data.trial_indices:
return self.data
return self.data.filter(trial_indices=[trial_index])

def lookup_data(
self,
trial_indices: Iterable[int] | None = None,
Expand All @@ -1009,6 +990,8 @@ def lookup_data(
trial_indices = list(trial_indices)
if len(trial_indices) == 0:
return Data()
if set(trial_indices) == self.data.trial_indices:
return self.data

return self.data.filter(trial_indices=trial_indices)

Expand Down Expand Up @@ -1392,7 +1375,7 @@ def warm_start_from_old_experiment(
none_throws(trial.arm).parameters,
raise_error=search_space_check_membership_raise_error,
)
dat = old_experiment.lookup_data_for_trial(trial_index=trial.index)
dat = old_experiment.lookup_data(trial_indices={trial.index})
# Set trial index and arm name to their values in new trial.
new_trial = self.new_trial()
add_arm_and_prevent_naming_collision(
Expand Down
4 changes: 2 additions & 2 deletions ax/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ def fetch_data_prefer_lookup(
# first identify trial + metric combos to fetch, then fetch them all
# at once.
for trial in completed_trials:
cached_trial_data = experiment.lookup_data_for_trial(
trial_index=trial.index,
cached_trial_data = experiment.lookup_data(
trial_indices={trial.index},
)

cached_metric_signatures = cached_trial_data.metric_signatures
Expand Down
11 changes: 6 additions & 5 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def test_fetch_and_store_data(self) -> None:
)

# Verify data lookup includes trials attached from `fetch_data`.
self.assertEqual(len(exp.lookup_data_for_trial(1).df), 30)
self.assertEqual(len(exp.lookup_data(trial_indices={1}).df), 30)

# Test local storage
exp.attach_data(batch_data)
Expand All @@ -625,7 +625,7 @@ def test_fetch_and_store_data(self) -> None:
self.assertEqual(len(exp.lookup_data().df), len(exp_data.df) + 1)

# Test retrieving original batch 0 data
trial_0_df = exp.lookup_data_for_trial(0).df
trial_0_df = exp.lookup_data(trial_indices={0}).df
self.assertEqual((trial_0_df["metric_name"] == "b").sum(), n)
self.assertEqual(
(trial_0_df["metric_name"] == "not_yet_on_experiment").sum(), 1
Expand All @@ -640,7 +640,7 @@ def test_fetch_and_store_data(self) -> None:
# same result as `lookup_data_for_trial(0)`
self.assertEqual(
(df["trial_index"] == 0).sum(),
len(exp.lookup_data_for_trial(trial_index=0).df),
len(exp.lookup_data(trial_indices={0}).df),
)
new_data = Data(
df=pd.DataFrame.from_records(
Expand Down Expand Up @@ -1374,7 +1374,8 @@ def test_clone_with(self) -> None:
self.assertEqual(len(cloned_experiment.trials[0].arms), 16)

self.assertEqual(
cloned_experiment.lookup_data_for_trial(1).df["trial_index"].iloc[0], 1
cloned_experiment.lookup_data(trial_indices={1}).df["trial_index"].iloc[0],
1,
)

# Save the cloned experiment to db and make sure the original
Expand All @@ -1387,7 +1388,7 @@ def test_clone_with(self) -> None:
# With existing data.
cloned_experiment = experiment.clone_with(trial_indices=[1])
self.assertEqual(len(cloned_experiment.trials), 1)
cloned_df = cloned_experiment.lookup_data_for_trial(0).df
cloned_df = cloned_experiment.lookup_data(trial_indices={0}).df
self.assertEqual(cloned_df["trial_index"].iloc[0], 0)

# Clone with data with "step" column
Expand Down
4 changes: 3 additions & 1 deletion ax/orchestration/tests/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,9 @@ def test_orchestrator_with_metric_with_new_data_after_completion(self) -> None:
return_value=timedelta(hours=1),
):
orchestrator.run_all_trials()
self.assertFalse(orchestrator.experiment.lookup_data_for_trial(0).full_df.empty)
self.assertFalse(
orchestrator.experiment.lookup_data(trial_indices={0}).full_df.empty
)

def test_run_trials_in_batches(self) -> None:
gs = self.two_sobol_steps_GS
Expand Down
4 changes: 2 additions & 2 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,7 +1697,7 @@ def test_update_trial_data(self) -> None:
)

ax_client.stop_trial_early(trial_index=idx)
df = ax_client.experiment.lookup_data_for_trial(idx).df
df = ax_client.experiment.lookup_data(trial_indices={idx}).df
self.assertEqual(len(df), 1)

# Failed trial.
Expand All @@ -1706,7 +1706,7 @@ def test_update_trial_data(self) -> None:
ax_client._update_trial_with_raw_data(
trial_index=idx, raw_data=[(0, {"branin": (3, 0.0)})]
)
df = ax_client.experiment.lookup_data_for_trial(idx).df
df = ax_client.experiment.lookup_data(trial_indices={idx}).df
self.assertEqual(df["mean"].item(), 3.0)

# Incomplete trial fails
Expand Down
4 changes: 1 addition & 3 deletions ax/service/tests/test_best_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,7 @@ def test_derelativize_opt_config(self) -> None:
input_obj = assert_is_instance(
input_optimization_config.objective, MultiObjective
)
status_quo_df = exp.lookup_data_for_trial(
trial_index=status_quo_trial_index
).df
status_quo_df = exp.lookup_data(trial_indices={status_quo_trial_index}).df
# This is not a real test of `derelativize_opt_config` but rather
# making sure the values on the experiment have't drifted
self.assertEqual(status_quo_df["metric_name"].tolist(), ["m1", "m2", "m3"])
Expand Down
2 changes: 1 addition & 1 deletion ax/storage/sqa_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def add_experiment_id(sqa: SQAData) -> None:
# of merging each trial's Data into one Experiment-level Data, so we
# must save it again, replacing the previously saved data for that
# trial.
data = experiment.lookup_data_for_trial(trial_index=trial.index)
data = experiment.lookup_data(trial_indices={trial.index})
datas.append(data)
data_encode_args.append({"trial_index": trial.index, "timestamp": 0})

Expand Down
2 changes: 1 addition & 1 deletion ax/storage/sqa_store/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_CopyDBIDsDataExp(self) -> None:
self.assertEqual(exp1, exp2)

# empty some of exp2 db_ids
data = exp2.lookup_data_for_trial(0)
data = exp2.lookup_data(trial_indices={0})
# pyre-fixme[8]: Attribute has type `int`; used as `None`.
data.db_id = None

Expand Down