diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 5d9983545..d0963a1e9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -291,7 +291,7 @@ def forward( B: torch.Tensor, out=None, bias: Optional[torch.Tensor] = None, - state: MatmulLtState = None, + state: Optional[MatmulLtState] = None, ): state = state or MatmulLtState()