Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions ax/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from ax.core.runner import Runner
from ax.core.trial import Trial
from ax.core.trial_status import TrialStatus # Used as a return type
from ax.core.utils import get_pending_observation_features_based_on_trial_status
from ax.early_stopping.strategies import (
BaseEarlyStoppingStrategy,
PercentileEarlyStoppingStrategy,
Expand Down Expand Up @@ -386,11 +385,6 @@ def get_next_trials(
with with_rng_seed(seed=self._random_seed):
grs_for_trials = self._generation_strategy_or_choose().gen(
experiment=self._experiment,
pending_observations=(
get_pending_observation_features_based_on_trial_status(
experiment=self._experiment
)
),
n=1,
fixed_features=(
# pyre-fixme[6]: Type narrowing broken because core Ax
Expand Down
37 changes: 37 additions & 0 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,43 @@ def flatten_observation_features(
)
return obs_feats

def get_disabled_parameter_fixed_features(
self,
fixed_features_to_overlay_on: core.observation.ObservationFeatures
| None = None,
) -> core.observation.ObservationFeatures | None:
"""Get the fixed features that should be used to disable parameters in the
search space. This is used to ensure that parameters that are not part of the
search space are not used in the model.

Args:
fixed_features_to_overlay_on: Fixed features to overlay the disabled
parameter fixed features on top of. This is useful when we want to
disable some parameters but still have some fixed features that are
not part of the search space.

Returns:
``ObservationFeatures`` with disabled parameters set to their default
values, or ``None`` if there are no disabled parameters and no
``fixed_features_to_overlay_on`` was provided.
"""
disabled_parameters_parameterization = {
n: p.default_value for n, p in self.parameters.items() if p.is_disabled
}
if not fixed_features_to_overlay_on:
if not disabled_parameters_parameterization:
return None
return core.observation.ObservationFeatures(
parameters=disabled_parameters_parameterization
)

return fixed_features_to_overlay_on.clone(
replace_parameters={
**disabled_parameters_parameterization,
**fixed_features_to_overlay_on.parameters,
}
)

def _cast_parameterization(
self,
parameters: Mapping[str, TParamValue],
Expand Down
10 changes: 7 additions & 3 deletions ax/generation_strategy/center_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,13 @@ def update_generator_state(self, experiment: Experiment, data: Data) -> None:

def gen(
self,
*,
experiment: Experiment,
data: Data | None = None,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
pending_observations: dict[str, list[ObservationFeatures]] | None,
skip_fit: bool = False,
data: Data | None = None,
n: int | None = None,
arms_per_node: dict[str, int] | None = None,
**gs_gen_kwargs: Any,
) -> GeneratorRun | None:
"""Generate candidates or skip if search space is exhausted.
Expand Down Expand Up @@ -102,9 +105,10 @@ def gen(
# Otherwise, proceed with normal generation
return super().gen(
experiment=experiment,
data=data,
pending_observations=pending_observations,
skip_fit=skip_fit,
data=data,
arms_per_node=arms_per_node,
**gs_gen_kwargs,
)

Expand Down
Loading