Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Mixtral inference encounters error about tensor location #963

Open
5 tasks done
pl752 opened this issue Mar 23, 2025 · 6 comments
Open
5 tasks done

[Bug] Mixtral inference encounters error about tensor location #963

pl752 opened this issue Mar 23, 2025 · 6 comments

Comments

@pl752
Copy link

pl752 commented Mar 23, 2025

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/kvcache-ai/ktransformers/discussions. Otherwise, it will be closed.
  • 5. To help the community, I will use Chinese/English or attach an Chinese/English translation if using another language. Non-Chinese/English content without translation may be closed.

Describe the bug

Loading mixtral with default config works, but any message result in the following error

loading output_norm.weight to cuda:0
2025-03-23 21:26:35,453 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/context_manager.py[21]: Creating Context Manager
INFO:     Started server process [3617]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:10002 (Press CTRL+C to quit)
INFO:     127.0.0.1:60002 - "GET /v1/threads?limit=100 HTTP/1.1" 307 Temporary Redirect
INFO:     127.0.0.1:60002 - "GET /v1/threads/?limit=100 HTTP/1.1" 200 OK
INFO:     127.0.0.1:60002 - "POST /v1/threads HTTP/1.1" 307 Temporary Redirect
2025-03-23 21:26:47,380 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/crud/assistants/threads.py[26]: Creating messages first for thread
INFO:     127.0.0.1:60002 - "POST /v1/threads/ HTTP/1.1" 200 OK
2025-03-23 21:26:47,439 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/context_manager.py[52]: no context for thread d775d844-ab13-497f-be09-d1bb5b6a15ca
INFO:     127.0.0.1:60002 - "POST /v1/threads/d775d844-ab13-497f-be09-d1bb5b6a15ca/messages HTTP/1.1" 200 OK
INFO:     127.0.0.1:60002 - "POST /v1/threads/d775d844-ab13-497f-be09-d1bb5b6a15ca/runs HTTP/1.1" 200 OK
2025-03-23 21:26:49,226 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/context_manager.py[29]: keys dict_keys([])
2025-03-23 21:26:49,226 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/context_manager.py[31]: new inference context d775d844-ab13-497f-be09-d1bb5b6a15ca
2025-03-23 21:26:49,228 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/base.py[98]: 1 messages loaded from database
2025-03-23 21:26:49,228 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/base.py[123]: start working
2025-03-23 21:26:49,265 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/transformers.py[198]: get input ids of shape torch.Size([1, 10])
2025-03-23 21:26:49,266 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/ktransformers.py[137]: input_ids: torch.Size([1, 10])
2025-03-23 21:26:49,267 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/ktransformers.py[162]: same prefix len: 0
2025-03-23 21:26:49,269 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/ktransformers.py[171]: input_ids: torch.Size([1, 10])
2025-03-23 21:26:49,269 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/ktransformers.py[172]: generate_ids: torch.Size([1, 0])
2025-03-23 21:26:49,288 DEBUG /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/ktransformers.py[187]: cache position: 0 to 10
2025-03-23 21:26:58,776 INFO /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/transformers.py[333]: args.max_new_tokens: 2000, cache_lens: 8192, seq_length: 11
2025-03-23 21:26:58,777 INFO /home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/transformers.py[338]: max_new_tokens: 1999
ERROR:    Exception in ASGI application
  + Exception Group Traceback (most recent call last):
  |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/_utils.py", line 76, in collapse_excgroups
  |     yield
  |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/responses.py", line 263, in __call__
  |     async with anyio.create_task_group() as task_group:
  |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 772, in __aexit__
  |     raise BaseExceptionGroup(
  | ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
  +-+---------------- 1 ----------------
    | Traceback (most recent call last):
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/uvicorn/protocols/http/h11_impl.py", line 403, in run_asgi
    |     result = await app(  # type: ignore[func-returns-value]
    |              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    |     return await self.app(scope, receive, send)
    |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in __call__
    |     await super().__call__(scope, receive, send)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/applications.py", line 112, in __call__
    |     await self.middleware_stack(scope, receive, send)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/middleware/errors.py", line 187, in __call__
    |     raise exc
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/middleware/errors.py", line 165, in __call__
    |     await self.app(scope, receive, _send)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/middleware/cors.py", line 93, in __call__
    |     await self.simple_response(scope, receive, send, request_headers=headers)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/middleware/cors.py", line 144, in simple_response
    |     await self.app(scope, receive, send)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 62, in __call__
    |     await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    |     raise exc
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    |     await app(scope, receive, sender)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/routing.py", line 714, in __call__
    |     await self.middleware_stack(scope, receive, send)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/routing.py", line 734, in app
    |     await route.handle(scope, receive, send)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/routing.py", line 288, in handle
    |     await self.app(scope, receive, send)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/routing.py", line 76, in app
    |     await wrap_app_handling_exceptions(app, request)(scope, receive, send)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    |     raise exc
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    |     await app(scope, receive, sender)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/routing.py", line 74, in app
    |     await response(scope, receive, send)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/responses.py", line 262, in __call__
    |     with collapse_excgroups():
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/contextlib.py", line 158, in __exit__
    |     self.gen.throw(typ, value, traceback)
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/_utils.py", line 82, in collapse_excgroups
    |     raise exc
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/responses.py", line 266, in wrap
    |     await func()
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/starlette/responses.py", line 246, in stream_response
    |     async for chunk in self.body_iterator:
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/schemas/assistants/streaming.py", line 80, in check_client_link
    |     async for event in async_events:
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/schemas/assistants/streaming.py", line 93, in to_stream_reply
    |     async for event in async_events:
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/schemas/assistants/streaming.py", line 87, in add_done
    |     async for event in async_events:
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/schemas/assistants/streaming.py", line 101, in filter_api_event
    |     async for event in async_events:
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/api/openai/assistants/runs.py", line 28, in inner
    |     async for event in ctx.work():
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/base.py", line 146, in work
    |     async for res in self.interface.inference(local_messages,self.thread.id):
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/ktransformers.py", line 233, in inference
    |     async for v in super().inference(local_messages, thread_id, temperature, top_p):
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/transformers.py", line 411, in inference
    |     for t, finish_reason in self.generate():
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 36, in generator_context
    |     response = gen.send(None)
    |                ^^^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/transformers.py", line 348, in generate
    |     next_token = self.decode_one_tokens()
    |                  ^^^^^^^^^^^^^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/ktransformers.py", line 115, in decode_one_tokens
    |     logits = self.model(
    |              ^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    |     return self._call_impl(*args, **kwargs)
    |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    |     return forward_call(*args, **kwargs)
    |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/models/modeling_mixtral.py", line 1424, in forward
    |     outputs = self.model(
    |               ^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    |     return self._call_impl(*args, **kwargs)
    |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    |     return forward_call(*args, **kwargs)
    |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/ktransformers/models/modeling_mixtral.py", line 1170, in forward
    |     inputs_embeds = self.embed_tokens(input_ids)
    |                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    |     return self._call_impl(*args, **kwargs)
    |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    |     return forward_call(*args, **kwargs)
    |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 190, in forward
    |     return F.embedding(
    |            ^^^^^^^^^^^^
    |   File "/home/pl752/miniforge3/envs/ktr/lib/python3.11/site-packages/torch/nn/functional.py", line 2551, in embedding
    |     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
    |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    | RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)
    +------------------------------------

Reproduction

Model is https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF/blob/main/mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf

command is ktransformers --model_path mistralai/Mixtral-8x7B-Instruct-v0.1 --gguf_path /home/pl752/mxtrl8b/ --optimize_config_path ktransformers/optimize/optimize_rules/Mixtral.yaml --web True
config is default

Environment

Ubuntu 24.10, python 3.11, torch 2.8 nightly, cuda 12.8, rtx 3060m, ryzen 5800h, ktransformers latest (git main)

pip list

Package                  Version
------------------------ ------------------------
accelerate               1.5.2
annotated-types          0.7.0
anyio                    4.9.0
blessed                  1.20.0
blobfile                 3.0.0
build                    1.2.2.post1
certifi                  2025.1.31
cffi                     1.17.1
charset-normalizer       3.4.1
click                    8.1.8
colorlog                 6.9.0
cpufeature               0.2.1
distro                   1.9.0
einops                   0.8.1
fastapi                  0.115.11
filelock                 3.16.1
fire                     0.7.0
flash_attn               2.7.4.post1
flashinfer-python        0.2.3
fsspec                   2024.10.0
greenlet                 3.1.1
h11                      0.14.0
httpcore                 1.0.7
httpx                    0.28.1
huggingface-hub          0.29.3
idna                     3.10
Jinja2                   3.1.4
jiter                    0.9.0
jsonpatch                1.33
jsonpointer              3.0.0
ktransformers            0.2.3.post2+torch28avx2
langchain                0.3.21
langchain-core           0.3.47
langchain-text-splitters 0.3.7
langsmith                0.3.18
lxml                     5.3.1
MarkupSafe               2.1.5
mpmath                   1.3.0
networkx                 3.4.2
ninja                    1.11.1.3
numpy                    2.1.2
nvidia-cublas-cu12       12.8.3.14
nvidia-cuda-cupti-cu12   12.8.57
nvidia-cuda-nvrtc-cu12   12.8.61
nvidia-cuda-runtime-cu12 12.8.57
nvidia-cudnn-cu12        9.8.0.87
nvidia-cufft-cu12        11.3.3.41
nvidia-cufile-cu12       1.13.0.11
nvidia-curand-cu12       10.3.9.55
nvidia-cusolver-cu12     11.7.2.55
nvidia-cusparse-cu12     12.5.7.53
nvidia-cusparselt-cu12   0.6.3
nvidia-nccl-cu12         2.26.2
nvidia-nvjitlink-cu12    12.8.61
nvidia-nvtx-cu12         12.8.55
openai                   1.68.0
orjson                   3.10.15
packaging                24.2
pillow                   11.0.0
pip                      25.0.1
protobuf                 6.30.1
psutil                   7.0.0
pycparser                2.22
pycryptodomex            3.22.0
pydantic                 2.10.6
pydantic_core            2.27.2
pyproject_hooks          1.2.0
pytorch-triton           3.3.0+git96316ce5
PyYAML                   6.0.2
regex                    2024.11.6
requests                 2.32.3
requests-toolbelt        1.0.0
safetensors              0.5.3
sentencepiece            0.2.0
setuptools               75.8.2
six                      1.17.0
sniffio                  1.3.1
sounddevice              0.5.1
SQLAlchemy               2.0.39
starlette                0.46.1
sympy                    1.13.3
tenacity                 9.0.0
termcolor                2.5.0
tiktoken                 0.9.0
tokenizers               0.19.1
torch                    2.8.0.dev20250322+cu128
torchaudio               2.6.0.dev20250322+cu128
torchvision              0.22.0.dev20250322+cu128
tqdm                     4.67.1
transformers             4.43.2
triton                   3.2.0
typing_extensions        4.12.2
urllib3                  2.3.0
uvicorn                  0.34.0
wcwidth                  0.2.13
wheel                    0.45.1
zstandard                0.23.0
@qiyuxinlin
Copy link
Contributor

您好,我们已经很长时间没有试验过在 Mixtral 上是否能运行,并且这个报错看起来是因为权重和输入不在一个设备上导致的

@pl752
Copy link
Author

pl752 commented Mar 24, 2025

Okay, then what should I do in the config or parameters, in order for it to end up in the proper positions?

@pl752
Copy link
Author

pl752 commented Mar 24, 2025

And also, if I use local chat module instead of the ktransformers server module it just works.

@pl752
Copy link
Author

pl752 commented Mar 24, 2025

It seems like input tokens end up on different devices on different backends somehow

@pl752
Copy link
Author

pl752 commented Mar 24, 2025

I also have noticed that implementation of generation loop differs between server backend and local chat module, one passes input ids to model directrly, while other uses prefill_and_generate
and it seems to handle tokens differently
server

logits = self.model(
                self.current_ids.to(torch_device),
                cache_position=self.active_cache_position,

local_chat

torch.cuda.set_device(torch_device)
            inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)

@pl752
Copy link
Author

pl752 commented Mar 24, 2025

Changing ktransformers/server/backend/interfaces/ktransformers.py from line 115
from

                logits = self.model(
                self.current_ids.to(torch_device),
                cache_position=self.active_cache_position,
...

to

            inputs_embeds = self.model.model.embed_tokens(self.current_ids.to("cpu")).to(torch_device)
            logits = self.model(
                inputs_embeds=inputs_embeds,
                cache_position=self.active_cache_position,
...

Seems to fix the problem. I think generation methods need some code unification.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants