Skip to content

Commit

Permalink
Persistent compilation (#39)
Browse files Browse the repository at this point in the history
* Add versioning

* Create modelling classes

* Incremental improvement

* Additional examples

* Add info

* Remove the automatic float16 activation conversion
  • Loading branch information
alessandropalla authored Jun 4, 2024
1 parent d4928c4 commit c26443e
Show file tree
Hide file tree
Showing 14 changed files with 215 additions and 117 deletions.
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

[![Test](https://github.com/intel/intel-npu-acceleration-library/actions/workflows/test.yml/badge.svg)](https://github.com/intel/intel-npu-acceleration-library/actions/workflows/test.yml) [![Style](https://github.com/intel/intel-npu-acceleration-library/actions/workflows/style.yml/badge.svg)](https://github.com/intel/intel-npu-acceleration-library/actions/workflows/style.yml) [![Documentation](https://github.com/intel/intel-npu-acceleration-library/actions/workflows/documentation.yml/badge.svg)](https://github.com/intel/intel-npu-acceleration-library/actions/workflows/documentation.yml)

[![PyPI version](https://badge.fury.io/py/intel-npu-acceleration-library.svg)](https://badge.fury.io/py/intel-npu-acceleration-library) [![Downloads](https://static.pepy.tech/badge/intel-npu-acceleration-library)](https://pepy.tech/project/intel-npu-acceleration-library)

[Documentation](https://intel.github.io/intel-npu-acceleration-library/)

The Intel® NPU Acceleration Library is a Python library designed to boost the efficiency of your applications by leveraging the power of the Intel Neural Processing Unit (NPU) to perform high-speed computations on compatible hardware.
Expand Down Expand Up @@ -109,21 +111,18 @@ optimized_model = intel_npu_acceleration_library.compile(model, dtype=torch.int8
### Run a Tiny-llama model on the NPU

```python
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
import intel_npu_acceleration_library
from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForCausalLM
import torch

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=True).eval()
model = NPUModelForCausalLM.from_pretrained(model_id, use_cache=True, dtype=torch.int8).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
streamer = TextStreamer(tokenizer, skip_special_tokens=True)


print("Compile model for the NPU")
model = intel_npu_acceleration_library.compile(model, dtype=torch.int8)

query = input("Ask something: ")
prefix = tokenizer(query, return_tensors="pt")["input_ids"]

Expand Down
11 changes: 4 additions & 7 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,19 @@ You are now up and running! You can create a simple script like the following on


.. code-block:: python
:emphasize-lines: 12, 13
:emphasize-lines: 2, 7
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
import intel_npu_acceleration_library
from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForCausalLM
import torch
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=True).eval()
model = NPUModelForCausalLM.from_pretrained(model_id, use_cache=True, dtype=torch.int8).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
print("Compile model for the NPU")
model = intel_npu_acceleration_library.compile(model, dtype=torch.int8)
query = input("Ask something: ")
prefix = tokenizer(query, return_tensors="pt")["input_ids"]
Expand Down
1 change: 1 addition & 0 deletions docs/source/python/intel_npu_acceleration_library.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Subpackages

intel_npu_acceleration_library.backend
intel_npu_acceleration_library.nn
intel_npu_acceleration_library.functional

Submodules
----------
Expand Down
11 changes: 5 additions & 6 deletions examples/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
# SPDX-License-Identifier: Apache 2.0
#

from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
import intel_npu_acceleration_library
from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForCausalLM
import torch

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=True).eval()
model = NPUModelForCausalLM.from_pretrained(
model_id, use_cache=True, dtype=torch.int8
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
streamer = TextStreamer(tokenizer, skip_special_tokens=True)


print("Compile model for the NPU")
model = intel_npu_acceleration_library.compile(model, dtype=torch.int8)

query = input("Ask something: ")
prefix = tokenizer(query, return_tensors="pt")["input_ids"]

Expand Down
26 changes: 5 additions & 21 deletions examples/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,16 @@
# SPDX-License-Identifier: Apache 2.0
#

from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import intel_npu_acceleration_library
from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForCausalLM
import torch
import os

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
dtype = "int8"

PATH = os.path.join("models", model_id, dtype)
filename = os.path.join(PATH, "model.pth")
os.makedirs(PATH, exist_ok=True)

if not os.path.exists(filename):
print("Compile model for the NPU")
model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=True).eval()
torch_dtype = torch.int8 if dtype == "int8" else torch.float16
with torch.no_grad():
model = intel_npu_acceleration_library.compile(model, dtype=torch_dtype)
torch.save(model, filename)
del model


print(f"Loading model from {filename}")

model = NPUModelForCausalLM.from_pretrained(
model_id, dtype=torch.int8, use_cache=True
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = torch.load(filename).eval()
streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)

print("Run inference with Llama3 on NPU\n")
Expand Down
10 changes: 5 additions & 5 deletions examples/phi-2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer
from transformers import AutoTokenizer, pipeline, TextStreamer
import intel_npu_acceleration_library as npu_lib

model_id = "microsoft/Phi-2"

model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=True).eval()
model = npu_lib.NPUModelForCausalLM.from_pretrained(
model_id, use_cache=True, dtype=npu_lib.int4
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)

npu_model = npu_lib.compile(model, dtype=npu_lib.int4)

pipe = pipeline(
"text-generation",
model=npu_model,
model=model,
tokenizer=tokenizer,
max_length=256,
temperature=0.9,
Expand Down
50 changes: 0 additions & 50 deletions examples/phi-3-nc.py

This file was deleted.

9 changes: 5 additions & 4 deletions examples/phi-3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
import intel_npu_acceleration_library
import intel_npu_acceleration_library as npu_lib
import warnings

torch.random.manual_seed(0)

model = AutoModelForCausalLM.from_pretrained(
model = npu_lib.NPUModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
torch_dtype="auto",
trust_remote_code=True,
dtype=npu_lib.int4,
)
model = intel_npu_acceleration_library.compile(model, torch.float16)

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
streamer = TextStreamer(tokenizer)
streamer = TextStreamer(tokenizer, skip_prompt=True)

messages = [
{
Expand Down
14 changes: 3 additions & 11 deletions examples/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,16 @@
# SPDX-License-Identifier: Apache 2.0
#

from transformers import AutoTokenizer, TextStreamer, AutoModelForSeq2SeqLM
import intel_npu_acceleration_library
import torch
from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForSeq2SeqLM

model_id = "google/flan-t5-small"

model = AutoModelForSeq2SeqLM.from_pretrained(model_id, use_cache=True).eval()
model = NPUModelForSeq2SeqLM.from_pretrained(model_id, use_cache=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
streamer = TextStreamer(tokenizer, skip_special_tokens=True)


print("Compile model for the NPU")
# TODO: Offload only the decoder as encoder is regressing in accuracy
model.decoder = intel_npu_acceleration_library.compile(
model.decoder, dtype=torch.float16
)

query = input("Ask something: ")
prefix = tokenizer(query, return_tensors="pt")["input_ids"]

Expand Down
14 changes: 13 additions & 1 deletion intel_npu_acceleration_library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@

from .compiler import compile
from .dtypes import int4, int8, float16
from ._version import __version__
from .modelling import NPUModel, NPUAutoModel, NPUModelForCausalLM, NPUModelForSeq2SeqLM


__all__ = ["compile", "int4", "int8", "float16"]
__all__ = [
"compile",
"int4",
"int8",
"float16",
"__version__",
"NPUModel",
"NPUAutoModel",
"NPUModelForCausalLM",
"NPUModelForSeq2SeqLM",
]
6 changes: 6 additions & 0 deletions intel_npu_acceleration_library/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

__version__ = "v1.2.0"
5 changes: 0 additions & 5 deletions intel_npu_acceleration_library/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import intel_npu_acceleration_library.nn as nn
from torch._dynamo import register_backend
from typing import Union, Callable, Any
from packaging.version import Version
from typing import List
import torch

Expand All @@ -38,10 +37,6 @@ def compile(
f"intel-npu-acceleration-library library do not support yet the requeste datatype: {dtype}"
)

# Convert model to half precision if torch version is greater or equal to 2.3.0
if Version(torch.__version__) >= Version("2.3.0") and dtype != torch.float32:
model = model.half()

# Prepare and optimize model for NPU
with torch.no_grad():
# General optimizations
Expand Down
Loading

0 comments on commit c26443e

Please sign in to comment.