diff --git a/tensorrt_llm/models/phi3/model.py b/tensorrt_llm/models/phi3/model.py index ac29ab9a09..aed5983b22 100644 --- a/tensorrt_llm/models/phi3/model.py +++ b/tensorrt_llm/models/phi3/model.py @@ -265,6 +265,7 @@ def from_hugging_face( dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, + attn_implementation: str = 'eager', **kwargs): import transformers @@ -281,6 +282,7 @@ def from_hugging_face( dtype=dtype, mapping=mapping, quant_config=quant_config, + attn_implementation=attn_implementation **kwargs) if not use_preloading: @@ -289,7 +291,8 @@ def from_hugging_face( hf_model = AutoModelForCausalLM.from_pretrained( hf_model_dir, torch_dtype="auto", - trust_remote_code=trust_remote_code) + trust_remote_code=trust_remote_code, + attn_implementation=attn_implementation) assert isinstance(hf_model, transformers.PreTrainedModel)