Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fp8 w8a8 for pt backend #2959

Merged
merged 10 commits into from
Jan 3, 2025
Merged

Conversation

RunningLeon
Copy link
Collaborator

Motivation

Support fp8 w8a8 for pt backend

Modification

Support fp8 w8a8 for pt backend

BC-breaking (Optional)

None

Use cases (Optional)

fp8 quant

lmdeploy lite smooth_quant \
  meta-llama/Meta-Llama-3-8B-Instruct \
  --work-dir Meta-Llama-3-8B-Instruct-fp8-w8a8 \
  --quant-dtype fp8

chat

lmdeploy  chat Meta-Llama-3-8B-Instruct-fp8-w8a8 --backend pytorch

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@@ -663,7 +664,9 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, rank: int,
world_size: int):
"""weight loader for rowwise linear."""
if loaded_weight.dim() == 2 and param.dtype == torch.int8:
if loaded_weight.dim() == 2 and param.dtype in (torch.int8,
torch.float8_e4m3fn,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the minimum requirement of torch given "torch.float8_e4m3fn" and "torch.float8_e5m2"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicts:
	lmdeploy/pytorch/backends/cuda/qmodules.py
@lvhan028
Copy link
Collaborator

@AllentDan
Copy link
Collaborator

https://github.com/mit-han-lab/smoothquant/blob/c61476d728e42ae0d8a35e7e78494edcac3237b5/smoothquant/smooth.py#L61 cc @AllentDan The official smoothquant clamp weight scales using 1e-5

It performs clamp directly on the weight while in our codes it only influence the Denominator.

async def __step():
"""step decoding."""
prefill = self.scheduler.has_waiting()
prefill = __do_prefill()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the benefit of this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is added by for better schedule policy @grimoire

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefill will not enable cudagraph, short prefill would lead to long kernel launch overhead. So we will not do prefill unless we have enough batch size.

@@ -137,7 +150,8 @@ def forward(self, input):
"""

if isinstance(input, torch.Tensor):
input_quant, input_scale = per_token_quant_int8(input, 1e-7)
input_quant, input_scale = per_token_quant_int8(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what case this branch will be called?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may happen when the pre op is a non-quant op and its output is feeding to the quant-linear

Copy link
Collaborator

@lvhan028 lvhan028 Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always quant input into int8 even though self.quant_dtype is torch.float8_e4m3fn?

@lvhan028
Copy link
Collaborator

May remove the conflict

@lvhan028 lvhan028 added enhancement New feature or request improvement and removed enhancement New feature or request labels Dec 31, 2024
@lvhan028
Copy link
Collaborator

cc @jinminxi104 will it affect dlinfer?

@lvhan028 lvhan028 requested a review from jinminxi104 December 31, 2024 09:35
@lvhan028
Copy link
Collaborator

May check if dlinfer support w8a8-int8/fp8
If it hasn't done it yet, may raise NotImplementedError

@jinminxi104
Copy link
Collaborator

May check if dlinfer support w8a8-int8/fp8

If it hasn't done it yet, may raise NotImplementedError

dlinfer is planning supporting w8a8-int8. But no fp8 plan.

@lvhan028
Copy link
Collaborator

ut failed

@RunningLeon RunningLeon requested a review from lvhan028 January 2, 2025 03:08
@lvhan028
Copy link
Collaborator

lvhan028 commented Jan 2, 2025

CUDA_VISIBLE_DEVICES=4 lmdeploy lite smooth_quant /models/140/InternLM/internlm2_5-7b-chat --work-dir internlm2_5-7b-chat-int8

Then test the internlm2_5-7b-chat-int8 with the following command, but got failure

python examples/workspace/test_pipeline.py internlm2_5-7b-chat-int8
2025-01-02 05:56:11,900 - lmdeploy - ERROR - request.py:21 - Engine loop failed with error: 'QTensor' object is not subscriptable
Traceback (most recent call last):
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 1071, in async_loop
    await self._async_loop()
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 1065, in _async_loop
    await __step()
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 1052, in __step
    raise e
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 1046, in __step
    raise out
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 949, in _async_loop_background
    await self._async_step_background(
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 797, in _async_step_background
    output = await self._async_model_forward(
  File "/workspace/lmdeploy/lmdeploy/utils.py", line 241, in __tmp
    return (await func(*args, **kwargs))
  File "/workspace/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 682, in _async_model_forward
    ret['hidden_states'] = ret['hidden_states'][:, last_token_loc]
TypeError: 'QTensor' object is not subscriptable

reproducible script "examples/workspace/test_pipeline.py" is:

from lmdeploy import pipeline, TurbomindEngineConfig, PytorchEngineConfig, GenerationConfig, ChatTemplateConfig
import fire
import time
from transformers import AutoModelForCausalLM
from lmdeploy.utils import get_logger

def get_chat_template_config(**kwargs):
    chat_template_config = None
    chat_template = kwargs.get('model_name', None)
    if chat_template:
        chat_template_config = dict(model_name=chat_template)
        chat_template_config.update({k: v for k, v in kwargs.items() if hasattr(ChatTemplateConfig, k)})
        chat_template_config = ChatTemplateConfig(**chat_template_config)
    return chat_template_config


def apply_turbomind(model_path, log_level='INFO', **kwargs):
    engine_config = TurbomindEngineConfig()
    for key, value in kwargs.items():
        if hasattr(TurbomindEngineConfig, key):
            setattr(engine_config, key, value)
    chat_template_config = get_chat_template_config(**kwargs)
    pipe = pipeline(model_path, backend_config=engine_config, chat_template_config=chat_template_config, log_level=log_level)

    return pipe


def apply_pytorch(model_path, log_level='INFO', **kwargs):
    engine_config = PytorchEngineConfig()
    for key, value in kwargs.items():
        if hasattr(PytorchEngineConfig, key):
            setattr(engine_config, key, value)
    chat_template_config = get_chat_template_config(**kwargs)
    pipe = pipeline(model_path, backend_config=engine_config, chat_template_config=chat_template_config, log_level=log_level)
    
    return pipe

def apply_transformers(model_path, **kwargs):
    model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
    return model
def get_gen_config(**kwargs):
    gen_config = GenerationConfig()
    for key, value in kwargs.items():
        if hasattr(gen_config, key):
            setattr(gen_config, key, value)
    return gen_config

def main(model_path, backend='turbomind', log_level='INFO', **kwargs):
    gen_config = get_gen_config(**kwargs)

    start = time.perf_counter()
    if backend == 'turbomind':
        pipe = apply_turbomind(model_path, log_level, **kwargs)
    elif backend == 'pytorch':
        pipe = apply_pytorch(model_path, log_level, **kwargs)
    # elif backend == 'transformers':
    #     pipe = apply_transformers(model_path, **kwargs)
    else:
        assert 0, f'unsupported backend {backend}'
    end = time.perf_counter()
    print(f'building pipeline cost: {end - start} s')
    prompt = 'write a "hello, world" using cuda'
    response = pipe(prompt, gen_config=gen_config)
    # messages = [
    # dict(role='user', content=[
    #     dict(type='text', text='hi')
    # ])]
    # response = pipe(messages, gen_config=gen_config)
    print(response)
    # input('press any key to exit')
    
if __name__ == "__main__":
    fire.Fire(main)

@lvhan028 lvhan028 merged commit bc86b94 into InternLM:main Jan 3, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants