-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
Describe the bug
When I import TVM, create a target, and then load a Hugging Face transformers model (google/flan-t5-base with torch_dtype=torch.bfloat16 and use_cache=True) and call generate(), the Python process crashes with a segmentation fault.
The crash happens before any TVM compilation or runtime calls on the model — simply creating a TVM target and then using AutoModelForSeq2SeqLM.generate() is enough to trigger a segfault. The stack trace shows the failure occurring during dlopen and initialization of LLVM’s COFF option table (llvm::opt::OptTable::buildPrefixChars() / COFFDirectiveParser.cpp).
This looks like a dynamic linking / LLVM initialization interaction between TVM and other LLVM-using components loaded by transformers / PyTorch.
Minimal reproducible example
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Minimal repro: TVM + transformers (flan-t5-base) cause a segfault.
Steps:
1) Import TVM and create a target/device.
2) Import AutoModelForSeq2SeqLM("google/flan-t5-base", bfloat16, use_cache=True).
3) Run one generate() on random input_ids.
"""
import torch
from torch import nn
import tvm
def main():
# 1) Load TVM and create a target (triggers LLVM / TVM runtime loading)
if torch.cuda.is_available():
target = tvm.target.Target("cuda")
device = "cuda"
else:
target = tvm.target.Target("llvm")
device = "cpu"
print("TVM target:", target)
# 2) Now import transformers and load flan-t5-base
from transformers import AutoModelForSeq2SeqLM
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForSeq2SeqLM.from_pretrained(
"google/flan-t5-base",
torch_dtype=torch.bfloat16,
use_cache=True,
)
def forward(self, input_ids, attention_mask=None, **gen_kwargs):
return self.model.generate(
input_ids,
attention_mask=attention_mask,
**gen_kwargs,
)
model = MyModel().to(device)
model.eval()
# 3) Single generate() call on random input
input_ids = torch.randint(0, 10000, (1, 512), dtype=torch.long, device=device)
attention_mask = torch.ones_like(input_ids)
with torch.no_grad():
out = model(
input_ids,
attention_mask=attention_mask,
max_new_tokens=8,
)
print("generate() finished, output shape:", out.shape)
if __name__ == "__main__":
main()Run:
python minimal_tvm_transformers_segfault.pyActual behavior
On my machine, the script prints the TVM target and then immediately crashes with a segmentation fault. The beginning of the output looks like this:
TVM target: cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32
!!!!!!! Segfault encountered !!!!!!!
File "./signal/../sysdeps/unix/sysv/linux/x86_64/libc_sigaction.c", line 0, in 0x00007e827de4251f
File "<unknown>", line 0, in llvm::opt::OptTable::buildPrefixChars()
File "<unknown>", line 0, in COFFOptTable::COFFOptTable()
File "<unknown>", line 0, in _GLOBAL__sub_I_COFFDirectiveParser.cpp
File "./elf/dl-init.c", line 70, in call_init
File "./elf/dl-init.c", line 33, in call_init
File "./elf/dl-init.c", line 117, in _dl_init
File "./elf/dl-error-skeleton.c", line 182, in __GI__dl_catch_exception
File "./elf/dl-open.c", line 808, in dl_open_worker
...
Segmentation fault (core dumped)
The full stack trace is quite long, but it mainly consists of dlopen / dl-init frames and LLVM initialization calls such as llvm::opt::OptTable::buildPrefixChars() and COFFDirectiveParser.cpp global constructors.
Expected behavior
I expect the script to run without a segmentation fault, print the TVM target, run one generate() call on flan-t5-base, and print the generated output tensor shape.
TVM is not actually compiling or running this model in the repro — only importing TVM and creating a target is required — so ideally it should coexist safely with transformers / PyTorch / their dependencies.
Environment
-
OS: Linux x86_64 (glibc-based, from backtrace paths such as
./elf/dl-open.c) -
Python:
3.10.16 | packaged by conda-forge | (main, Apr 8 2025, 20:53:32) [GCC 13.3.0] -
NumPy:
2.2.6 -
PyTorch:
2.9.0+cu128 -
TVM:
- Version:
0.22.0 - LLVM version (reported by
tvm.support.libinfo()):17.0.6 - GIT_COMMIT_HASH:
9dbf3f22ff6f44962472f9af310fda368ca85ef2
- Version:
-
GPU / CUDA:
- TVM target:
cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32 - CUDA toolkit likely 12.8 (from PyTorch build tag
+cu128)
- TVM target:
import tvm, torch, transformers
from tvm import support
print("TVM version:", getattr(tvm, "__version__", "unknown"))
print("TVM LLVM version:", support.libinfo().get("LLVM_VERSION", "unknown"))
print("PyTorch:", torch.__version__)
print("transformers:", transformers.__version__)Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
- needs-triage
- bug