Skip to content

Commit

Permalink
clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Mar 6, 2025
1 parent a7bf319 commit 4a58bb1
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 45 deletions.
121 changes: 91 additions & 30 deletions examples/weight_transform.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,62 @@
import torch
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
)
from compressed_tensors.transforms import Hadamard, RandomHadamard, Transforms
from compressed_tensors.transforms.transform_args import (
ModuleTarget,
TransformationArgs,
)
from compressed_tensors.transforms.transform_config import TransformationConfig
from compressed_tensors.transforms.transform_data import TransformData
from compressed_tensors.transforms.transform_scheme import TransformationScheme
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

ignore = ["re:*.mlp.down_proj$"]
module_targets = [ModuleTarget.WEIGHTS]
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier

# U(W)V.T

ignore = ["re:.*.mlp.down_proj$"]
module_targets = [ModuleTarget.WEIGHT.value]

# Start with a processed
targets = ["Linear"] # 2048 * 2048
# Start with a processed
targets = ["Linear"] # 2048 * 2048
v_linear_args = TransformationArgs(
targets=targets, module_targets=module_targets, ignore=ignore, call_args={"transpose": True, "first": False}
targets=targets,
module_targets=module_targets,
ignore=ignore,
call_args={"transpose": True, "first": False},
)

targets = ["re:*.mlp.down_proj$"] # 5632 * 5632
targets = ["re:.*.mlp.down_proj$"] # 8192 * 8192
v_down_proj = TransformationArgs(
targets=targets, module_targets=module_targets, call_args={"transpose": True, "first": False}
targets=targets,
module_targets=module_targets,
call_args={"transpose": True, "first": False},
)

targets = ["re:*.attn.q_proj$", "re:*.attn.o_proj$", "re:*.mlp.down_proj$"] # 2048 * 2048
targets = [
"re:.*.attn.q_proj$",
"re:.*.attn.o_proj$",
"re:.*.mlp.down_proj$",
] # 2048 * 2048
u_q_o_down_proj = TransformationArgs(
targets=targets, module_targets=module_targets,
targets=targets,
module_targets=module_targets,
)

targets = ["re:*.attn.gate_proj$", "re:*.mlp.up_proj$"] # 5632 * 5632
targets = ["re:.*.mlp.gate_proj$", "re:.*.mlp.up_proj$"] # 8192 * 8192
u_gate_up_proj = TransformationArgs(
targets=targets, module_targets=module_targets,
targets=targets,
module_targets=module_targets,
)

targets = ["re:*.attn.k_proj$", "re:*.attn.v_proj$"] # 256 * 256
targets = ["re:.*.attn.k_proj$", "re:.*.attn.v_proj$"] # 512 * 512
u_k_v_proj = TransformationArgs(
targets=targets, module_targets=module_targets,
targets=targets,
module_targets=module_targets,
)


Expand All @@ -51,7 +72,7 @@
v_scheme_down_proj = TransformationScheme(
transform_type="random-hadamard",
groups=[v_down_proj],
transform_creation_args={"size": 5632},
transform_creation_args={"size": 8192},
)

# We could combine multiple args to the same scheme but then would make it more difficult to consolidate order of transforms
Expand All @@ -64,35 +85,65 @@
u_scheme_gate_up_proj = TransformationScheme(
transform_type="random-hadamard",
groups=[u_gate_up_proj],
transform_creation_args={"size": 5632},
transform_creation_args={"size": 8192},
)

u_scheme_k_v_proj = TransformationScheme(
transform_type="random-hadamard",
groups=[u_k_v_proj],
transform_creation_args={"size": 256},
transform_creation_args={"size": 512},
)

# QuIP Recipe with weight only quantization
config = TransformationConfig(
transform_groups={
"u_transform_q_o_down_proj": u_scheme_q_o_down_proj,
"u_transform_gate_up_proj": u_scheme_gate_up_proj,
"u_transform_k_v_proj": u_scheme_k_v_proj,
"u_transform_gate_up_proj": u_scheme_gate_up_proj,
"v_transform_linear": v_scheme,
"v_transform_down_proj": v_scheme_down_proj
"v_transform_down_proj": v_scheme_down_proj,
}
)

#MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
recipe = QuantizationModifier(
targets="Linear",
ignore=["lm_head"],
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=4,
symmetric=True,
strategy=QuantizationStrategy.GROUP,
group_size=128,
),
)
},
transforms_config=config,
)

MODEL_ID = "meta-llama/Llama-3.2-1B"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
MODEL_ID, device_map="auto", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

oneshot(model=model, recipe=recipe)

print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-Transforms"
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

"""
x = model.model.layers[0]
attn = x.self_attn
mlp = x.mlp
Expand All @@ -104,16 +155,26 @@
attn.o_proj,
mlp.gate_proj,
mlp.down_proj,
mlp.up_proj
mlp.up_proj,
]
for layer in layers:
from compressed_tensors.transforms.hadamard_utils import (
deterministic_hadamard_matrix,
random_hadamard_matrix,
)
for layer in layers:
current_weight = layer.weight
original_weight = current_weight.data.clone()
(n, m) = current_weight.shape
U = torch.eye(n).to("cuda").to(torch.bfloat16)
V = torch.eye(m).to("cuda").to(torch.bfloat16)
print(n, layer)
U = torch.Tensor(random_hadamard_matrix(n)).to("cuda").to(torch.float32)
V = torch.Tensor(random_hadamard_matrix(m)).to("cuda").to(torch.float32)
output = torch.matmul(U, current_weight)
output = torch.matmul(output, V.T)
# apply untransform
x = torch.matmul(U.T, torch.matmul(output, V))
print(torch.max(abs(x - original_weight)))
"""
21 changes: 6 additions & 15 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from compressed_tensors.quantization import QuantizationStatus, is_attention_module
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
from compressed_tensors.transforms.apply import apply_transforms_to_parameter
from compressed_tensors.utils.offload import is_module_offloaded, update_parameter_data
from loguru import logger
from torch.nn import Module
Expand Down Expand Up @@ -123,25 +124,15 @@ def update_weight_zp_scale(module: Module):

transform_data = getattr(module, "transform_data", None)
if transform_data is not None:
# order that the transforms were added to match the order they should be applied
untransformed_weight = module.weight.data.clone()
for transform_name, transform_values in transform_data.data.items():
transform = getattr(module, transform_name)
apply = transform_values.get("apply")
call_args = transform_values.get("call_args")
if call_args:
transformed_weight = apply(
input_tensor=module.weight, transform=transform, **call_args
)
else:
transformed_weight = apply(
input_tensor=module.weight, transform=transform
)
module.weight.data.copy_(transformed_weight)
apply_transforms_to_parameter(
module=module,
module_parameter=module.weight,
transform_data=transform_data,
)

call_observer(module=module, base_name="weight")

# TODO: what do we do here?
if transform_data is not None:
module.weight.data.copy_(untransformed_weight)

Expand Down

0 comments on commit 4a58bb1

Please sign in to comment.