-
Notifications
You must be signed in to change notification settings - Fork 68
Add basic GaussianProcessSurrogate.from_prior constructor method
#717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
830224d
e88f94b
cdf71a9
1dff685
7d17a12
eacdb28
4cc837d
63e385c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -24,9 +24,11 @@ | |||||
| DefaultKernelFactory, | ||||||
| _default_noise_factory, | ||||||
| ) | ||||||
| from baybe.surrogates.gaussian_process.prior_modules import PriorMean | ||||||
| from baybe.utils.conversion import to_string | ||||||
|
|
||||||
| if TYPE_CHECKING: | ||||||
| from botorch.models import SingleTaskGP | ||||||
| from botorch.models.gpytorch import GPyTorchModel | ||||||
| from botorch.models.transforms.input import InputTransform | ||||||
| from botorch.models.transforms.outcome import OutcomeTransform | ||||||
|
|
@@ -113,11 +115,60 @@ class GaussianProcessSurrogate(Surrogate): | |||||
| _model = field(init=False, default=None, eq=False) | ||||||
| """The actual model.""" | ||||||
|
|
||||||
| # Transfer learning fields | ||||||
| _prior_gp: SingleTaskGP | None = field(init=False, default=None, eq=False) | ||||||
| """Prior GP to extract mean/covariance for transfer learning.""" | ||||||
|
|
||||||
| @staticmethod | ||||||
| def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate: | ||||||
| """Create a Gaussian process surrogate from one of the defined presets.""" | ||||||
| return make_gp_from_preset(preset) | ||||||
|
|
||||||
| @classmethod | ||||||
| def from_prior( | ||||||
AVHopp marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| cls, | ||||||
| prior_gp: GaussianProcessSurrogate, | ||||||
| kernel_factory: KernelFactory | None = None, | ||||||
| **kwargs, | ||||||
| ) -> GaussianProcessSurrogate: | ||||||
| """Create a GP surrogate using a prior GP's predictions as the mean function. | ||||||
|
|
||||||
| Transfers knowledge by using the prior GP's posterior mean predictions | ||||||
| as the mean function for a new GP, while learning covariance from scratch. | ||||||
|
|
||||||
| Args: | ||||||
| prior_gp: Fitted GaussianProcessSurrogate to use as prior | ||||||
| kernel_factory: Kernel factory for covariance components | ||||||
| **kwargs: Additional arguments for GaussianProcessSurrogate constructor | ||||||
|
|
||||||
| Returns: | ||||||
| New GaussianProcessSurrogate instance with transfer learning | ||||||
|
|
||||||
| Raises: | ||||||
| ValueError: If prior_gp is not a GaussianProcessSurrogate or is not fitted | ||||||
| """ | ||||||
| from copy import deepcopy | ||||||
|
|
||||||
| # Validate prior GP is fitted | ||||||
| if not isinstance(prior_gp, cls): | ||||||
| raise ValueError( | ||||||
| "prior_gp must be a fitted GaussianProcessSurrogate instance" | ||||||
| ) | ||||||
| if prior_gp._model is None: | ||||||
| raise ValueError("Prior GP must be fitted before use") | ||||||
|
|
||||||
| # Configure kernel factory (always needed since we only do mean transfer now) | ||||||
| if kernel_factory is None: | ||||||
| kernel_factory = DefaultKernelFactory() | ||||||
|
|
||||||
| # Create new surrogate instance | ||||||
| instance = cls(kernel_or_factory=kernel_factory, **kwargs) | ||||||
|
|
||||||
| # Configure for transfer learning - store the BoTorch model | ||||||
| instance._prior_gp = deepcopy(prior_gp.to_botorch()) | ||||||
|
|
||||||
| return instance | ||||||
|
|
||||||
| @override | ||||||
| def to_botorch(self) -> GPyTorchModel: | ||||||
| return self._model | ||||||
|
|
@@ -152,22 +203,30 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: | |||||
| assert self._searchspace is not None | ||||||
|
|
||||||
| context = _ModelContext(self._searchspace) | ||||||
|
|
||||||
| numerical_idxs = context.get_numerical_indices(train_x.shape[-1]) | ||||||
|
|
||||||
| # For GPs, we let botorch handle the scaling. See [Scaling Workaround] above. | ||||||
| input_transform = Normalize( | ||||||
| train_x.shape[-1], | ||||||
| bounds=context.parameter_bounds, | ||||||
| indices=list(numerical_idxs), | ||||||
| ) | ||||||
| outcome_transform = Standardize(train_y.shape[-1]) | ||||||
|
|
||||||
| # extract the batch shape of the training data | ||||||
| batch_shape = train_x.shape[:-2] | ||||||
|
|
||||||
| # Configure input/output transforms | ||||||
| if self._prior_gp is not None and hasattr(self._prior_gp, "input_transform"): | ||||||
AVHopp marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| # Use prior's transforms for consistency in transfer learning | ||||||
| input_transform = self._prior_gp.input_transform | ||||||
| outcome_transform = self._prior_gp.outcome_transform | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since there is an explicit check for Why is the heck for |
||||||
| else: | ||||||
| # For GPs, we let botorch handle scaling. See [Scaling Workaround] above. | ||||||
| input_transform = Normalize( | ||||||
| train_x.shape[-1], | ||||||
| bounds=context.parameter_bounds, | ||||||
| indices=numerical_idxs, | ||||||
|
||||||
| indices=numerical_idxs, | |
| indices=list(numerical_idxs), |
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the name |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| """Prior modules for Gaussian process transfer learning.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from copy import deepcopy | ||
| from typing import Any | ||
|
|
||
| import gpytorch | ||
| import torch | ||
| from botorch.models import SingleTaskGP | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| class PriorMean(gpytorch.means.Mean): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some question for understanding:
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there might be some misunderstanding here. The incoming mean is not constant since the prior GP is fitted on some data already and we are using its posterior here. Even if the prior GP originally had a
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see thanks for clarifying, I see now the need for the class Is it also right that this implementaiton is the variant of completely frozen prio mean? ie the mean is not just a prior but its forever the mean for our actual GP used int he campaign? |
||
| """GPyTorch mean module using a trained GP as prior mean. | ||
|
|
||
| This mean module wraps a trained Gaussian Process and uses its predictions | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| as the mean function for another GP. | ||
|
|
||
| Args: | ||
| gp: Trained Gaussian Process to use as mean function. | ||
| batch_shape: Batch shape for the mean module. | ||
| **kwargs: Additional keyword arguments. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it necessary to include those in this class/the |
||
| """ | ||
|
|
||
| def __init__( | ||
| self, gp: SingleTaskGP, batch_shape: torch.Size = torch.Size(), **kwargs: Any | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| # Deep copy and freeze the GP | ||
| self.gp: SingleTaskGP = deepcopy(gp) | ||
| self.batch_shape: torch.Size = batch_shape | ||
|
|
||
| # Freeze parameters and set eval mode once | ||
| for param in self.gp.parameters(): | ||
| param.requires_grad = False | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| """Compute the mean function using the wrapped GP. | ||
|
|
||
| Args: | ||
| x: Input tensor for which to compute the mean. | ||
|
|
||
| Returns: | ||
| Mean predictions from the wrapped GP. | ||
| """ | ||
| self.gp.eval() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wouldnt it make sense to move these |
||
| self.gp.likelihood.eval() | ||
AVHopp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| with torch.no_grad(), gpytorch.settings.fast_pred_var(): | ||
| mean = self.gp(x).mean.detach() | ||
|
|
||
| # Handle batch dimensions | ||
| target_shape = torch.broadcast_shapes(self.batch_shape, x.shape[:-1]) | ||
| return mean.reshape(target_shape) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given that we already have this
_modelattribute, can you explain why we need to introduce yet another attribute like_prior_gp? Naively I would suspect the first contains the latter. Or at least we should strive to avoid putting alot of additional attributes in this class (because they will esentially be irrelevant for non TL cases)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point and I thought the same before, but unfortunately I couldn't find a good solution to this: The problem is that there is a gap between creation via
from_priorand fitting the model viafit. The instance must somehow remember it should use transfer learning and its prior to be able to create the_model. I'd be happy to change this and will give it another thought. Maybe the logic could be moved to someKernelFactoryorMeanFactory. Do you have any suggestions how to get rid of this attribute?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AVHopp , do you maybe have an idea?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about introducing a new mean factory similar to the kernel factories in BayBE?
Then in
from_priorI'd just replace the mean factory by the newPriorMeanFactoryand could remove the attribute from the surrogate class, but this would add an entirely new factory pattern to BayBE.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I would prefer that a bit
although the main reason for the factory was that search space info is needed when creating the kernels, which is not available yet when specifying the attribute here to the surrogate. that doesnt seem to bet he case here with the means, right?
So a factory is not strictly needed but I see till two advantages why I would prefer it:
About
_model: This is supposed to hold the fitted botorch model right? So would it make any sense to only partially initialize it with the means? If no, then forget hat ideaThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this thread now still relevant, given that we agreed to a
Factoryapproach in our meeting (iirc)?