Skip to content

Commit

Permalink
Driver support for true quantization in eager mode (#20)
Browse files Browse the repository at this point in the history
* Driver supports proper quantization now

* Provide llama3 example

* initial support for i4 quantization

* Add test for explicit quantization

* Refactor bindings

* Fix linux compilation error

* Faster tests

* Update to OV 2024.1

* Add a bunch of examples

* fix issues with scikit-learn

* fix

* Add test for int4

* Refactor quantization code to accomodate for int4

* Add llava example
  • Loading branch information
alessandropalla authored May 25, 2024
1 parent 541c79d commit 83e51bd
Show file tree
Hide file tree
Showing 26 changed files with 586 additions and 83 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ repos:
rev: 'v1.8.0'
hooks:
- id: mypy
exclude: 'docs|script|test|venv'
exclude: 'docs|script|test|venv|examples'
10 changes: 8 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,20 @@ function(get_linux_lsb_release_information)
set(LSB_RELEASE_VERSION "${LSB_RELEASE_VERSION}" PARENT_SCOPE)
endfunction()

set(OV_VERSION_SHORT "2024.1")
set(OV_VERSION "2024.1.0.15008.f4afc983258_x86_64")

if (WIN32)
if(NOT OV_LIBRARY_URL)
set(OV_LIBRARY_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.0/windows/w_openvino_toolkit_windows_2024.0.0.14509.34caeefd078_x86_64.zip")
set(OV_LIBRARY_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages/${OV_VERSION_SHORT}/windows/w_openvino_toolkit_windows_${OV_VERSION}.zip")
endif()
elseif(UNIX)
if(NOT OV_LIBRARY_URL)
get_linux_lsb_release_information()
if (LSB_RELEASE_ID STREQUAL "Ubuntu")
if (${LSB_RELEASE_VERSION} STREQUAL "18.04" OR ${LSB_RELEASE_VERSION} STREQUAL "20.04" OR ${LSB_RELEASE_VERSION} STREQUAL "22.04")
string(REPLACE ".04" "" LSB_RELEASE_VERSION_SHORT ${LSB_RELEASE_VERSION})
set(OV_LIBRARY_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.0/linux/l_openvino_toolkit_ubuntu${LSB_RELEASE_VERSION_SHORT}_2024.0.0.14509.34caeefd078_x86_64.tgz")
set(OV_LIBRARY_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages/${OV_VERSION_SHORT}/linux/l_openvino_toolkit_ubuntu${LSB_RELEASE_VERSION_SHORT}_${OV_VERSION}.tgz")
else()
message(FATAL_ERROR "Ubuntu version ${LSB_RELEASE_VERSION} is unsupported")
endif()
Expand Down Expand Up @@ -81,6 +84,9 @@ else()
file(COPY ${OpenVINObin} DESTINATION ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
endif()

file(GLOB OpenVINOPython ${openvino_SOURCE_DIR}/python/openvino/*)
file(COPY ${OpenVINOPython} DESTINATION ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../external/openvino)

set(CMAKE_POSITION_INDEPENDENT_CODE ON)

include_directories(include)
Expand Down
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pytest
pytest-xdist
pytest-cov
scikit-learn
scikit-learn < 1.5.0
pre-commit; sys_platform == 'darwin'
sphinx
breathe
Expand Down
61 changes: 61 additions & 0 deletions examples/llama3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import intel_npu_acceleration_library
import torch
import os

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
dtype = "int8"

PATH = os.path.join("models", model_id, dtype)
filename = os.path.join(PATH, "model.pth")
os.makedirs(PATH, exist_ok=True)

if not os.path.exists(filename):
print("Compile model for the NPU")
model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=True).eval()
torch_dtype = torch.int8 if dtype == "int8" else torch.float16
with torch.no_grad():
model = intel_npu_acceleration_library.compile(model, dtype=torch_dtype)
torch.save(model, filename)
del model


print(f"Loading model from {filename}")

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = torch.load(filename).eval()
streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)

print("Run inference with Llama3 on NPU\n")


query = input(">")


messages = [
{
"role": "system",
"content": "You are an helpful chatbot that can provide information about the Intel NPU",
},
{"role": "user", "content": query},
]

input_ids = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
).to(model.device)

terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]


outputs = model.generate(
input_ids,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=False,
streamer=streamer,
)
55 changes: 55 additions & 0 deletions examples/llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

import requests
from PIL import Image
from transformers import (
LlavaForConditionalGeneration,
AutoTokenizer,
CLIPImageProcessor,
TextStreamer,
)
from transformers.feature_extraction_utils import BatchFeature
import intel_npu_acceleration_library
import torch


checkpoint = "Intel/llava-gemma-2b"

# Load model
model = LlavaForConditionalGeneration.from_pretrained(checkpoint)

model = intel_npu_acceleration_library.compile(model)

image_processor = CLIPImageProcessor.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)

# Prepare inputs
# Use gemma chat template
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<image>\nWhat's the content of the image?"}],
tokenize=False,
add_generation_prompt=True,
)
text_inputs = tokenizer(prompt, return_tensors="pt")

# clean the console
print("\033[H\033[J")
print("LLaVA Gemma Chatbot\n")
print("Please provide an image URL to generate a response.\n")
url = input("Image URL: ")

print("Description: ", end="", flush=True)
# url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)

pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]

inputs = BatchFeature(data={**text_inputs, "pixel_values": pixel_values})

# Generate
model.generate(**inputs, max_new_tokens=150, streamer=streamer)
46 changes: 46 additions & 0 deletions examples/phi-2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer
import intel_npu_acceleration_library
import torch

model_id = "microsoft/Phi-2"

model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)

npu_model = intel_npu_acceleration_library.compile(model, dtype=torch.float16)

pipe = pipeline(
"text-generation",
model=npu_model,
tokenizer=tokenizer,
max_length=256,
temperature=0.9,
top_p=0.95,
repetition_penalty=1.2,
streamer=streamer,
)

local_llm = HuggingFacePipeline(pipeline=pipe)
pipe.model.config.pad_token_id = pipe.model.config.eos_token_id


template = """Question: {question}
Answer: """

prompt = PromptTemplate(template=template, input_variables=["question"])

llm_chain = LLMChain(prompt=prompt, llm=local_llm)

question = "What's the distance between the Earth and the Moon?"

llm_chain.run(question)
49 changes: 49 additions & 0 deletions examples/phi-3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
import intel_npu_acceleration_library
import warnings

torch.random.manual_seed(0)

model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
torch_dtype="auto",
trust_remote_code=True,
)
model = intel_npu_acceleration_library.compile(model, torch.float16)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
streamer = TextStreamer(tokenizer)

messages = [
{
"role": "system",
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
},
{
"role": "user",
"content": "Can you provide ways to eat combinations of bananas and dragonfruits?",
},
]

pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)

generation_args = {
"max_new_tokens": 500,
"return_full_text": False,
"temperature": 0.0,
"do_sample": False,
"streamer": streamer,
}

with warnings.catch_warnings():
warnings.simplefilter("ignore")
pipe(messages, **generation_args)
54 changes: 54 additions & 0 deletions examples/tiny_llama_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

from transformers import pipeline, TextStreamer, set_seed
import intel_npu_acceleration_library
import torch
import os

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

print("Loading the model...")
pipe = pipeline(
"text-generation", model=model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
print("Compiling the model for NPU...")
pipe.model = intel_npu_acceleration_library.compile(pipe.model, dtype=torch.int8)

streamer = TextStreamer(pipe.tokenizer, skip_special_tokens=True, skip_prompt=True)

set_seed(42)


messages = [
{
"role": "system",
"content": "You are a friendly chatbot. You can ask me anything.",
},
]

print("NPU Chatbot is ready! Please ask a question. Type 'exit' to quit.")
while True:
query = input("User: ")
if query.lower() == "exit":
break
messages.append({"role": "user", "content": query})

prompt = pipe.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
print("Assistant: ", end="", flush=True)
out = pipe(
prompt,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
streamer=streamer,
)

reply = out[0]["generated_text"].split("<|assistant|>")[-1].strip()
messages.append({"role": "assistant", "content": reply})
16 changes: 16 additions & 0 deletions include/intel_npu_acceleration_library/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ bool _isNPUAvailable(ov::Core& core) {
return std::find(availableDevices.begin(), availableDevices.end(), "NPU") != availableDevices.end();
}

ov::element::Type_t dtype_from_string(const std::string& dtype) {
if (dtype == "int8" || dtype == "i8") {
return ov::element::Type_t::i8;
} else if (dtype == "int4" || dtype == "i4") {
return ov::element::Type_t::i4;
}
if (dtype == "float16" || dtype == "half" || dtype == "f16") {
return ov::element::Type_t::f16;
}
if (dtype == "bfloat16" || dtype == "bf16") {
return ov::element::Type_t::bf16;
} else {
throw std::invalid_argument("Unsupported datatype: " + dtype);
}
}

} // namespace intel_npu_acceleration_library

// Define half pointer as uint16_t pointer datatype
Expand Down
Loading

0 comments on commit 83e51bd

Please sign in to comment.