breaks for 4 dimensions #980

HDCharles opened this issue Jan 23, 2024 · 1 comment

breaks for 4 dimensions #980

HDCharles opened this issue Jan 23, 2024 · 1 comment


System Info

A100 cuda 12.1

if dims>3 it just leaves rows empty which causes an error if to_order isn't "row" or "column"

repro with

mod=bitsandbytes.nn.Linear8bitLt(1280, 3840, bias=True, has_fp16_weights=False).to(0).cuda()
mod(torch.randn((400,14,14,1280) , device='cuda', dtype=torch.float16))

UnboundLocalError: local variable 'rows' referenced before assignment

note even after fixing that there are further issues

AttributeError: 'NoneType' object has no attribute 'dtype'

Expected behavior

not breaking

As of v0.45.0, due to #1401, we'll instead raise:
AssertionError: Only two or three dimensional matrices are supported for argument A.

A workaround for dim>3 could be to fold the additional dimensions first.

