Skip to content

Run on second GPU (torch.device("cuda:1")) #24

@imabot2

Description

@imabot2

Hi, you did an awesome work ! I ran your code in an RTX3090 with offload_per_layer = 0 : Awesome !!!

I noticed that when I change the device for my second GPU device = torch.device("cuda:1"), the model is properly loaded in the GPU memory, but inference does not work:

Traceback (most recent call last):
  File "/home/philippe/tmp/mixtral2/main.py", line 112, in <module>
    result = model.generate(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/generation/utils.py", line 1764, in generate
    return self.sample(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/generation/utils.py", line 2861, in sample
    outputs = self(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1213, in forward
    outputs = self.model(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1081, in forward
    layer_outputs = decoder_layer(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 797, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 305, in forward
    query_states = self.q_proj(hidden_states)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/mixtral-offloading/src/custom_layers.py", line 50, in forward
    return self.forward_triton(x)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/mixtral-offloading/src/custom_layers.py", line 80, in forward_triton
    output = fn(
  File "/home/philippe/tmp/mixtral2/mixtral-offloading/src/triton_kernels.py", line 172, in triton_matmul4_transpose
    matmul4_kernel_transpose[grid](
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 156, in run
    ret = self.fn.run(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/triton/runtime/jit.py", line 550, in run
    bin.c_wrapper(
ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

I can't figure out what's wrong, any idea?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions