Skip to content

Add basic GaussianProcessSurrogate.from_prior constructor method#717

Open
kalama-ai wants to merge 8 commits intomainfrom
feature/from_prior_gp_constructor
Open

Add basic GaussianProcessSurrogate.from_prior constructor method#717
kalama-ai wants to merge 8 commits intomainfrom
feature/from_prior_gp_constructor

Conversation

@kalama-ai
Copy link
Collaborator

@kalama-ai kalama-ai commented Dec 18, 2025

New constructor method that enables transfer learning for Gaussian Process surrogates

# Train source GP
source_gp = GaussianProcessSurrogate()
source_gp.fit(source_space, source_objective, source_data)

# Transfer mean
target_gp = GaussianProcessSurrogate.from_prior(
      prior_gp=source_gp,  # Use source GP as prior
  )

Mean Function Transfer:

  • Implements full mean transfer from the prior GP to the new GP
  • The posterior mean predictions of the pre-trained GP used as the mean module for the new GP
  • Frozen hyperparameters
  • Mean function is evaluated at source training points
  • New PriorMean class that wraps the prior GP as a BoTorch-compatible mean module

In the upcoming PR we will introduce a new TL surrogate that takes a search space with TaskParameter and internally falls back to the new constructor for training a source GP on the source data and a target GP on the target data using the source mean as a prior.

Further extensions:

  • the interface can be extended to enhanced mean transfer (initializing the HPs, evaluation at the target) and covariance transfer
  • this will be happening in later PRs once the other configurations are tested and the dispatching to different transfer modes is agreed on

- construct a GP by transferring knowledge from a pre-trained prior GP
- basic implementation for full mean transfer
- the posterior mean of the pretrained GP are used as mean module for GP
- hypereparameters are frozen and mean is evaluated at source points
- interface might later be extended to other mean transfers (initialize
hyperparameters) or covariance transfer
- new `PriorMean` class, that implements mean of prior GP as botorch module
class PriorMean(gpytorch.means.Mean):
"""GPyTorch mean module using a trained GP as prior mean.

This mean module wraps a trained Gaussian Process and uses its predictions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

process in Gaussian process should not be capitalized (unless its a headline or similar)

New GaussianProcessSurrogate instance with transfer learning

Raises:
ValueError: If prior_gp is not fitted
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is also a raise of the variable is not a SingleTaskGP which is not mentioned here?

@@ -113,11 +115,57 @@ class GaussianProcessSurrogate(Surrogate):
_model = field(init=False, default=None, eq=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that we already have this _model attribute, can you explain why we need to introduce yet another attribute like _prior_gp? Naively I would suspect the first contains the latter. Or at least we should strive to avoid putting alot of additional attributes in this class (because they will esentially be irrelevant for non TL cases)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point and I thought the same before, but unfortunately I couldn't find a good solution to this: The problem is that there is a gap between creation via from_prior and fitting the model via fit. The instance must somehow remember it should use transfer learning and its prior to be able to create the _model. I'd be happy to change this and will give it another thought. Maybe the logic could be moved to some KernelFactory or MeanFactory. Do you have any suggestions how to get rid of this attribute?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AVHopp , do you maybe have an idea?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about introducing a new mean factory similar to the kernel factories in BayBE?

  # New default
  class ConstantMeanFactory(MeanFactory):
      def __call__(self, ..)
          return gpytorch.means.ConstantMean()

  class PriorMeanFactory(MeanFactory):
      def __init__(self, prior_gp: GPSurrogate):
          self.prior_gp = deepcopy(prior_gp)

      def __call__(self, batch_shape: torch.Size) :
          return PriorMean()

Then in from_prior I'd just replace the mean factory by the new PriorMeanFactory and could remove the attribute from the surrogate class, but this would add an entirely new factory pattern to BayBE.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I would prefer that a bit

although the main reason for the factory was that search space info is needed when creating the kernels, which is not available yet when specifying the attribute here to the surrogate. that doesnt seem to bet he case here with the means, right?

So a factory is not strictly needed but I see till two advantages why I would prefer it:

  • it wouldnt be an emtpy unused content in no prior gp is used as it would hold the default factory
  • it would be more consistent to have all kinds of fatories rather than haveing a mixture of factories and other optional model-related attributes

About _model: This is supposed to hold the fitted botorch model right? So would it make any sense to only partially initialize it with the means? If no, then forget hat idea

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this thread now still relevant, given that we agreed to a Factory approach in our meeting (iirc)?

"""The actual model."""

# Transfer learning fields
_prior_gp = field(init=False, default=None, eq=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no type?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see its prob the same issue as with _model so ideally you can paste the same comment thats there also here

depending ont he design this attribbute might also be removed and the commen is obsolete

if self._prior_gp is not None and hasattr(self._prior_gp, "input_transform"):
# Use prior's transforms for consistency in transfer learning
input_transform = self._prior_gp.input_transform
outcome_transform = self._prior_gp.outcome_transform
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since there is an explicit check for inout_transform, is it always guaranteed to have output_transform?

Why is the heck for input_transform even needed?

from torch import Tensor


class PriorMean(gpytorch.means.Mean):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some question for understanding:

  1. What does this new class achieve what is not possible with the existing botorch constant mean class?
  2. When the incoming mean is a constant mean, this class would also effectively produce a contant mean?
  3. Afaik all our GP have constant mean, so everything woudl forever be costant mean. Is this class here then necessary? Couldnt we just use the botorch cosntant mean class for the new TL case as well, except that the number is fixed and predetermined, ie somehow "set"?

Copy link
Collaborator Author

@kalama-ai kalama-ai Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be some misunderstanding here. The incoming mean is not constant since the prior GP is fitted on some data already and we are using its posterior here. Even if the prior GP originally had a ConstantMean, once trained, its posterior mean will not be constant anymore. Or am I misunderstanding your comment?

Copy link
Collaborator

@Scienfitz Scienfitz Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thanks for clarifying, I see now the need for the class
please lets just make sure this is optimized and does not impose any computaitonal bottleneck

Is it also right that this implementaiton is the variant of completely frozen prio mean? ie the mean is not just a prior but its forever the mean for our actual GP used int he campaign?

@AVHopp AVHopp requested a review from Copilot January 7, 2026 13:38
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a transfer learning capability for Gaussian Process surrogates through a new from_prior constructor method. The implementation enables mean function transfer from a pre-trained GP to a new GP by using the source GP's posterior mean predictions as the mean module for the target GP.

Key changes:

  • New PriorMean class that wraps a trained GP as a BoTorch-compatible mean module with frozen hyperparameters
  • New from_prior class method for constructing a GP surrogate with transfer learning capabilities
  • Modified _fit method to conditionally use the prior GP's mean and transforms when available

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
baybe/surrogates/gaussian_process/prior_modules.py Introduces PriorMean class to wrap a trained GP as a mean module for transfer learning
baybe/surrogates/gaussian_process/core.py Adds from_prior constructor and updates _fit to support mean function transfer

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

input_transform = Normalize(
train_x.shape[-1],
bounds=context.parameter_bounds,
indices=numerical_idxs,
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indices parameter expects a list but numerical_idxs is a tuple. While this may work in practice, it's inconsistent with the previous implementation that used list(numerical_idxs) on line 217 in the original code. For consistency and to match the expected type, convert the tuple to a list.

Suggested change
indices=numerical_idxs,
indices=list(numerical_idxs),

Copilot uses AI. Check for mistakes.
Comment on lines +147 to +149
from copy import deepcopy

from botorch.models import SingleTaskGP
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import statements are placed inside the method. Since deepcopy is already imported at the module level (line 5 in prior_modules.py) and SingleTaskGP is imported in the TYPE_CHECKING block (line 31), these local imports are redundant and should be removed in favor of the module-level imports.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@AVHopp AVHopp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor comments as @Scienfitz raised some questions that might impact the design of this code, I did not fully review everything yet.

kernel_factory: KernelFactory | None = None,
**kwargs,
) -> GaussianProcessSurrogate:
"""Create a GP surrogate with mean function transfer learning.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this docstring needs a bit more explanation on what exactly is done and transferred. Also, the description in the Returns: part could contain more information (but might not be needed if you add 2-3 sentences here describing what this does in more detail)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name x_modules is a bit inconsistent compared to our other naming
just means.py?

Returns:
Mean predictions from the wrapped GP.
"""
self.gp.eval()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldnt it make sense to move these eval statements into init because they are only needed once?

Args:
gp: Trained Gaussian Process to use as mean function.
batch_shape: Batch shape for the mean module.
**kwargs: Additional keyword arguments.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to include those in this class/the __init__? Currently they seem to be silently ignored, so I would propose to either remove them completely if possible or at least mention that they are being ignored.

@@ -113,11 +115,57 @@ class GaussianProcessSurrogate(Surrogate):
_model = field(init=False, default=None, eq=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this thread now still relevant, given that we agreed to a Factory approach in our meeting (iirc)?

@kalama-ai
Copy link
Collaborator Author

Note: On hold until mean factory is implemented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants