Skip to content

Conversation

@electron271
Copy link

@electron271 electron271 commented Sep 5, 2025

currently im using my own github actions builds of bitsandbytes as the main bitsandbytes builds have multiple issues with rocm (not supporting all architectures, and the ones i mentioned in the repo https://github.com/electron271/bitsandbytes-index)

once bitsandbytes-foundation/bitsandbytes#1519 is fixed this can be changed

closes #37

@electron271 electron271 marked this pull request as draft September 5, 2025 22:18
@electron271
Copy link
Author

windows support may also be possible but i would need some help testing this as i do not have a windows machine

@electron271 electron271 marked this pull request as ready for review September 6, 2025 00:29
@electron271 electron271 marked this pull request as draft September 6, 2025 01:28
@electron271
Copy link
Author

electron271 commented Sep 6, 2025

docs changes:

diff --git a/get-started/installing-+-updating/pip-install.md b/get-started/installing-+-updating/pip-install.md
index c1f0975..5f66dbf 100644
--- a/get-started/installing-+-updating/pip-install.md
+++ b/get-started/installing-+-updating/pip-install.md
@@ -24,6 +24,16 @@ pip uninstall unsloth unsloth_zoo -y && pip install --no-deps git+https://github
 
 If you're installing Unsloth in Jupyter, Colab, or other notebooks, be sure to prefix the command with `!`. This isn't necessary when using a terminal
 
+**To install Unsloth on AMD GPUs:**
+
+{% hint style="info" %}
+You can safely ignore errors about CUDA not being linked properly if you are installing Unsloth on AMD GPUs.
+{% endhint %}
+
+```bash
+pip install "unsloth[rocm64-torch280]"
+```
+
 ## Uninstall + Reinstall
 
 If you're still encountering dependency issues with Unsloth, many users have resolved them by forcing uninstalling and reinstalling Unsloth:
diff --git a/get-started/beginner-start-here/unsloth-requirements.md b/get-started/beginner-start-here/unsloth-requirements.md
index 793bd63..b5f5429 100644
--- a/get-started/beginner-start-here/unsloth-requirements.md
+++ b/get-started/beginner-start-here/unsloth-requirements.md
@@ -8,7 +8,7 @@ description: Here are Unsloth's requirements including system and GPU VRAM requi
 
 * **Operating System**: Works on Linux and Windows.
 * Supports NVIDIA GPUs since 2018+ including [Blackwell RTX 50](../../basics/training-llms-with-blackwell-rtx-50-series-and-unsloth) series. Minimum CUDA Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40, A100, H100, L40 etc) [Check your GPU!](https://developer.nvidia.com/cuda-gpus) GTX 1070, 1080 works, but is slow.
-* Unsloth should work on [AMD](https://github.com/unslothai/unsloth/pull/2520) and [Intel](https://github.com/unslothai/unsloth/pull/2621) GPUs! Apple/Silicon/MLX is in the works.
+* Unsloth should work on [AMD](../installing-+-updating/pip-install#amd-installation) and [Intel](https://github.com/unslothai/unsloth/pull/2621) GPUs! Apple/Silicon/MLX is in the works.
 * If you have different versions of torch, transformers etc., `pip install unsloth` will automatically install all the latest versions of those libraries so you don't need to worry about version compatibility.
 * Your device must have `xformers`, `torch`, `BitsandBytes` and `triton` support.
 

@electron271 electron271 marked this pull request as ready for review September 6, 2025 01:39
@electron271
Copy link
Author

seems like 4bit exporting has some issues as 64 blocksize is not supported with rocm (ROCm/bitsandbytes#10), it is possible to have 64 blocksize though depending on warp size so i will look into submitting a pr to bitsandbytes

@electron271
Copy link
Author

i have found a likely solution, if it works maybe i can switch over the builds to my fork until its merged in so 4bit works

@electron271 electron271 marked this pull request as draft September 6, 2025 05:56
@electron271
Copy link
Author

marking as draft until i get this issue fixed as it is fairly major

@electron271
Copy link
Author

pr created: bitsandbytes-foundation/bitsandbytes#1748

@electron271
Copy link
Author

should work now, testing changes

@electron271 electron271 marked this pull request as ready for review September 7, 2025 00:15
@electron271
Copy link
Author

works

@emuchogu
Copy link
Contributor

emuchogu commented Sep 9, 2025

Works great on AMD MI100.

I added this to my vllm Dockerfile and it just worked.

RUN git clone --recurse https://github.com/ROCm/bitsandbytes && cd bitsandbytes && git checkout rocm_enabled_multi_backend && pip install -r requirements-dev.txt && cmake -DCOMPUTE_BACKEND=hip -S . && make -j  && pip install .
RUN git clone https://github.com/electron271/unsloth-rocm.git && cd unsloth-rocm && pip install .
RUN pip install unsloth_zoo

Thanks

@electron271
Copy link
Author

Works great on AMD MI100.

I added this to my vllm Dockerfile and it just worked.

RUN git clone --recurse https://github.com/ROCm/bitsandbytes && cd bitsandbytes && git checkout rocm_enabled_multi_backend && pip install -r requirements-dev.txt && cmake -DCOMPUTE_BACKEND=hip -S . && make -j  && pip install .
RUN git clone https://github.com/electron271/unsloth-rocm.git && cd unsloth-rocm && pip install .
RUN pip install unsloth_zoo

Thanks

great to hear! you also shouldn't need to use the rocm fork of bitsandbytes (afaik), this branch will install rocm supported bitsandbytes as a dependency and if you want to manually install it was merged into main so you can use main bitsandbytes

@nole70
Copy link

nole70 commented Sep 9, 2025

I ran git clone https://github.com/electron271/unsloth-rocm.git && cd unsloth-rocm && pip install .[rocm-torch280] on MI300x and tried to do DPO and get this error:

Traceback (most recent call last):
  File "/workspace/script.py", line 193, in <module>
    dpo_trainer.train()
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/trainer.py", line 2328, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 323, in _fast_inner_training_loop
  File "<string>", line 40, in _unsloth_training_step
  File "/tmp/unsloth_compiled_cache/UnslothDPOTrainer.py", line 2065, in compute_loss
    loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/unsloth_compiled_cache/UnslothDPOTrainer.py", line 1981, in get_batch_loss_metrics
    model_output = self.concatenated_forward(model, batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/unsloth_compiled_cache/UnslothDPOTrainer.py", line 1855, in concatenated_forward
    outputs = model(input_ids, **model_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/accelerate/utils/operations.py", line 818, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/accelerate/utils/operations.py", line 806, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/peft/peft_model.py", line 1850, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 222, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/unsloth_compiled_cache/unsloth_compiled_module_gemma3.py", line 880, in forward
    return Gemma3ForConditionalGeneration_forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, logits_to_keep, **lm_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_dynamo/external_utils.py", line 198, in nonrecursive_disable_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/unsloth_compiled_cache/unsloth_compiled_module_gemma3.py", line 696, in Gemma3ForConditionalGeneration_forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/utils/generic.py", line 940, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 937, in forward
    outputs = self.language_model(
              ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 555, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/modeling_layers.py", line 93, in __call__
    return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 488, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/unsloth_zoo/gradient_checkpointing.py", line 475, in forward
    outputs = run_function(*args)
              ^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/utils/generic.py", line 1024, in wrapped_forward
    output = orig_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 389, in forward
    hidden_states, self_attn_weights = self.self_attn(
                                       ^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gemma.py", line 762, in forward
    return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gemma.py", line 643, in forward_function
    query_states_fp16 = self.q_proj(hidden_states) # output fp16
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/unsloth_compiled_cache/Linear4bit_peft_forward.py", line 56, in unsloth_forward
    result = self.base_layer(x, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/nn/modules.py", line 565, in forward
    return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py", line 466, in matmul_4bit
    return MatMul4Bit.apply(A, B, out, bias, quant_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py", line 380, in forward
    output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/functional.py", line 1002, in dequantize_4bit
    out = torch.ops.bitsandbytes.dequantize_4bit.default(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/library.py", line 752, in func_no_dynamo
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py", line 361, in _
    _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py", line 389, in _dequantize_4bit_impl
    torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
  File "/workspace/venv312/lib/python3.12/site-packages/torch/__init__.py", line 1684, in _check
    _check_with(RuntimeError, cond, message)
  File "/workspace/venv312/lib/python3.12/site-packages/torch/__init__.py", line 1666, in _check_with
    raise error_type(message_evaluated)
RuntimeError: Expected cond to be True, but got False. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

@electron271
Copy link
Author

I ran git clone https://github.com/electron271/unsloth-rocm.git && cd unsloth-rocm && pip install .[rocm-torch280] on MI300x and tried to do DPO and get this error:

Traceback (most recent call last):
  File "/workspace/script.py", line 193, in <module>
    dpo_trainer.train()
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/trainer.py", line 2328, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 323, in _fast_inner_training_loop
  File "<string>", line 40, in _unsloth_training_step
  File "/tmp/unsloth_compiled_cache/UnslothDPOTrainer.py", line 2065, in compute_loss
    loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/unsloth_compiled_cache/UnslothDPOTrainer.py", line 1981, in get_batch_loss_metrics
    model_output = self.concatenated_forward(model, batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/unsloth_compiled_cache/UnslothDPOTrainer.py", line 1855, in concatenated_forward
    outputs = model(input_ids, **model_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/accelerate/utils/operations.py", line 818, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/accelerate/utils/operations.py", line 806, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/peft/peft_model.py", line 1850, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 222, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/unsloth_compiled_cache/unsloth_compiled_module_gemma3.py", line 880, in forward
    return Gemma3ForConditionalGeneration_forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, logits_to_keep, **lm_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_dynamo/external_utils.py", line 198, in nonrecursive_disable_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/unsloth_compiled_cache/unsloth_compiled_module_gemma3.py", line 696, in Gemma3ForConditionalGeneration_forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/utils/generic.py", line 940, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 937, in forward
    outputs = self.language_model(
              ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 555, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/modeling_layers.py", line 93, in __call__
    return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 488, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/unsloth_zoo/gradient_checkpointing.py", line 475, in forward
    outputs = run_function(*args)
              ^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/utils/generic.py", line 1024, in wrapped_forward
    output = orig_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 389, in forward
    hidden_states, self_attn_weights = self.self_attn(
                                       ^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gemma.py", line 762, in forward
    return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gemma.py", line 643, in forward_function
    query_states_fp16 = self.q_proj(hidden_states) # output fp16
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/unsloth_compiled_cache/Linear4bit_peft_forward.py", line 56, in unsloth_forward
    result = self.base_layer(x, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/nn/modules.py", line 565, in forward
    return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py", line 466, in matmul_4bit
    return MatMul4Bit.apply(A, B, out, bias, quant_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py", line 380, in forward
    output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/functional.py", line 1002, in dequantize_4bit
    out = torch.ops.bitsandbytes.dequantize_4bit.default(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/torch/library.py", line 752, in func_no_dynamo
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py", line 361, in _
    _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
  File "/workspace/venv312/lib/python3.12/site-packages/bitsandbytes/backends/cuda/ops.py", line 389, in _dequantize_4bit_impl
    torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
  File "/workspace/venv312/lib/python3.12/site-packages/torch/__init__.py", line 1684, in _check
    _check_with(RuntimeError, cond, message)
  File "/workspace/venv312/lib/python3.12/site-packages/torch/__init__.py", line 1666, in _check_with
    raise error_type(message_evaluated)
RuntimeError: Expected cond to be True, but got False. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

4bit is broken on CDNA gpus as they do not support 64 block size, i am unaware if there is a solution or not

@billishyahao
Copy link
Contributor

Hi @electron271 , glad to see this fabulous contribution for amd GPU. Let me help on verifying on more kinds of devices and hope to collaborate on this.

@billishyahao
Copy link
Contributor

billishyahao commented Sep 10, 2025

I like the way to provide our end user the fresh prebuilt bnb binary directly in the patch. Somehow this does not work in some environment
image
That's one of the reasons why I install bnb from source in my previous patch #2520
I suggest to provide a dockerfile of rocm for end user to ensure this would work finally. What do you think ? cc @danielhanchen @shimmyshimmer

@electron271
Copy link
Author

I like the way to provide our end user the fresh prebuilt bnb binary directly in the patch. Somehow this does not work in some environment image That's one of the reasons why I install bnb from source in my previous patch #2520 I suggest to provide a dockerfile of rocm for end user to ensure this would work finally. What do you think ? cc @danielhanchen @shimmyshimmer

i think a dockerfile would be beneficial for systems that dont support this. this error is caused by having a out of date system, the minimally usable version of gcc is GCC 13.2, released July 27, 2023. i will note that i had a lot of issues with dockerized rocm when i was trying to get unsloth working on rocm initially, so i'm not sure if i am able to help with it.

@electron271
Copy link
Author

the upstream bitsandbytes pr should hopefully be able to be merged soon

@matthewdouglas
Copy link

Hi @electron271
You'll want to try to build on Ubuntu 22.04 instead of Ubuntu 24.04 to have better compatibility - your repo is producing wheels with a glibc 2.39 requirement.

With that said, the official bitsandbytes wheels we build and will eventually publish are compatible with Ubuntu 22.04 (and other supported systems with glibc>=2.24).

I am going to go ahead and merge that PR on bitsandbytes soon; we'll drop the ROCm 6.1 build and keep 6.2/6.3/6.4/7.0. We still need to add the RDNA4/CDNA4 build targets (RX 9070/9060, MI350X/MI355X), and need to keep in mind that while this can enable blocksize 64 on RDNA (consumer) it won't for CDNA (datacenter).

cc @billishyahao @danielhanchen

@electron271
Copy link
Author

Hi @electron271 You'll want to try to build on Ubuntu 22.04 instead of Ubuntu 24.04 to have better compatibility - your repo is producing wheels with a glibc 2.39 requirement.

With that said, the official bitsandbytes wheels we build and will eventually publish are compatible with Ubuntu 22.04 (and other supported systems with glibc>=2.24).

I am going to go ahead and merge that PR on bitsandbytes soon; we'll drop the ROCm 6.1 build and keep 6.2/6.3/6.4/7.0. We still need to add the RDNA4/CDNA4 build targets (RX 9070/9060, MI350X/MI355X), and need to keep in mind that while this can enable blocksize 64 on RDNA (consumer) it won't for CDNA (datacenter).

cc @billishyahao @danielhanchen

done, my bitsandbytes builds are temporarily broken though as i reached maximum git lfs bandwidth and the limit resets in ~30 days. will think of a potential solution

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] AMD GPU

5 participants