-
Notifications
You must be signed in to change notification settings - Fork 453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support fp8 w8a8 for pt backend #2959
Conversation
@@ -663,7 +664,9 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, | |||
loaded_weight: torch.Tensor, rank: int, | |||
world_size: int): | |||
"""weight loader for rowwise linear.""" | |||
if loaded_weight.dim() == 2 and param.dtype == torch.int8: | |||
if loaded_weight.dim() == 2 and param.dtype in (torch.int8, | |||
torch.float8_e4m3fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the minimum requirement of torch given "torch.float8_e4m3fn" and "torch.float8_e5m2"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch>=2.3.0 https://pytorch.org/docs/2.3/tensors.html#data-types
Conflicts: lmdeploy/pytorch/backends/cuda/qmodules.py
https://github.com/mit-han-lab/smoothquant/blob/c61476d728e42ae0d8a35e7e78494edcac3237b5/smoothquant/smooth.py#L61 |
It performs clamp directly on the weight while in our codes it only influence the |
async def __step(): | ||
"""step decoding.""" | ||
prefill = self.scheduler.has_waiting() | ||
prefill = __do_prefill() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the benefit of this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is added by for better schedule policy @grimoire
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefill will not enable cudagraph, short prefill would lead to long kernel launch overhead. So we will not do prefill unless we have enough batch size.
@@ -137,7 +150,8 @@ def forward(self, input): | |||
""" | |||
|
|||
if isinstance(input, torch.Tensor): | |||
input_quant, input_scale = per_token_quant_int8(input, 1e-7) | |||
input_quant, input_scale = per_token_quant_int8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what case this branch will be called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may happen when the pre op is a non-quant op and its output is feeding to the quant-linear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
always quant input
into int8 even though self.quant_dtype is torch.float8_e4m3fn
?
May remove the conflict |
cc @jinminxi104 will it affect dlinfer? |
May check if dlinfer support w8a8-int8/fp8 |
dlinfer is planning supporting w8a8-int8. But no fp8 plan. |
ut failed |
CUDA_VISIBLE_DEVICES=4 lmdeploy lite smooth_quant /models/140/InternLM/internlm2_5-7b-chat --work-dir internlm2_5-7b-chat-int8 Then test the python examples/workspace/test_pipeline.py internlm2_5-7b-chat-int8
reproducible script "examples/workspace/test_pipeline.py" is: from lmdeploy import pipeline, TurbomindEngineConfig, PytorchEngineConfig, GenerationConfig, ChatTemplateConfig
import fire
import time
from transformers import AutoModelForCausalLM
from lmdeploy.utils import get_logger
def get_chat_template_config(**kwargs):
chat_template_config = None
chat_template = kwargs.get('model_name', None)
if chat_template:
chat_template_config = dict(model_name=chat_template)
chat_template_config.update({k: v for k, v in kwargs.items() if hasattr(ChatTemplateConfig, k)})
chat_template_config = ChatTemplateConfig(**chat_template_config)
return chat_template_config
def apply_turbomind(model_path, log_level='INFO', **kwargs):
engine_config = TurbomindEngineConfig()
for key, value in kwargs.items():
if hasattr(TurbomindEngineConfig, key):
setattr(engine_config, key, value)
chat_template_config = get_chat_template_config(**kwargs)
pipe = pipeline(model_path, backend_config=engine_config, chat_template_config=chat_template_config, log_level=log_level)
return pipe
def apply_pytorch(model_path, log_level='INFO', **kwargs):
engine_config = PytorchEngineConfig()
for key, value in kwargs.items():
if hasattr(PytorchEngineConfig, key):
setattr(engine_config, key, value)
chat_template_config = get_chat_template_config(**kwargs)
pipe = pipeline(model_path, backend_config=engine_config, chat_template_config=chat_template_config, log_level=log_level)
return pipe
def apply_transformers(model_path, **kwargs):
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
return model
def get_gen_config(**kwargs):
gen_config = GenerationConfig()
for key, value in kwargs.items():
if hasattr(gen_config, key):
setattr(gen_config, key, value)
return gen_config
def main(model_path, backend='turbomind', log_level='INFO', **kwargs):
gen_config = get_gen_config(**kwargs)
start = time.perf_counter()
if backend == 'turbomind':
pipe = apply_turbomind(model_path, log_level, **kwargs)
elif backend == 'pytorch':
pipe = apply_pytorch(model_path, log_level, **kwargs)
# elif backend == 'transformers':
# pipe = apply_transformers(model_path, **kwargs)
else:
assert 0, f'unsupported backend {backend}'
end = time.perf_counter()
print(f'building pipeline cost: {end - start} s')
prompt = 'write a "hello, world" using cuda'
response = pipe(prompt, gen_config=gen_config)
# messages = [
# dict(role='user', content=[
# dict(type='text', text='hi')
# ])]
# response = pipe(messages, gen_config=gen_config)
print(response)
# input('press any key to exit')
if __name__ == "__main__":
fire.Fire(main) |
Motivation
Support fp8 w8a8 for pt backend
Modification
Support fp8 w8a8 for pt backend
BC-breaking (Optional)
None
Use cases (Optional)
fp8 quant
chat
Checklist