Skip to content

[Bug] Segmentation fault when using TVM together with transformers (flan-t5-base + bfloat16 + use_cache=True) #18653

@tinywisdom

Description

@tinywisdom

Describe the bug

When I import TVM, create a target, and then load a Hugging Face transformers model (google/flan-t5-base with torch_dtype=torch.bfloat16 and use_cache=True) and call generate(), the Python process crashes with a segmentation fault.

The crash happens before any TVM compilation or runtime calls on the model — simply creating a TVM target and then using AutoModelForSeq2SeqLM.generate() is enough to trigger a segfault. The stack trace shows the failure occurring during dlopen and initialization of LLVM’s COFF option table (llvm::opt::OptTable::buildPrefixChars() / COFFDirectiveParser.cpp).

This looks like a dynamic linking / LLVM initialization interaction between TVM and other LLVM-using components loaded by transformers / PyTorch.


Minimal reproducible example

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Minimal repro: TVM + transformers (flan-t5-base) cause a segfault.

Steps:
  1) Import TVM and create a target/device.
  2) Import AutoModelForSeq2SeqLM("google/flan-t5-base", bfloat16, use_cache=True).
  3) Run one generate() on random input_ids.
"""

import torch
from torch import nn
import tvm


def main():
    # 1) Load TVM and create a target (triggers LLVM / TVM runtime loading)
    if torch.cuda.is_available():
        target = tvm.target.Target("cuda")
        device = "cuda"
    else:
        target = tvm.target.Target("llvm")
        device = "cpu"
    print("TVM target:", target)

    # 2) Now import transformers and load flan-t5-base
    from transformers import AutoModelForSeq2SeqLM

    class MyModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                "google/flan-t5-base",
                torch_dtype=torch.bfloat16,
                use_cache=True,
            )

        def forward(self, input_ids, attention_mask=None, **gen_kwargs):
            return self.model.generate(
                input_ids,
                attention_mask=attention_mask,
                **gen_kwargs,
            )

    model = MyModel().to(device)
    model.eval()

    # 3) Single generate() call on random input
    input_ids = torch.randint(0, 10000, (1, 512), dtype=torch.long, device=device)
    attention_mask = torch.ones_like(input_ids)

    with torch.no_grad():
        out = model(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=8,
        )

    print("generate() finished, output shape:", out.shape)


if __name__ == "__main__":
    main()

Run:

python minimal_tvm_transformers_segfault.py

Actual behavior

On my machine, the script prints the TVM target and then immediately crashes with a segmentation fault. The beginning of the output looks like this:

TVM target: cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32
!!!!!!! Segfault encountered !!!!!!!
  File "./signal/../sysdeps/unix/sysv/linux/x86_64/libc_sigaction.c", line 0, in 0x00007e827de4251f
  File "<unknown>", line 0, in llvm::opt::OptTable::buildPrefixChars()
  File "<unknown>", line 0, in COFFOptTable::COFFOptTable()
  File "<unknown>", line 0, in _GLOBAL__sub_I_COFFDirectiveParser.cpp
  File "./elf/dl-init.c", line 70, in call_init
  File "./elf/dl-init.c", line 33, in call_init
  File "./elf/dl-init.c", line 117, in _dl_init
  File "./elf/dl-error-skeleton.c", line 182, in __GI__dl_catch_exception
  File "./elf/dl-open.c", line 808, in dl_open_worker
  ...
Segmentation fault (core dumped)

The full stack trace is quite long, but it mainly consists of dlopen / dl-init frames and LLVM initialization calls such as llvm::opt::OptTable::buildPrefixChars() and COFFDirectiveParser.cpp global constructors.


Expected behavior

I expect the script to run without a segmentation fault, print the TVM target, run one generate() call on flan-t5-base, and print the generated output tensor shape.

TVM is not actually compiling or running this model in the repro — only importing TVM and creating a target is required — so ideally it should coexist safely with transformers / PyTorch / their dependencies.


Environment

  • OS: Linux x86_64 (glibc-based, from backtrace paths such as ./elf/dl-open.c)

  • Python: 3.10.16 | packaged by conda-forge | (main, Apr 8 2025, 20:53:32) [GCC 13.3.0]

  • NumPy: 2.2.6

  • PyTorch: 2.9.0+cu128

  • TVM:

    • Version: 0.22.0
    • LLVM version (reported by tvm.support.libinfo()): 17.0.6
    • GIT_COMMIT_HASH: 9dbf3f22ff6f44962472f9af310fda368ca85ef2
  • GPU / CUDA:

    • TVM target: cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32
    • CUDA toolkit likely 12.8 (from PyTorch build tag +cu128)
import tvm, torch, transformers
from tvm import support

print("TVM version:", getattr(tvm, "__version__", "unknown"))
print("TVM LLVM version:", support.libinfo().get("LLVM_VERSION", "unknown"))
print("PyTorch:", torch.__version__)
print("transformers:", transformers.__version__)

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage
  • bug

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions