Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 61 additions & 23 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,16 @@ class PyroModel:
the subclasses are used as inputs to a `SaasFullyBayesianSingleTaskGP`,
which should then have its hyperparameters fit with
`fit_fully_bayesian_model_nuts`. (By default, its subclass `SaasPyroModel`
is used). A `PyroModel`s `sample` method should specify lightweight
is used). A `PyroModel`'s `sample` method should specify lightweight
PyTorch functionality, which will be used for fast model fitting with NUTS.
The utility of `PyroModel` is in enabling fast fitting with NUTS, since we
would otherwise need to use GPyTorch, which is computationally infeasible
in combination with Pyro.
"""

_prior_mode: bool = False
_noiseless_eps_for_sampleability: float = 1e-7

def __init__(
self,
use_input_warping: bool = False,
Expand Down Expand Up @@ -242,6 +245,56 @@ def sample_concentrations(self, **tkwargs: Any) -> tuple[Tensor, Tensor]:

return c0, c1

def sample_observations(
self,
mean: Tensor,
K_noiseless: Tensor,
noise: Tensor,
**tkwargs: Any,
) -> None:
r"""Sample the observations Y (or prior samples in prior mode).

Args:
mean: The mean constant.
K_noiseless: The kernel matrix without noise.
noise: The noise variance.
**tkwargs: dtype and device keyword arguments.
"""
if self.train_Y.shape[-2] == 0:
# Do not attempt to sample Y if the data is empty.
return

n = self.train_X.shape[0]
K = K_noiseless + noise * torch.eye(n, **tkwargs)

if self._prior_mode:
self.f_prior_sample = pyro.sample(
"f",
pyro.distributions.MultivariateNormal(
loc=mean.view(-1).expand(n),
covariance_matrix=K_noiseless
+ self._noiseless_eps_for_sampleability * torch.eye(n, **tkwargs),
# sadly need to add a little bit of noise to be possible
# to sample from this
),
)
self.Y_prior_sample = pyro.sample(
"Y",
pyro.distributions.Normal(
loc=self.f_prior_sample,
scale=noise.sqrt(),
),
)
else:
pyro.sample(
"Y",
pyro.distributions.MultivariateNormal(
loc=mean.view(-1).expand(n),
covariance_matrix=K,
),
obs=self.train_Y.squeeze(-1),
)


class MaternPyroModel(PyroModel):
r"""Implementation of the a fully Bayesian model with a dimension-scaling prior.
Expand Down Expand Up @@ -269,19 +322,10 @@ def sample(self) -> None:
noise = self.sample_noise(**tkwargs)
lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs)
X_tf = self._maybe_input_warp(self.train_X, **tkwargs)
if self.train_Y.shape[-2] > 0:
# Do not attempt to sample Y if the data is empty.
# This leads to errors with empty data.
K = matern52_kernel(X=X_tf, lengthscale=lengthscale)
K = outputscale * K + noise * torch.eye(self.train_X.shape[0], **tkwargs)
pyro.sample(
"Y",
pyro.distributions.MultivariateNormal(
loc=mean.view(-1).expand(self.train_X.shape[0]),
covariance_matrix=K,
),
obs=self.train_Y.squeeze(-1),
)
K_noiseless = outputscale * matern52_kernel(X=X_tf, lengthscale=lengthscale)
self.sample_observations(
mean=mean, K_noiseless=K_noiseless, noise=noise, **tkwargs
)

