You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, base_metrics.py provides TorchMetricWrapper to wrap torchmetrics Lightning metrics. However, there is no equivalent for
standard PyTorch nn.Module losses such as:
torch.nn.HuberLoss
torch.nn.MSELoss
torch.nn.SmoothL1Loss
torch.nn.L1Loss
If a user tries to pass these directly to a pytorch-forecasting model, it fails because they:
Don't implement update() / compute() required by LightningMetric
Don't handle the (target, weight) tuple that pytorch-forecasting
passes internally (see MultiHorizonMetric.update() in base_metrics.py)
Don't implement to_prediction() which models call at inference time
Can't be composed with MultiLoss or CompositeMetric
Proposed Solution
Add NNLossWrapper in pytorch_forecasting/metrics/base_metrics.py as a sibling to the existing TorchMetricWrapper, following the same
pattern:
classNNLossWrapper(Metric):
""" Wrap a standard PyTorch nn.Module loss for use with pytorch-forecasting. Example ------- >>> loss = NNLossWrapper(nn.HuberLoss(delta=1.5)) >>> combined = NNLossWrapper(nn.HuberLoss()) + MAE() """def__init__(self, loss_fn: torch.nn.Module, **kwargs):
ifnotisinstance(loss_fn, torch.nn.Module):
raiseTypeError(f"loss_fn must be an nn.Module, got {type(loss_fn)}")
super().__init__(**kwargs)
self.loss_fn=loss_fndefupdate(self, y_pred: torch.Tensor, y_actual) ->None:
ifisinstance(y_actual, (list, tuple)) andnotisinstance(
y_actual, rnn.PackedSequence
):
target, _=y_actualelse:
target=y_actualy_pred=self.to_prediction(y_pred)
self._loss=self.loss_fn(y_pred, target)
defcompute(self) ->torch.Tensor:
returnself._lossdefto_prediction(self, y_pred: torch.Tensor) ->torch.Tensor:
ify_pred.ndim==3:
y_pred=y_pred[..., 0]
returny_preddef__repr__(self):
returnf"NNLossWrapper({repr(self.loss_fn)})"
Problem
Currently,
base_metrics.pyprovidesTorchMetricWrapperto wraptorchmetricsLightning metrics. However, there is no equivalent forstandard PyTorch
nn.Modulelosses such as:torch.nn.HuberLosstorch.nn.MSELosstorch.nn.SmoothL1Losstorch.nn.L1LossIf a user tries to pass these directly to a pytorch-forecasting model, it fails because they:
update()/compute()required byLightningMetric(target, weight)tuple that pytorch-forecastingpasses internally (see
MultiHorizonMetric.update()inbase_metrics.py)to_prediction()which models call at inference timeMultiLossorCompositeMetricProposed Solution
Add
NNLossWrapperinpytorch_forecasting/metrics/base_metrics.pyas a sibling to the existingTorchMetricWrapper, following the samepattern:
Files to Change
pytorch_forecasting/metrics/base_metrics.py— addNNLossWrapperpytorch_forecasting/metrics/__init__.py— exportNNLossWrapperRelated
TorchMetricWrapperinbase_metrics.py— the parallelclass this follows as a pattern
Github: @vinitjain2005