Skip to content

RuntimeError: expected scalar type BFloat16 but found Float #293

@ZailoxTT

Description

@ZailoxTT

When I try to run this in Google Collab:

!python -m qai_hub_models.models.llama_v3_2_3b_instruct.export \
    --chipset qualcomm-snapdragon-8-elite 
    --skip-inferencing \
    --skip-profiling \
    --output-dir ./genie_bundle

Then I get the error:

Some quantized models require the AIMET-ONNX package, which is only supported on Linux. Quantized model can be exported without this requirement.
Quantized models require the AIMET-ONNX package, which is only supported on Linux. Install qai-hub-models on a Linux machine to use quantized models.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
+-----------------------------------------------------------------------------------------+
| ⚠️ Warning: Insufficient memory                                                         |
|                                                                                         |
| Recommended memory (RAM + swap): 80 GB (currently 12 GB)                                |
|                                                                                         |
| Recommended swap space: 69 GB (currently 0 GB)                                          |
|                                                                                         |
| The process could get killed with out-of-memory error during export/demo.               |
|                                                                                         |
| This can be avoided by increasing your swap space. Please follow these instructions:    |
|                                                                                         |
|   https://github.com/quic/ai-hub-apps/blob/main/tutorials/llm_on_genie/increase_swap.md |
|                                                                                         |
+-----------------------------------------------------------------------------------------+

Loading model config from meta-llama/Llama-3.2-3B-Instruct

Loading tokenizer from meta-llama/Llama-3.2-3B-Instruct
Loading weights: 100% 254/254 [00:00<00:00, 675.07it/s, Materializing param=model.norm.weight]

Exporting ONNX model with sequence length 128 and context length 4096. This could take around 10 minutes.
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/usr/local/lib/python3.12/dist-packages/qai_hub_models/models/llama_v3_2_3b_instruct/export.py", line 53, in <module>
    main()
  File "/usr/local/lib/python3.12/dist-packages/qai_hub_models/models/llama_v3_2_3b_instruct/export.py", line 37, in main
    export_main(
  File "/usr/local/lib/python3.12/dist-packages/qai_hub_models/models/_shared/llm/export.py", line 914, in export_main
    export_model(
  File "/usr/local/lib/python3.12/dist-packages/qai_hub_models/models/_shared/llm/export.py", line 331, in export_model
    model = model_cls.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/qai_hub_models/models/llama_v3_2_3b_instruct/model.py", line 224, in from_pretrained
    cls.create_onnx_models(
  File "/usr/local/lib/python3.12/dist-packages/qai_hub_models/models/_shared/llm/model.py", line 1344, in create_onnx_models
    get_onnx_model(
  File "/usr/local/lib/python3.12/dist-packages/qai_hub_models/models/_shared/llm/model.py", line 380, in get_onnx_model
    safe_torch_onnx_export(
  File "/usr/local/lib/python3.12/dist-packages/qai_hub_models/utils/onnx/helpers.py", line 301, in safe_torch_onnx_export
    torch.onnx.export(*args, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/torch/onnx/__init__.py", line 424, in export
    export(
  File "/usr/local/lib/python3.12/dist-packages/torch/onnx/utils.py", line 522, in export
    _export(
  File "/usr/local/lib/python3.12/dist-packages/torch/onnx/utils.py", line 1457, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/onnx/utils.py", line 1080, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/onnx/utils.py", line 964, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/onnx/utils.py", line 871, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/jit/_trace.py", line 1504, in _get_trace_graph
    outs = ONNXTracedModule(
           ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/jit/_trace.py", line 138, in forward
    graph, _out = torch._C._create_graph_by_tracing(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/jit/_trace.py", line 129, in wrapper
    outs.append(self.inner(*trace_inputs))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/qai_hub_models/utils/base_model.py", line 534, in __call__
    return torch.nn.Module.__call__(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/qai_hub_models/models/_shared/llm/model.py", line 720, in forward
    out = self.model(**model_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 835, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama/modeling_llama.py", line 486, in forward
    outputs: BaseModelOutputWithPast = self.model(
                                       ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 1002, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama/modeling_llama.py", line 421, in forward
    hidden_states = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_layers.py", line 93, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama/modeling_llama.py", line 320, in forward
    hidden_states, _ = self.self_attn(
                       ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/qai_hub_models/models/_shared/llama3/model_adaptations.py", line 212, in forward_sha
    q_proj(hidden_states).permute(0, 2, 3, 1) for q_proj in self.q_proj_sha
    ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/conv.py", line 548, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/conv.py", line 543, in _conv_forward
    return F.conv2d(
           ^^^^^^^^^
RuntimeError: expected scalar type BFloat16 but found Float

qai-hub-models version: 0.49.1
qai-hub version: 0.47.0
huggingface_hub version: 1.9.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions