diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index b6b0915afa9..ce6e5f4d695 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -454,11 +454,10 @@ def lookup_data(self) -> Data: """Lookup cached data on experiment for this trial. Returns: - If not merging across timestamps, the latest ``Data`` object - associated with the trial. If merging, all data for trial, merged. + All ``Data`` on the experiment that is associated with this trial. """ - return self.experiment.lookup_data_for_trial(trial_index=self.index) + return self.experiment.lookup_data(trial_indices={self.index}) def _check_existing_and_name_arm(self, arm: Arm) -> None: """Sets name for given arm; if this arm is already in the diff --git a/ax/core/experiment.py b/ax/core/experiment.py index fd0ca7bf8b5..b06465caf96 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -968,25 +968,6 @@ def attach_fetch_results( data = Data.from_multiple_data(data=[ok.ok for ok in oks]) self.attach_data(data=data) - def lookup_data_for_trial(self, trial_index: int) -> Data: - """Look up stored data for a specific trial. - - Returns data for this trial. Returns empty data if no data - is present. This method will not fetch data from metrics - to do that, - use `fetch_data()` instead. - - Args: - trial_index: The index of the trial to lookup data for. - - Returns: - The requested data object. - """ - if trial_index not in self.data.trial_indices: - return Data() - elif {trial_index} == self.data.trial_indices: - return self.data - return self.data.filter(trial_indices=[trial_index]) - def lookup_data( self, trial_indices: Iterable[int] | None = None, @@ -1009,6 +990,8 @@ def lookup_data( trial_indices = list(trial_indices) if len(trial_indices) == 0: return Data() + if set(trial_indices) == self.data.trial_indices: + return self.data return self.data.filter(trial_indices=trial_indices) @@ -1392,7 +1375,7 @@ def warm_start_from_old_experiment( none_throws(trial.arm).parameters, raise_error=search_space_check_membership_raise_error, ) - dat = old_experiment.lookup_data_for_trial(trial_index=trial.index) + dat = old_experiment.lookup_data(trial_indices={trial.index}) # Set trial index and arm name to their values in new trial. new_trial = self.new_trial() add_arm_and_prevent_naming_collision( diff --git a/ax/core/metric.py b/ax/core/metric.py index c2e14a28bd6..f7a34f1713e 100644 --- a/ax/core/metric.py +++ b/ax/core/metric.py @@ -362,8 +362,8 @@ def fetch_data_prefer_lookup( # first identify trial + metric combos to fetch, then fetch them all # at once. for trial in completed_trials: - cached_trial_data = experiment.lookup_data_for_trial( - trial_index=trial.index, + cached_trial_data = experiment.lookup_data( + trial_indices={trial.index}, ) cached_metric_signatures = cached_trial_data.metric_signatures diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index a2fa2f68d77..ec422629a84 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -610,7 +610,7 @@ def test_fetch_and_store_data(self) -> None: ) # Verify data lookup includes trials attached from `fetch_data`. - self.assertEqual(len(exp.lookup_data_for_trial(1).df), 30) + self.assertEqual(len(exp.lookup_data(trial_indices={1}).df), 30) # Test local storage exp.attach_data(batch_data) @@ -625,7 +625,7 @@ def test_fetch_and_store_data(self) -> None: self.assertEqual(len(exp.lookup_data().df), len(exp_data.df) + 1) # Test retrieving original batch 0 data - trial_0_df = exp.lookup_data_for_trial(0).df + trial_0_df = exp.lookup_data(trial_indices={0}).df self.assertEqual((trial_0_df["metric_name"] == "b").sum(), n) self.assertEqual( (trial_0_df["metric_name"] == "not_yet_on_experiment").sum(), 1 @@ -640,7 +640,7 @@ def test_fetch_and_store_data(self) -> None: # same result as `lookup_data_for_trial(0)` self.assertEqual( (df["trial_index"] == 0).sum(), - len(exp.lookup_data_for_trial(trial_index=0).df), + len(exp.lookup_data(trial_indices={0}).df), ) new_data = Data( df=pd.DataFrame.from_records( @@ -1374,7 +1374,8 @@ def test_clone_with(self) -> None: self.assertEqual(len(cloned_experiment.trials[0].arms), 16) self.assertEqual( - cloned_experiment.lookup_data_for_trial(1).df["trial_index"].iloc[0], 1 + cloned_experiment.lookup_data(trial_indices={1}).df["trial_index"].iloc[0], + 1, ) # Save the cloned experiment to db and make sure the original @@ -1387,7 +1388,7 @@ def test_clone_with(self) -> None: # With existing data. cloned_experiment = experiment.clone_with(trial_indices=[1]) self.assertEqual(len(cloned_experiment.trials), 1) - cloned_df = cloned_experiment.lookup_data_for_trial(0).df + cloned_df = cloned_experiment.lookup_data(trial_indices={0}).df self.assertEqual(cloned_df["trial_index"].iloc[0], 0) # Clone with data with "step" column diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 58dead781e5..ce339ff2258 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -1333,7 +1333,9 @@ def test_orchestrator_with_metric_with_new_data_after_completion(self) -> None: return_value=timedelta(hours=1), ): orchestrator.run_all_trials() - self.assertFalse(orchestrator.experiment.lookup_data_for_trial(0).full_df.empty) + self.assertFalse( + orchestrator.experiment.lookup_data(trial_indices={0}).full_df.empty + ) def test_run_trials_in_batches(self) -> None: gs = self.two_sobol_steps_GS diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index a1bdcdd5f21..1294d1bf436 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -1697,7 +1697,7 @@ def test_update_trial_data(self) -> None: ) ax_client.stop_trial_early(trial_index=idx) - df = ax_client.experiment.lookup_data_for_trial(idx).df + df = ax_client.experiment.lookup_data(trial_indices={idx}).df self.assertEqual(len(df), 1) # Failed trial. @@ -1706,7 +1706,7 @@ def test_update_trial_data(self) -> None: ax_client._update_trial_with_raw_data( trial_index=idx, raw_data=[(0, {"branin": (3, 0.0)})] ) - df = ax_client.experiment.lookup_data_for_trial(idx).df + df = ax_client.experiment.lookup_data(trial_indices={idx}).df self.assertEqual(df["mean"].item(), 3.0) # Incomplete trial fails diff --git a/ax/service/tests/test_best_point_utils.py b/ax/service/tests/test_best_point_utils.py index 32d6134ce54..5d3e7fbde64 100644 --- a/ax/service/tests/test_best_point_utils.py +++ b/ax/service/tests/test_best_point_utils.py @@ -667,9 +667,7 @@ def test_derelativize_opt_config(self) -> None: input_obj = assert_is_instance( input_optimization_config.objective, MultiObjective ) - status_quo_df = exp.lookup_data_for_trial( - trial_index=status_quo_trial_index - ).df + status_quo_df = exp.lookup_data(trial_indices={status_quo_trial_index}).df # This is not a real test of `derelativize_opt_config` but rather # making sure the values on the experiment have't drifted self.assertEqual(status_quo_df["metric_name"].tolist(), ["m1", "m2", "m3"]) diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index e47d2d7232c..c3df41f568a 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -299,7 +299,7 @@ def add_experiment_id(sqa: SQAData) -> None: # of merging each trial's Data into one Experiment-level Data, so we # must save it again, replacing the previously saved data for that # trial. - data = experiment.lookup_data_for_trial(trial_index=trial.index) + data = experiment.lookup_data(trial_indices={trial.index}) datas.append(data) data_encode_args.append({"trial_index": trial.index, "timestamp": 0}) diff --git a/ax/storage/sqa_store/tests/test_utils.py b/ax/storage/sqa_store/tests/test_utils.py index 92f486f3323..9a4a8d11eae 100644 --- a/ax/storage/sqa_store/tests/test_utils.py +++ b/ax/storage/sqa_store/tests/test_utils.py @@ -68,7 +68,7 @@ def test_CopyDBIDsDataExp(self) -> None: self.assertEqual(exp1, exp2) # empty some of exp2 db_ids - data = exp2.lookup_data_for_trial(0) + data = exp2.lookup_data(trial_indices={0}) # pyre-fixme[8]: Attribute has type `int`; used as `None`. data.db_id = None