Skip to content

Commit e4e39a3

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
Make SAAS prior sampleable
Summary: I want to sample datasets from the SAAS prior. This is only a draft PR to ask whether this should yield the right datasets, I would engineer things a little nicer + of course add tests if this looks good david. Differential Revision: D86770833
1 parent 3ff4d24 commit e4e39a3

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

botorch/models/fully_bayesian.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ class MaternPyroModel(PyroModel):
252252

253253
_outputscale_prior_concentration: float | None = None
254254
_outputscale_prior_rate: float | None = None
255+
_prior_mode: bool = False
255256

256257
def sample(self) -> None:
257258
r"""Sample from the Matern pyro model.
@@ -272,16 +273,35 @@ def sample(self) -> None:
272273
if self.train_Y.shape[-2] > 0:
273274
# Do not attempt to sample Y if the data is empty.
274275
# This leads to errors with empty data.
275-
K = matern52_kernel(X=X_tf, lengthscale=lengthscale)
276-
K = outputscale * K + noise * torch.eye(self.train_X.shape[0], **tkwargs)
277-
pyro.sample(
278-
"Y",
279-
pyro.distributions.MultivariateNormal(
280-
loc=mean.view(-1).expand(self.train_X.shape[0]),
281-
covariance_matrix=K,
282-
),
283-
obs=self.train_Y.squeeze(-1),
284-
)
276+
K_noiseless = outputscale * matern52_kernel(X=X_tf, lengthscale=lengthscale)
277+
K = K_noiseless + noise * torch.eye(self.train_X.shape[0], **tkwargs)
278+
if self._prior_mode:
279+
self.f = pyro.sample(
280+
"f",
281+
pyro.distributions.MultivariateNormal(
282+
loc=mean.view(-1).expand(self.train_X.shape[0]),
283+
covariance_matrix=K_noiseless
284+
+ 1e-7 * torch.eye(self.train_X.shape[0], **tkwargs),
285+
# sadly need to add a little bit of noise to be possible
286+
# to sample from this
287+
),
288+
)
289+
self.train_Y = pyro.sample(
290+
"Y",
291+
pyro.distributions.Normal(
292+
loc=self.f,
293+
scale=noise.sqrt(),
294+
),
295+
)
296+
else:
297+
pyro.sample(
298+
"Y",
299+
pyro.distributions.MultivariateNormal(
300+
loc=mean.view(-1).expand(self.train_X.shape[0]),
301+
covariance_matrix=K,
302+
),
303+
obs=self.train_Y.squeeze(-1),
304+
)
285305

286306
def sample_lengthscale(self, dim: int, **tkwargs: Any) -> Tensor:
287307
r"""Sample the lengthscale."""

0 commit comments

Comments
 (0)