def sample_lengthscale(self, dim: int, **tkwargs: Any) -> Tensor:
r"""Sample the lengthscale."""
Expand Down Expand Up @@ -503,16 +547,10 @@ def sample(self) -> None:
weight_variance = self.sample_weight_variance(**tkwargs)
X_tf = self._maybe_input_warp(X=self.train_X, **tkwargs)
X_tf = X_tf - 0.5 # center transformed data at 0 (for linear model)
K = linear_kernel(X=X_tf, weight_variance=weight_variance)
K_noiseless = linear_kernel(X=X_tf, weight_variance=weight_variance)
noise = self.sample_noise(**tkwargs)
K = K + noise * torch.eye(self.train_X.shape[0], **tkwargs)
pyro.sample(
"Y",
pyro.distributions.MultivariateNormal(
loc=mean.view(-1).expand(self.train_X.shape[0]),
covariance_matrix=K,
),
obs=self.train_Y.squeeze(-1),
self.sample_observations(
mean=mean, K_noiseless=K_noiseless, noise=noise, **tkwargs
)

def sample_weight_variance(self, alpha: float = 0.1, **tkwargs: Any) -> Tensor:
Expand Down
134 changes: 134 additions & 0 deletions test/models/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,140 @@ def load_mcmc_samples(self, mcmc_samples) -> None:
pass


class TestPyroModelPriorMode(BotorchTestCase):
"""Tests for the _prior_mode attribute and sample_observations method."""

def test_prior_mode_attribute_on_base_class(self) -> None:
"""Test that _prior_mode is accessible on the base PyroModel class."""
# Test that _prior_mode defaults to False
self.assertFalse(PyroModel._prior_mode)

# Test that subclasses inherit the attribute
self.assertFalse(MaternPyroModel._prior_mode)
self.assertFalse(SaasPyroModel._prior_mode)
self.assertFalse(LinearPyroModel._prior_mode)

def test_sample_observations_normal_mode(self) -> None:
"""Test sample_observations in normal (non-prior) mode."""
tkwargs = {"dtype": torch.double, "device": self.device}
n, d = 5, 3

# Create a PyroModel subclass instance
pyro_model = MaternPyroModel()
train_X = torch.rand(n, d, **tkwargs)
train_Y = torch.rand(n, 1, **tkwargs)
pyro_model.set_inputs(train_X=train_X, train_Y=train_Y)

# Ensure _prior_mode is False
self.assertFalse(pyro_model._prior_mode)

mean = torch.zeros(1, **tkwargs)
K_noiseless = torch.eye(n, **tkwargs)
noise = torch.tensor(0.1, **tkwargs)

# In normal mode, sample_observations should call pyro.sample with obs
with patch.object(pyro, "sample") as mock_sample:
pyro_model.sample_observations(
mean=mean, K_noiseless=K_noiseless, noise=noise, **tkwargs
)
# Verify pyro.sample was called with obs argument
mock_sample.assert_called_once()
call_kwargs = mock_sample.call_args[1]
self.assertIn("obs", call_kwargs)
self.assertEqual(mock_sample.call_args[0][0], "Y")

def test_sample_observations_prior_mode(self) -> None:
"""Test sample_observations in prior mode."""
tkwargs = {"dtype": torch.double, "device": self.device}
n, d = 5, 3

# Create a PyroModel subclass instance
pyro_model = MaternPyroModel()
train_X = torch.rand(n, d, **tkwargs)
train_Y = torch.rand(n, 1, **tkwargs)
pyro_model.set_inputs(train_X=train_X, train_Y=train_Y)

# Set _prior_mode to True
pyro_model._prior_mode = True

mean = torch.zeros(1, **tkwargs)
K_noiseless = torch.eye(n, **tkwargs)
noise = torch.tensor(0.1, **tkwargs)

# In prior mode, sample_observations should sample both "f" and "Y"
with patch.object(pyro, "sample") as mock_sample:
mock_sample.return_value = torch.randn(n, **tkwargs)
pyro_model.sample_observations(
mean=mean, K_noiseless=K_noiseless, noise=noise, **tkwargs
)
# Verify pyro.sample was called twice (for "f" and "Y")
self.assertEqual(mock_sample.call_count, 2)
# First call should be for "f"
self.assertEqual(mock_sample.call_args_list[0][0][0], "f")
# Second call should be for "Y"
self.assertEqual(mock_sample.call_args_list[1][0][0], "Y")
# Neither call should have obs argument
for call in mock_sample.call_args_list:
self.assertNotIn("obs", call[1])

def test_sample_observations_empty_data(self) -> None:
"""Test that sample_observations returns early for empty data."""
tkwargs = {"dtype": torch.double, "device": self.device}
d = 3

# Create a PyroModel subclass instance with empty data
pyro_model = MaternPyroModel()
train_X = torch.rand(0, d, **tkwargs)
train_Y = torch.rand(0, 1, **tkwargs)
pyro_model.set_inputs(train_X=train_X, train_Y=train_Y)

mean = torch.zeros(1, **tkwargs)
K_noiseless = torch.eye(0, **tkwargs)
noise = torch.tensor(0.1, **tkwargs)

# sample_observations should return early without calling pyro.sample
with patch.object(pyro, "sample") as mock_sample:
pyro_model.sample_observations(
mean=mean, K_noiseless=K_noiseless, noise=noise, **tkwargs
)
mock_sample.assert_not_called()

def test_matern_pyro_model_sample_with_prior_mode(self) -> None:
"""Test MaternPyroModel.sample() with _prior_mode enabled."""
tkwargs = {"dtype": torch.double, "device": self.device}
n, d = 5, 3

pyro_model = MaternPyroModel()
train_X = torch.rand(n, d, **tkwargs)
train_Y = torch.rand(n, 1, **tkwargs)
pyro_model.set_inputs(train_X=train_X, train_Y=train_Y)

# Enable prior mode
pyro_model._prior_mode = True

# Mock pyro.sample to return valid tensors
def mock_sample_fn(name, dist):
if name == "mean":
return torch.tensor(0.0, **tkwargs)
elif name == "noise":
return torch.tensor(0.01, **tkwargs)
elif name == "lengthscale":
return torch.ones(d, **tkwargs)
elif name == "f":
return torch.randn(n, **tkwargs)
elif name == "Y":
return torch.randn(n, **tkwargs)
else:
return torch.tensor(1.0, **tkwargs)

with patch.object(pyro, "sample", side_effect=mock_sample_fn):
# Should not raise any errors
pyro_model.sample()
# Check that prior samples are stored
self.assertIsNotNone(pyro_model.f_prior_sample)
self.assertIsNotNone(pyro_model.Y_prior_sample)


class TestSaasFullyBayesianSingleTaskGP(BotorchTestCase):
model_cls: type[FullyBayesianSingleTaskGP] = SaasFullyBayesianSingleTaskGP
pyro_model_cls: type[PyroModel] = SaasPyroModel
Expand Down