@@ -22,12 +22,62 @@ def __init__(self, linear: nn.Linear):
2222 weight = linear .weight
2323 assert weight .data .dtype .is_floating_point
2424 assert linear .bias is None
25- CB , SCB , _ = bnbF .int8_vectorwise_quant (weight .data .to (torch .float16 )) # type: ignore
26- self .weight = nn .Parameter (CB , requires_grad = False )
27- self .weight_scb = nn .Parameter (SCB , requires_grad = False )
25+
26+ # Check if the weight is on a meta device
27+ if weight .device .type == 'meta' :
28+ # For meta device, we need to preserve the shape information
29+ # Create tensors with the same shape as the original weight
30+ # We'll use the shape information from the linear layer
31+ out_features , in_features = weight .shape
32+
33+ # Create CB tensor with shape [out_features * 8/8, in_features]
34+ # The first dimension is rounded up to a multiple of 8
35+ # This matches the shape that would be produced by int8_vectorwise_quant
36+ padded_out_features = ((out_features + 7 ) // 8 ) * 8
37+ self .weight = nn .Parameter (
38+ torch .zeros ((padded_out_features , in_features ),
39+ dtype = torch .int8 , device = 'meta' ),
40+ requires_grad = False
41+ )
42+
43+ # Create SCB tensor with shape [out_features]
44+ self .weight_scb = nn .Parameter (
45+ torch .zeros (out_features , dtype = torch .float , device = 'meta' ),
46+ requires_grad = False
47+ )
48+ self .is_meta = True
49+ else :
50+ # Normal quantization for non-meta tensors
51+ CB , SCB , _ = bnbF .int8_vectorwise_quant (weight .data .to (torch .float16 )) # type: ignore
52+ self .weight = nn .Parameter (CB , requires_grad = False )
53+ self .weight_scb = nn .Parameter (SCB , requires_grad = False )
54+ self .is_meta = False
55+
56+ def _check_meta_status (self ):
57+ """Check if the weights are still meta tensors and update is_meta flag accordingly."""
58+ if hasattr (self , 'is_meta' ) and self .is_meta :
59+ # Check if the weights have been loaded (no longer on meta device)
60+ if self .weight .device .type != 'meta' and self .weight .numel () > 0 :
61+ self .is_meta = False
62+
63+ # Ensure the scale tensor is float32, regardless of the model's dtype
64+ if self .weight_scb .dtype != torch .float :
65+ self .weight_scb .data = self .weight_scb .data .float ()
2866
2967 def forward (self , x ):
3068 import bitsandbytes as bnb # type: ignore
69+
70+ # Update meta status based on actual tensor properties
71+ self ._check_meta_status ()
72+
73+ # Check if this is a meta tensor that hasn't been properly initialized yet
74+ if hasattr (self , 'is_meta' ) and self .is_meta :
75+ # If we're still in meta mode but trying to do a forward pass,
76+ # this means the weights weren't properly loaded
77+ raise RuntimeError (
78+ "Attempting to run forward pass with meta tensors. "
79+ "The model weights need to be loaded before running inference." )
80+
3181 state = bnb .MatmulLtState ()
3282 state .CB = self .weight # type: ignore
3383 assert isinstance (state .CB , torch .Tensor )
0 commit comments