Add basic GaussianProcessSurrogate.from_prior constructor method#717
Add basic GaussianProcessSurrogate.from_prior constructor method#717
GaussianProcessSurrogate.from_prior constructor method#717Conversation
- construct a GP by transferring knowledge from a pre-trained prior GP - basic implementation for full mean transfer - the posterior mean of the pretrained GP are used as mean module for GP - hypereparameters are frozen and mean is evaluated at source points - interface might later be extended to other mean transfers (initialize hyperparameters) or covariance transfer - new `PriorMean` class, that implements mean of prior GP as botorch module
| class PriorMean(gpytorch.means.Mean): | ||
| """GPyTorch mean module using a trained GP as prior mean. | ||
|
|
||
| This mean module wraps a trained Gaussian Process and uses its predictions |
There was a problem hiding this comment.
process in Gaussian process should not be capitalized (unless its a headline or similar)
| New GaussianProcessSurrogate instance with transfer learning | ||
|
|
||
| Raises: | ||
| ValueError: If prior_gp is not fitted |
There was a problem hiding this comment.
there is also a raise of the variable is not a SingleTaskGP which is not mentioned here?
| @@ -113,11 +115,57 @@ class GaussianProcessSurrogate(Surrogate): | |||
| _model = field(init=False, default=None, eq=False) | |||
There was a problem hiding this comment.
given that we already have this _model attribute, can you explain why we need to introduce yet another attribute like _prior_gp? Naively I would suspect the first contains the latter. Or at least we should strive to avoid putting alot of additional attributes in this class (because they will esentially be irrelevant for non TL cases)
There was a problem hiding this comment.
I see your point and I thought the same before, but unfortunately I couldn't find a good solution to this: The problem is that there is a gap between creation via from_prior and fitting the model via fit. The instance must somehow remember it should use transfer learning and its prior to be able to create the _model. I'd be happy to change this and will give it another thought. Maybe the logic could be moved to some KernelFactory or MeanFactory. Do you have any suggestions how to get rid of this attribute?
There was a problem hiding this comment.
How about introducing a new mean factory similar to the kernel factories in BayBE?
# New default
class ConstantMeanFactory(MeanFactory):
def __call__(self, ..)
return gpytorch.means.ConstantMean()
class PriorMeanFactory(MeanFactory):
def __init__(self, prior_gp: GPSurrogate):
self.prior_gp = deepcopy(prior_gp)
def __call__(self, batch_shape: torch.Size) :
return PriorMean()
Then in from_prior I'd just replace the mean factory by the new PriorMeanFactory and could remove the attribute from the surrogate class, but this would add an entirely new factory pattern to BayBE.
There was a problem hiding this comment.
yeah I would prefer that a bit
although the main reason for the factory was that search space info is needed when creating the kernels, which is not available yet when specifying the attribute here to the surrogate. that doesnt seem to bet he case here with the means, right?
So a factory is not strictly needed but I see till two advantages why I would prefer it:
- it wouldnt be an emtpy unused content in no prior gp is used as it would hold the default factory
- it would be more consistent to have all kinds of fatories rather than haveing a mixture of factories and other optional model-related attributes
About _model: This is supposed to hold the fitted botorch model right? So would it make any sense to only partially initialize it with the means? If no, then forget hat idea
There was a problem hiding this comment.
Is this thread now still relevant, given that we agreed to a Factory approach in our meeting (iirc)?
| """The actual model.""" | ||
|
|
||
| # Transfer learning fields | ||
| _prior_gp = field(init=False, default=None, eq=False) |
There was a problem hiding this comment.
i see its prob the same issue as with _model so ideally you can paste the same comment thats there also here
depending ont he design this attribbute might also be removed and the commen is obsolete
| if self._prior_gp is not None and hasattr(self._prior_gp, "input_transform"): | ||
| # Use prior's transforms for consistency in transfer learning | ||
| input_transform = self._prior_gp.input_transform | ||
| outcome_transform = self._prior_gp.outcome_transform |
There was a problem hiding this comment.
since there is an explicit check for inout_transform, is it always guaranteed to have output_transform?
Why is the heck for input_transform even needed?
| from torch import Tensor | ||
|
|
||
|
|
||
| class PriorMean(gpytorch.means.Mean): |
There was a problem hiding this comment.
Some question for understanding:
- What does this new class achieve what is not possible with the existing botorch constant mean class?
- When the incoming mean is a constant mean, this class would also effectively produce a contant mean?
- Afaik all our GP have constant mean, so everything woudl forever be costant mean. Is this class here then necessary? Couldnt we just use the botorch cosntant mean class for the new TL case as well, except that the number is fixed and predetermined, ie somehow "set"?
There was a problem hiding this comment.
I think there might be some misunderstanding here. The incoming mean is not constant since the prior GP is fitted on some data already and we are using its posterior here. Even if the prior GP originally had a ConstantMean, once trained, its posterior mean will not be constant anymore. Or am I misunderstanding your comment?
There was a problem hiding this comment.
I see thanks for clarifying, I see now the need for the class
please lets just make sure this is optimized and does not impose any computaitonal bottleneck
Is it also right that this implementaiton is the variant of completely frozen prio mean? ie the mean is not just a prior but its forever the mean for our actual GP used int he campaign?
There was a problem hiding this comment.
Pull request overview
This PR introduces a transfer learning capability for Gaussian Process surrogates through a new from_prior constructor method. The implementation enables mean function transfer from a pre-trained GP to a new GP by using the source GP's posterior mean predictions as the mean module for the target GP.
Key changes:
- New
PriorMeanclass that wraps a trained GP as a BoTorch-compatible mean module with frozen hyperparameters - New
from_priorclass method for constructing a GP surrogate with transfer learning capabilities - Modified
_fitmethod to conditionally use the prior GP's mean and transforms when available
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| baybe/surrogates/gaussian_process/prior_modules.py | Introduces PriorMean class to wrap a trained GP as a mean module for transfer learning |
| baybe/surrogates/gaussian_process/core.py | Adds from_prior constructor and updates _fit to support mean function transfer |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| input_transform = Normalize( | ||
| train_x.shape[-1], | ||
| bounds=context.parameter_bounds, | ||
| indices=numerical_idxs, |
There was a problem hiding this comment.
The indices parameter expects a list but numerical_idxs is a tuple. While this may work in practice, it's inconsistent with the previous implementation that used list(numerical_idxs) on line 217 in the original code. For consistency and to match the expected type, convert the tuple to a list.
| indices=numerical_idxs, | |
| indices=list(numerical_idxs), |
| from copy import deepcopy | ||
|
|
||
| from botorch.models import SingleTaskGP |
There was a problem hiding this comment.
The import statements are placed inside the method. Since deepcopy is already imported at the module level (line 5 in prior_modules.py) and SingleTaskGP is imported in the TYPE_CHECKING block (line 31), these local imports are redundant and should be removed in favor of the module-level imports.
AVHopp
left a comment
There was a problem hiding this comment.
Just some minor comments as @Scienfitz raised some questions that might impact the design of this code, I did not fully review everything yet.
| kernel_factory: KernelFactory | None = None, | ||
| **kwargs, | ||
| ) -> GaussianProcessSurrogate: | ||
| """Create a GP surrogate with mean function transfer learning. |
There was a problem hiding this comment.
I think this docstring needs a bit more explanation on what exactly is done and transferred. Also, the description in the Returns: part could contain more information (but might not be needed if you add 2-3 sentences here describing what this does in more detail)
Co-authored-by: Martin Fitzner <martin.fitzner@merckgroup.com>
There was a problem hiding this comment.
the name x_modules is a bit inconsistent compared to our other naming
just means.py?
| Returns: | ||
| Mean predictions from the wrapped GP. | ||
| """ | ||
| self.gp.eval() |
There was a problem hiding this comment.
wouldnt it make sense to move these eval statements into init because they are only needed once?
| Args: | ||
| gp: Trained Gaussian Process to use as mean function. | ||
| batch_shape: Batch shape for the mean module. | ||
| **kwargs: Additional keyword arguments. |
There was a problem hiding this comment.
Is it necessary to include those in this class/the __init__? Currently they seem to be silently ignored, so I would propose to either remove them completely if possible or at least mention that they are being ignored.
| @@ -113,11 +115,57 @@ class GaussianProcessSurrogate(Surrogate): | |||
| _model = field(init=False, default=None, eq=False) | |||
There was a problem hiding this comment.
Is this thread now still relevant, given that we agreed to a Factory approach in our meeting (iirc)?
|
Note: On hold until mean factory is implemented. |
New constructor method that enables transfer learning for Gaussian Process surrogates
Mean Function Transfer:
PriorMeanclass that wraps the prior GP as a BoTorch-compatible mean moduleIn the upcoming PR we will introduce a new TL surrogate that takes a search space with
TaskParameterand internally falls back to the new constructor for training a source GP on the source data and a target GP on the target data using the source mean as a prior.Further extensions: