Skip to content

Commit

Permalink
assert lora weight
Browse files Browse the repository at this point in the history
Signed-off-by: ZX-ModelCloud <[email protected]>
  • Loading branch information
ZX-ModelCloud committed Feb 28, 2025
1 parent a316cd2 commit 86a242c
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions tests/test_quant_and_eora_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# -- do not touch
import os

import torch
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft.tuners.lora.gptq import GPTQLoraLinear
Expand All @@ -28,7 +30,7 @@

from datasets import load_dataset # noqa: E402
from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402
from gptqmodel.adapter.adapter import Lora # noqa: E402
from gptqmodel.adapter.adapter import Lora, HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX # noqa: E402
from gptqmodel.utils.eval import EVAL # noqa: E402
from gptqmodel.utils.torch import torch_empty_cache # noqa: E402
from lm_eval.utils import make_table # noqa: E402
Expand Down Expand Up @@ -122,7 +124,7 @@ def test_quant_and_eora(self):

# BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA,
for backend in [ BACKEND.MARLIN ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN
# base_bench = self.bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only
base_bench = self.bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only
eora_bench = self.bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora)

print('--------GPTQModel + EoRA Config ---------')
Expand All @@ -131,10 +133,10 @@ def test_quant_and_eora(self):
table_data = [[key, value] for key, value in config_dict.items()]
print(tabulate(table_data, headers=["Key", "Value"], tablefmt="grid"))

# print('--------Eval GPTQ Result---------')
# print(make_table(base_bench))
# if "groups" in base_bench:
# print(make_table(base_bench, "groups"))
print('--------Eval GPTQ Result---------')
print(make_table(base_bench))
if "groups" in base_bench:
print(make_table(base_bench, "groups"))

print('--------Eval GPTQ + EoRA Result---------')
print(make_table(eora_bench))
Expand All @@ -153,7 +155,17 @@ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]):
model.load_adapter(adapter.path)
print("peft model", model)

assert isinstance(model.model.layers[0].self_attn.v_proj, GPTQLoraLinear)
adapter_weights = load_file(os.path.join(adapter.path, HF_ADAPTER_FILE_NAME))
origin_lora_a_weight = adapter_weights[f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}model.layers.6.self_attn.v_proj.lora_A.weight"]
origin_lora_b_weight = adapter_weights[f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}model.layers.6.self_attn.v_proj.lora_B.weight"]

module = model.model.layers[6].self_attn.v_proj

assert isinstance(module, GPTQLoraLinear)
assert torch.equal(origin_lora_a_weight.to(model.device), module.lora_A["default"].weight.data)
assert torch.equal(origin_lora_b_weight.to(model.device), module.lora_B["default"].weight.data)

del origin_lora_a_weight, origin_lora_b_weight, adapter_weights

tokenizer = AutoTokenizer.from_pretrained(path)
inp = tokenizer("Capital of France is", return_tensors="pt").to(model.device)
Expand Down

0 comments on commit 86a242c

Please sign in to comment.