diff --git a/examples/weight_transform.py b/examples/weight_transform.py index 83e54920a..c980d1945 100644 --- a/examples/weight_transform.py +++ b/examples/weight_transform.py @@ -1,41 +1,62 @@ +import torch +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, +) from compressed_tensors.transforms import Hadamard, RandomHadamard, Transforms from compressed_tensors.transforms.transform_args import ( ModuleTarget, TransformationArgs, ) from compressed_tensors.transforms.transform_config import TransformationConfig -from compressed_tensors.transforms.transform_data import TransformData from compressed_tensors.transforms.transform_scheme import TransformationScheme from transformers import AutoModelForCausalLM, AutoTokenizer -import torch -ignore = ["re:*.mlp.down_proj$"] -module_targets = [ModuleTarget.WEIGHTS] +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier + +# U(W)V.T + +ignore = ["re:.*.mlp.down_proj$"] +module_targets = [ModuleTarget.WEIGHT.value] -# Start with a processed -targets = ["Linear"] # 2048 * 2048 +# Start with a processed +targets = ["Linear"] # 2048 * 2048 v_linear_args = TransformationArgs( - targets=targets, module_targets=module_targets, ignore=ignore, call_args={"transpose": True, "first": False} + targets=targets, + module_targets=module_targets, + ignore=ignore, + call_args={"transpose": True, "first": False}, ) -targets = ["re:*.mlp.down_proj$"] # 5632 * 5632 +targets = ["re:.*.mlp.down_proj$"] # 8192 * 8192 v_down_proj = TransformationArgs( - targets=targets, module_targets=module_targets, call_args={"transpose": True, "first": False} + targets=targets, + module_targets=module_targets, + call_args={"transpose": True, "first": False}, ) -targets = ["re:*.attn.q_proj$", "re:*.attn.o_proj$", "re:*.mlp.down_proj$"] # 2048 * 2048 +targets = [ + "re:.*.attn.q_proj$", + "re:.*.attn.o_proj$", + "re:.*.mlp.down_proj$", +] # 2048 * 2048 u_q_o_down_proj = TransformationArgs( - targets=targets, module_targets=module_targets, + targets=targets, + module_targets=module_targets, ) -targets = ["re:*.attn.gate_proj$", "re:*.mlp.up_proj$"] # 5632 * 5632 +targets = ["re:.*.mlp.gate_proj$", "re:.*.mlp.up_proj$"] # 8192 * 8192 u_gate_up_proj = TransformationArgs( - targets=targets, module_targets=module_targets, + targets=targets, + module_targets=module_targets, ) -targets = ["re:*.attn.k_proj$", "re:*.attn.v_proj$"] # 256 * 256 +targets = ["re:.*.attn.k_proj$", "re:.*.attn.v_proj$"] # 512 * 512 u_k_v_proj = TransformationArgs( - targets=targets, module_targets=module_targets, + targets=targets, + module_targets=module_targets, ) @@ -51,7 +72,7 @@ v_scheme_down_proj = TransformationScheme( transform_type="random-hadamard", groups=[v_down_proj], - transform_creation_args={"size": 5632}, + transform_creation_args={"size": 8192}, ) # We could combine multiple args to the same scheme but then would make it more difficult to consolidate order of transforms @@ -64,35 +85,65 @@ u_scheme_gate_up_proj = TransformationScheme( transform_type="random-hadamard", groups=[u_gate_up_proj], - transform_creation_args={"size": 5632}, + transform_creation_args={"size": 8192}, ) u_scheme_k_v_proj = TransformationScheme( transform_type="random-hadamard", groups=[u_k_v_proj], - transform_creation_args={"size": 256}, + transform_creation_args={"size": 512}, ) # QuIP Recipe with weight only quantization config = TransformationConfig( transform_groups={ "u_transform_q_o_down_proj": u_scheme_q_o_down_proj, - "u_transform_gate_up_proj": u_scheme_gate_up_proj, "u_transform_k_v_proj": u_scheme_k_v_proj, + "u_transform_gate_up_proj": u_scheme_gate_up_proj, "v_transform_linear": v_scheme, - "v_transform_down_proj": v_scheme_down_proj + "v_transform_down_proj": v_scheme_down_proj, } ) -#MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" -MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +recipe = QuantizationModifier( + targets="Linear", + ignore=["lm_head"], + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + symmetric=True, + strategy=QuantizationStrategy.GROUP, + group_size=128, + ), + ) + }, + transforms_config=config, +) + +MODEL_ID = "meta-llama/Llama-3.2-1B" model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, - device_map="auto", - torch_dtype="auto", + MODEL_ID, device_map="auto", torch_dtype="auto" ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +oneshot(model=model, recipe=recipe) + +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-Transforms" +model.save_pretrained(SAVE_DIR) +tokenizer.save_pretrained(SAVE_DIR) + +""" x = model.model.layers[0] attn = x.self_attn mlp = x.mlp @@ -104,16 +155,26 @@ attn.o_proj, mlp.gate_proj, mlp.down_proj, - mlp.up_proj + mlp.up_proj, ] -for layer in layers: +from compressed_tensors.transforms.hadamard_utils import ( + deterministic_hadamard_matrix, + random_hadamard_matrix, +) +for layer in layers: current_weight = layer.weight + original_weight = current_weight.data.clone() (n, m) = current_weight.shape - U = torch.eye(n).to("cuda").to(torch.bfloat16) - V = torch.eye(m).to("cuda").to(torch.bfloat16) - print(n, layer) + + U = torch.Tensor(random_hadamard_matrix(n)).to("cuda").to(torch.float32) + V = torch.Tensor(random_hadamard_matrix(m)).to("cuda").to(torch.float32) output = torch.matmul(U, current_weight) output = torch.matmul(output, V.T) + + # apply untransform + x = torch.matmul(U.T, torch.matmul(output, V)) + print(torch.max(abs(x - original_weight))) +""" diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 3245c8604..f20951b88 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -4,6 +4,7 @@ from compressed_tensors.quantization import QuantizationStatus, is_attention_module from compressed_tensors.quantization.lifecycle.forward import forward_quantize from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme +from compressed_tensors.transforms.apply import apply_transforms_to_parameter from compressed_tensors.utils.offload import is_module_offloaded, update_parameter_data from loguru import logger from torch.nn import Module @@ -123,25 +124,15 @@ def update_weight_zp_scale(module: Module): transform_data = getattr(module, "transform_data", None) if transform_data is not None: - # order that the transforms were added to match the order they should be applied untransformed_weight = module.weight.data.clone() - for transform_name, transform_values in transform_data.data.items(): - transform = getattr(module, transform_name) - apply = transform_values.get("apply") - call_args = transform_values.get("call_args") - if call_args: - transformed_weight = apply( - input_tensor=module.weight, transform=transform, **call_args - ) - else: - transformed_weight = apply( - input_tensor=module.weight, transform=transform - ) - module.weight.data.copy_(transformed_weight) + apply_transforms_to_parameter( + module=module, + module_parameter=module.weight, + transform_data=transform_data, + ) call_observer(module=module, base_name="weight") - # TODO: what do we do here? if transform_data is not None: module.weight.data.copy_(untransformed_weight)