Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 69 additions & 10 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
DefaultKernelFactory,
_default_noise_factory,
)
from baybe.surrogates.gaussian_process.prior_modules import PriorMean
from baybe.utils.conversion import to_string

if TYPE_CHECKING:
from botorch.models import SingleTaskGP
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
Expand Down Expand Up @@ -113,11 +115,60 @@ class GaussianProcessSurrogate(Surrogate):
_model = field(init=False, default=None, eq=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that we already have this _model attribute, can you explain why we need to introduce yet another attribute like _prior_gp? Naively I would suspect the first contains the latter. Or at least we should strive to avoid putting alot of additional attributes in this class (because they will esentially be irrelevant for non TL cases)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point and I thought the same before, but unfortunately I couldn't find a good solution to this: The problem is that there is a gap between creation via from_prior and fitting the model via fit. The instance must somehow remember it should use transfer learning and its prior to be able to create the _model. I'd be happy to change this and will give it another thought. Maybe the logic could be moved to some KernelFactory or MeanFactory. Do you have any suggestions how to get rid of this attribute?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AVHopp , do you maybe have an idea?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about introducing a new mean factory similar to the kernel factories in BayBE?

  # New default
  class ConstantMeanFactory(MeanFactory):
      def __call__(self, ..)
          return gpytorch.means.ConstantMean()

  class PriorMeanFactory(MeanFactory):
      def __init__(self, prior_gp: GPSurrogate):
          self.prior_gp = deepcopy(prior_gp)

      def __call__(self, batch_shape: torch.Size) :
          return PriorMean()

Then in from_prior I'd just replace the mean factory by the new PriorMeanFactory and could remove the attribute from the surrogate class, but this would add an entirely new factory pattern to BayBE.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I would prefer that a bit

although the main reason for the factory was that search space info is needed when creating the kernels, which is not available yet when specifying the attribute here to the surrogate. that doesnt seem to bet he case here with the means, right?

So a factory is not strictly needed but I see till two advantages why I would prefer it:

  • it wouldnt be an emtpy unused content in no prior gp is used as it would hold the default factory
  • it would be more consistent to have all kinds of fatories rather than haveing a mixture of factories and other optional model-related attributes

About _model: This is supposed to hold the fitted botorch model right? So would it make any sense to only partially initialize it with the means? If no, then forget hat idea

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this thread now still relevant, given that we agreed to a Factory approach in our meeting (iirc)?

"""The actual model."""

# Transfer learning fields
_prior_gp: SingleTaskGP | None = field(init=False, default=None, eq=False)
"""Prior GP to extract mean/covariance for transfer learning."""

@staticmethod
def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate:
"""Create a Gaussian process surrogate from one of the defined presets."""
return make_gp_from_preset(preset)

@classmethod
def from_prior(
cls,
prior_gp: GaussianProcessSurrogate,
kernel_factory: KernelFactory | None = None,
**kwargs,
) -> GaussianProcessSurrogate:
"""Create a GP surrogate using a prior GP's predictions as the mean function.

Transfers knowledge by using the prior GP's posterior mean predictions
as the mean function for a new GP, while learning covariance from scratch.

Args:
prior_gp: Fitted GaussianProcessSurrogate to use as prior
kernel_factory: Kernel factory for covariance components
**kwargs: Additional arguments for GaussianProcessSurrogate constructor

Returns:
New GaussianProcessSurrogate instance with transfer learning

Raises:
ValueError: If prior_gp is not a GaussianProcessSurrogate or is not fitted
"""
from copy import deepcopy

# Validate prior GP is fitted
if not isinstance(prior_gp, cls):
raise ValueError(
"prior_gp must be a fitted GaussianProcessSurrogate instance"
)
if prior_gp._model is None:
raise ValueError("Prior GP must be fitted before use")

# Configure kernel factory (always needed since we only do mean transfer now)
if kernel_factory is None:
kernel_factory = DefaultKernelFactory()

# Create new surrogate instance
instance = cls(kernel_or_factory=kernel_factory, **kwargs)

# Configure for transfer learning - store the BoTorch model
instance._prior_gp = deepcopy(prior_gp.to_botorch())

return instance

@override
def to_botorch(self) -> GPyTorchModel:
return self._model
Expand Down Expand Up @@ -152,22 +203,30 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
assert self._searchspace is not None

context = _ModelContext(self._searchspace)

numerical_idxs = context.get_numerical_indices(train_x.shape[-1])

# For GPs, we let botorch handle the scaling. See [Scaling Workaround] above.
input_transform = Normalize(
train_x.shape[-1],
bounds=context.parameter_bounds,
indices=list(numerical_idxs),
)
outcome_transform = Standardize(train_y.shape[-1])

# extract the batch shape of the training data
batch_shape = train_x.shape[:-2]

# Configure input/output transforms
if self._prior_gp is not None and hasattr(self._prior_gp, "input_transform"):
# Use prior's transforms for consistency in transfer learning
input_transform = self._prior_gp.input_transform
outcome_transform = self._prior_gp.outcome_transform
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since there is an explicit check for inout_transform, is it always guaranteed to have output_transform?

Why is the heck for input_transform even needed?

else:
# For GPs, we let botorch handle scaling. See [Scaling Workaround] above.
input_transform = Normalize(
train_x.shape[-1],
bounds=context.parameter_bounds,
indices=numerical_idxs,
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indices parameter expects a list but numerical_idxs is a tuple. While this may work in practice, it's inconsistent with the previous implementation that used list(numerical_idxs) on line 217 in the original code. For consistency and to match the expected type, convert the tuple to a list.

Suggested change
indices=numerical_idxs,
indices=list(numerical_idxs),

Copilot uses AI. Check for mistakes.
)
outcome_transform = Standardize(train_y.shape[-1])

# create GP mean
mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape)
if self._prior_gp is not None:
mean_module = PriorMean(self._prior_gp, batch_shape=batch_shape)
else:
mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape)

# define the covariance module for the numeric dimensions
base_covar_module = self.kernel_factory(
Expand Down
55 changes: 55 additions & 0 deletions baybe/surrogates/gaussian_process/prior_modules.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name x_modules is a bit inconsistent compared to our other naming
just means.py?

Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Prior modules for Gaussian process transfer learning."""

from __future__ import annotations

from copy import deepcopy
from typing import Any

import gpytorch
import torch
from botorch.models import SingleTaskGP
from torch import Tensor


class PriorMean(gpytorch.means.Mean):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some question for understanding:

  1. What does this new class achieve what is not possible with the existing botorch constant mean class?
  2. When the incoming mean is a constant mean, this class would also effectively produce a contant mean?
  3. Afaik all our GP have constant mean, so everything woudl forever be costant mean. Is this class here then necessary? Couldnt we just use the botorch cosntant mean class for the new TL case as well, except that the number is fixed and predetermined, ie somehow "set"?

Copy link
Collaborator Author

@kalama-ai kalama-ai Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be some misunderstanding here. The incoming mean is not constant since the prior GP is fitted on some data already and we are using its posterior here. Even if the prior GP originally had a ConstantMean, once trained, its posterior mean will not be constant anymore. Or am I misunderstanding your comment?

Copy link
Collaborator

@Scienfitz Scienfitz Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thanks for clarifying, I see now the need for the class
please lets just make sure this is optimized and does not impose any computaitonal bottleneck

Is it also right that this implementaiton is the variant of completely frozen prio mean? ie the mean is not just a prior but its forever the mean for our actual GP used int he campaign?

"""GPyTorch mean module using a trained GP as prior mean.

This mean module wraps a trained Gaussian Process and uses its predictions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

process in Gaussian process should not be capitalized (unless its a headline or similar)

as the mean function for another GP.

Args:
gp: Trained Gaussian Process to use as mean function.
batch_shape: Batch shape for the mean module.
**kwargs: Additional keyword arguments.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to include those in this class/the __init__? Currently they seem to be silently ignored, so I would propose to either remove them completely if possible or at least mention that they are being ignored.

"""

def __init__(
self, gp: SingleTaskGP, batch_shape: torch.Size = torch.Size(), **kwargs: Any
) -> None:
super().__init__()

# Deep copy and freeze the GP
self.gp: SingleTaskGP = deepcopy(gp)
self.batch_shape: torch.Size = batch_shape

# Freeze parameters and set eval mode once
for param in self.gp.parameters():
param.requires_grad = False

def forward(self, x: Tensor) -> Tensor:
"""Compute the mean function using the wrapped GP.

Args:
x: Input tensor for which to compute the mean.

Returns:
Mean predictions from the wrapped GP.
"""
self.gp.eval()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldnt it make sense to move these eval statements into init because they are only needed once?

self.gp.likelihood.eval()
with torch.no_grad(), gpytorch.settings.fast_pred_var():
mean = self.gp(x).mean.detach()

# Handle batch dimensions
target_shape = torch.broadcast_shapes(self.batch_shape, x.shape[:-1])
return mean.reshape(target_shape)
Loading