Skip to content

[ENH] Add gradient_clip_val to BaseModel v2 so clipping is model-owned and checkpoint-portable #2237

@StrikerEureka34

Description

@StrikerEureka34

Is your feature request related to a problem? Please describe.

So I was training TimeXer on a multivariate dataset with exogenous variables, and I noticed that the loss was quite noisy and the training kept diverging and despite trying everything it didn't improve, so I did a bit of digging.....

And found that the natural fix for transformer-based models like this is gradient clipping. But the only way to do it right now is to pass gradient_clip_val in the Trainer config, which I did, and it worked but only because I remembered to re-add it manually every time I resumed training from a checkpoint. When I forgot once, training quietly resumed without clipping and diverged again before I caught it.

The core issue is that gradient_clip_val has no home in the model itself.
It only lives in Trainer config (where it is hardcoded into the shared test setup). This means:

  • It is not saved in the checkpoint, so every resume requires you to reconstruct Trainer arguments from memory or docs.
  • Models like TimeXer, DLinear, SAMformer, TFT v2, and TIDE v2 cannot declare their own clipping needs. Every model is treated the same regardless of architecture, even though transformers are more sensitive to gradient explosions than linear models.

Describe the solution you'd like

Add gradient_clip_val: float | None = None and gradient_clip_algorithm: str = "norm" to BaseModel.init) and override configure_gradient_clipping (the Lightning hook called once per step before the optimizer):

def configure_gradient_clipping(self, optimizer, gradient_clip_val=None, gradient_clip_algorithm=None):
    if self.gradient_clip_val is not None:
        total_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=float("inf"))
        self.log("grad_norm", total_norm, on_step=True, ...)
        self.clip_gradients(optimizer, gradient_clip_val=self.gradient_clip_val, ...)
    else:
        super().configure_gradient_clipping(optimizer, gradient_clip_val, gradient_clip_algorithm)

When gradient_clip_val=None (default), nothing changes, existing Trainer-based clipping still works. When set on the model, clipping is model-owned, saved in the checkpoint, and inherited by every subclass.

TslibBaseModel also needs the same two params added and forwarded to super().init(). The existing save_hyperparameters() call there will pick them up automatically with no other changes.

Add few tests to prevent silent regression in future.

Describe alternatives you've considered

Keeping it Trainer-only works for simple cases but fails at checkpoint resumption and makes it impossible for individual models to declare their own clipping requirements. A custom callback could apply clipping per step but would bypass Lightning's clip_gradients(), which handles AMP gradient unscaling and getting that right manually is fragile and still does not solve the checkpoint portability problem.

Additional context
The outcome of this will be:

  • DLinear and TimeXer get gradient_clip_val in their constructor for free since they go through TslibBaseModel. ,
  • SAMformer, TFT v2, and TIDE v2 inherit the clipping behavior but each needs its own init updated individually to expose the param
  • For foundation models [ENH] add first batch of foundation models into PTF for v2. #1959 which extend BaseModel directly, having this there before they land means no retrofit work needed.

Out of scope:
I discovered that Samformer, TFT v2, and TIDE v2 each manage their own save_hyperparameters() call independently, so each needs its own init updated separately. Deferring it to keep this PR focused.

Related: #2132, #1959

Would really appreciate your insights on whether this is the right approach and if there's a better way to handle it before going further, thanks! @phoeenniixx @RecreationalMath @PranavBhatP

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions