Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main(
device=fabric.device, dtype=dtype, quantization_mode=quantize
):
model = LLaMA.from_name(name)
add_adapter_v2_parameters_to_linear_layers(model)
add_adapter_v2_parameters_to_linear_layers(model, dtype)
Copy link
Contributor

@rasbt rasbt May 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update on the PR! Eager to give this a try!
Btw here I noticed that you'd also have to modify the finetune/adapter_v2.py script so that it includes the dtype in the function call


# 1. Load the pretrained weights
model.load_state_dict(pretrained_checkpoint, strict=False)
Expand Down
19 changes: 13 additions & 6 deletions lit_llama/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import Tensor
import torch.nn as nn
from torch.nn import functional as F
from lit_llama.quantization import Linear8bitLt

from lit_llama.adapter import LLaMA

Expand All @@ -26,20 +27,26 @@ def adapter_v2_state_from_state_dict(state_dict: dict) -> dict:


def adapter_v2_new_forward(self, input: Tensor) -> Tensor:
weight = self.weight
if isinstance(self, Linear8bitLt):
weight = self.dequantize(input.dtype)
return self.adapter_scale * (
F.linear(input, self.weight, self.bias) + self.adapter_bias
F.linear(input, weight, self.bias) + self.adapter_bias
)


def adapter_v2_linear_with_bias_and_scale(layer):
layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True)
layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True)
def adapter_v2_linear_with_bias_and_scale(layer, dtype):
weight = layer.weight
if isinstance(layer, Linear8bitLt):
weight = layer.dequantize(dtype)
layer.adapter_bias = torch.nn.Parameter(torch.zeros(weight.shape[0]), requires_grad=True)
layer.adapter_scale = torch.nn.Parameter(torch.ones(weight.shape[0]), requires_grad=True)
bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__)
setattr(layer, 'forward', bound_method)
return layer


def add_adapter_v2_parameters_to_linear_layers(model):
def add_adapter_v2_parameters_to_linear_layers(model, dtype):
for module in model.modules():
if isinstance(module, nn.Linear):
adapter_v2_linear_with_bias_and_scale(module)
adapter_v2_linear_with_bias_and_scale(module, dtype)
11 changes: 11 additions & 0 deletions lit_llama/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ def _quantize_weight(self, weight: torch.Tensor) -> None:
setattr(self.weight, "CB", CB)
setattr(self.weight, "SCB", SCB)

def dequantize(self, dtype):
if dtype not in [torch.bfloat16, torch.float16, torch.float32]:
raise ValueError(f"Invalid dtype: {dtype}. Allowed dtypes are: bfloat16, float16, float32")
weight_CB = self.weight.CB
weight_SCB = self.weight.SCB
# Modify SBC shape if it doesn't match CB
if weight_CB.shape[1] != weight_SCB.shape[0]:
weight_SCB = weight_SCB.view(weight_SCB.shape[0], 1)
result = (weight_CB * weight_SCB) / 127
result = result.to(dtype)
return result

if triton is not None:
# This is adapted from the OpenAI Triton matmul example.
Expand Down