Skip to content

Commit

Permalink
assert dynamic rank
Browse files Browse the repository at this point in the history
Signed-off-by: ZX-ModelCloud <[email protected]>
  • Loading branch information
ZX-ModelCloud committed Feb 28, 2025
1 parent a1803ab commit 21dbdf6
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions tests/test_quant_and_eora_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@

log = LogBar.shared()


class Test(ModelTest):
#NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/"
# NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/"
# NATIVE_MODEL_ID = "/monster/data/model/tinyllama-15M-stories"
NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B"

Expand All @@ -60,7 +61,7 @@ def test_quant_and_eora(self):
rank = 128
batch_size = 1
calibration_dataset_rows = 512
calibration_dataset_concat_size = 0 # disable
calibration_dataset_concat_size = 0 # disable
auto_gc = False
adapter_path = "eora"
dataset_id = "allenai/c4"
Expand Down Expand Up @@ -88,7 +89,6 @@ def test_quant_and_eora(self):
).select(range(calibration_dataset_rows))["text"]

with tempfile.TemporaryDirectory() as tmpdir:
tmpdir="test_adapter"
eora = Lora(
# for quant, path is save path. for load, it is loading path
path=os.path.join(tmpdir, adapter_path),
Expand All @@ -100,6 +100,13 @@ def test_quant_and_eora(self):
group_size=group_size,
desc_act=desc_act, # bitblas only supports DESC_ACT=False
adapter=eora,
dynamic={
".*\\.gate_proj.*": {
"adapter": {
"rank": 256
}
}
},
)

model = GPTQModel.load(
Expand All @@ -112,7 +119,7 @@ def test_quant_and_eora(self):
batch_size=batch_size,
auto_gc=auto_gc,
calibration_dataset_concat_size=calibration_dataset_concat_size,
) #
) #

# EoRA adapter is saved according to Lora.path property
# if Lora.path is not set, we will save the lora as "lora.safetensors" in the same path as quant model
Expand All @@ -123,9 +130,9 @@ def test_quant_and_eora(self):
torch_empty_cache()

# BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA,
for backend in [ BACKEND.MARLIN ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN
eora_bench = self.bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora)
base_bench = self.bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only
for backend in [BACKEND.MARLIN]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN
eora_bench = self.bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora)
base_bench = self.bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only

print('--------GPTQModel + EoRA Config ---------')

Expand All @@ -147,8 +154,10 @@ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]):
# test post-quant inference
if adapter:
adapter_weights = load_file(os.path.join(adapter.path, HF_ADAPTER_FILE_NAME))
origin_lora_a_weight = adapter_weights[f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}model.layers.6.self_attn.v_proj.lora_A.weight"]
origin_lora_b_weight = adapter_weights[f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}model.layers.6.self_attn.v_proj.lora_B.weight"]
origin_lora_a_weight = adapter_weights[
f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}model.layers.5.self_attn.v_proj.lora_A.weight"]
origin_lora_b_weight = adapter_weights[
f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}model.layers.5.self_attn.v_proj.lora_B.weight"]

model = AutoModelForCausalLM.from_pretrained(path, device_map="cuda")
log.info("PEFT: converting model to lora model")
Expand All @@ -163,6 +172,14 @@ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]):
self.assert_adapter_load(model, origin_lora_a_weight, origin_lora_b_weight)
print("peft model", model)

# assert dynamic rank
v_proj_module = model.model.layers[5].self_attn.v_proj
assert v_proj_module.lora_A["default"].weight.data.shape[0] == 128
assert v_proj_module.lora_B["default"].weight.data.shape[1] == 128
gate_proj_module = model.model.layers[5].mlp.gate_proj
assert gate_proj_module.lora_A["default"].weight.data.shape[0] == 256
assert gate_proj_module.lora_B["default"].weight.data.shape[1] == 256

del origin_lora_a_weight, origin_lora_b_weight, adapter_weights
else:
model = AutoModelForCausalLM.from_pretrained(path, device_map="cuda")
Expand All @@ -173,7 +190,7 @@ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]):
tokens = model.generate(**inp)[0]
result = tokenizer.decode(tokens)
print(f"BACKEND: {backend}, Result: {result}")
#assert "paris" in result.lower(), f"`paris` not found in `{result}`"
# assert "paris" in result.lower(), f"`paris` not found in `{result}`"

bench_result = GPTQModel.eval(
model_or_id_or_path=model,
Expand All @@ -188,7 +205,7 @@ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]):
return bench_result

def assert_adapter_load(self, model, origin_lora_a_weight, origin_lora_b_weight):
module = model.model.layers[6].self_attn.v_proj
module = model.model.layers[5].self_attn.v_proj
assert isinstance(module, GPTQLoraLinear)
assert torch.equal(origin_lora_a_weight.to(model.device), module.lora_A["default"].weight.data)
assert torch.equal(origin_lora_b_weight.to(model.device), module.lora_B["default"].weight.data)

0 comments on commit 21dbdf6

Please sign in to comment.