Skip to content

Commit

Permalink
complete test test_lora_bnb_4bit_quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
GM-git-dotcom committed Feb 29, 2024
1 parent 412bf81 commit 637c828
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_lora_gptq_quantization_from_pretrained_safetensors(self):
@pytest.mark.single_gpu_tests
def test_lora_bnb_4bit_quantization(self):
r"""
Test that tests if the 4bit quantization using LoRA works as expected
Test that tests if the 4bit quantization for Linear and Embedding using LoRA works as expected
"""
whisper_4bit = WhisperForConditionalGeneration.from_pretrained(
self.audio_model_id,
Expand All @@ -342,28 +342,48 @@ def test_lora_bnb_4bit_quantization(self):
)

flan_lora_config = LoraConfig(
r=16, lora_alpha=32, target_modules=["q", "v"], lora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM"
r=16, lora_alpha=32, target_modules=["q", "v", "embed_tokens"],
lora_dropout=0.05,
bias="none",
task_type="SEQ_2_SEQ_LM"
)

opt_lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
target_modules=["q_proj", "v_proj", "embed_tokens"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
whisper_lora_config = LoraConfig(
r=32,
lora_alpha=64,
target_modules=["q_proj", "v_proj", "embed_tokens"],
lora_dropout=0.05,
bias="none")

quantize_embedding(flan_4bit, ["embed_tokens"])
quantize_embedding(opt_4bit, ["embed_tokens"])
quantize_embedding(whisper_4bit, ["embed_tokens"])

# Embedding4bit: Have to specify embed_tokens again in quantize_embeddings
# This arises from the fact that get_peft_model only dispatches for `target_modules` in the LoraConfig
# So unless "embed_tokens" is a target, the bnbEmbedding4bit (quantized embeddings) won't be converted
# to Embedding4bit

flan_4bit = get_peft_model(flan_4bit, flan_lora_config)
assert isinstance(flan_4bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, LoraLinear4bit)
assert isinstance(flan_4bit.base_model.encoder.embed_tokens, LoraEmbedding4bit)

opt_4bit = get_peft_model(opt_4bit, opt_lora_config)
assert isinstance(opt_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit)
assert isinstance(opt_4bit.base_model.model.model.decoder.embed_tokens, LoraEmbedding4bit)

whisper_4bit = get_peft_model(whisper_4bit, config)
whisper_4bit = get_peft_model(whisper_4bit, whisper_lora_config)
assert isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit)
assert isinstance(whisper_4bit.base_model.model.model.decoder.embed_tokens, LoraEmbedding4bit)

@require_bitsandbytes
@pytest.mark.multi_gpu_tests
Expand Down

0 comments on commit 637c828

Please sign in to comment.