From 0c6dda0842a8ee463518aa547fa0e4ab36b233db Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 13 Mar 2024 18:10:10 +0200 Subject: [PATCH] Mark some optimizer update arguments as Noneable (they were being called with Nones) --- bitsandbytes/functional.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 8fa8f2f60..bb6a04892 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1618,18 +1618,18 @@ def optimizer_update_8bit( g: Tensor, p: Tensor, state1: Tensor, - state2: Tensor, + state2: Optional[torch.Tensor], beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: Tensor, - qmap2: Tensor, + qmap2: Optional[torch.Tensor], max1: Tensor, - max2: Tensor, + max2: Optional[torch.Tensor], new_max1: Tensor, - new_max2: Tensor, + new_max2: Optional[torch.Tensor], weight_decay: float = 0.0, gnorm_scale: float = 1.0, unorm_vec: Optional[torch.Tensor] = None, @@ -1751,16 +1751,16 @@ def optimizer_update_8bit_blockwise( g: Tensor, p: Tensor, state1: Tensor, - state2: Tensor, + state2: Optional[torch.Tensor], beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: Tensor, - qmap2: Tensor, + qmap2: Optional[torch.Tensor], absmax1: Tensor, - absmax2: Tensor, + absmax2: Optional[torch.Tensor], weight_decay: float = 0.0, gnorm_scale: float = 1.0, skip_zeros=False,