Skip to content

Commit

Permalink
Merge pull request #2 from KeremTurgutlu/peft_tests
Browse files Browse the repository at this point in the history
Peft tests
  • Loading branch information
warner-benjamin authored Jan 12, 2024
2 parents e8e8eed + 4aecbdf commit 4f15388
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@ def gemv_4bit(
ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc)

if B.dtype in [torch.uint8, torch.bfloat16, torch.float32]:
if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
elif A.dtype == torch.bfloat16:
Expand Down
5 changes: 3 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2370,7 +2370,8 @@ def test_normal_map_tree():
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
def test_gemv_4bit(dtype, storage_type, double_quant, kind):
@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32'])
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
for dim in [128, 256, 512, 1024]:
#for dim in [4*1024]:
#for dim in [1*16]:
Expand Down Expand Up @@ -2399,7 +2400,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)

qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage)
C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True
Expand Down

0 comments on commit 4f15388

Please sign in to comment.