Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 7baa96b

Browse files
Support CodeLlama model in NeuralChat (#711)
* Support neural-chat-7b-v3 and neural-chat-7b-v3-1 Signed-off-by: lvliang-intel <[email protected]>
1 parent df40d5e commit 7baa96b

File tree

4 files changed

+34
-4
lines changed

4 files changed

+34
-4
lines changed

intel_extension_for_transformers/neural_chat/chatbot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def build_chatbot(config: PipelineConfig=None):
7272
adapter = BaseModel()
7373
else:
7474
raise ValueError("NeuralChat Error: Unsupported model name or path, \
75-
only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL now.")
75+
only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL/CODELLAMA/STARCODER now.")
7676

7777
# register plugin instance in model adaptor
7878
if config.plugins:

intel_extension_for_transformers/neural_chat/models/base_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def predict_stream(self, query, config=None):
145145
query_include_prompt = False
146146
self.get_conv_template(self.model_name, config.task)
147147
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
148-
"starcoder" in self.model_name:
148+
"starcoder" in self.model_name or "codellama" in self.model_name.lower():
149149
query_include_prompt = True
150150

151151
# plugin pre actions
@@ -220,7 +220,7 @@ def predict(self, query, config=None):
220220
query_include_prompt = False
221221
self.get_conv_template(self.model_name, config.task)
222222
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
223-
"starcoder" in self.model_name:
223+
"starcoder" in self.model_name or "codellama" in self.model_name.lower():
224224
query_include_prompt = True
225225

226226
# plugin pre actions

intel_extension_for_transformers/neural_chat/models/model_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def load_model(
365365
or re.search("neural-chat-7b-v3", model_name, re.IGNORECASE)
366366
or re.search("qwen", model_name, re.IGNORECASE)
367367
or re.search("starcoder", model_name, re.IGNORECASE)
368+
or re.search("codellama", model_name, re.IGNORECASE)
368369
or re.search("Mistral", model_name, re.IGNORECASE)
369370
) and not ipex_int8) or re.search("opt", model_name, re.IGNORECASE):
370371
with smart_context_manager(use_deepspeed=use_deepspeed):
@@ -377,6 +378,7 @@ def load_model(
377378
)
378379
elif (
379380
(re.search("starcoder", model_name, re.IGNORECASE)
381+
or re.search("codellama", model_name, re.IGNORECASE)
380382
) and ipex_int8
381383
):
382384
with smart_context_manager(use_deepspeed=use_deepspeed):
@@ -389,7 +391,7 @@ def load_model(
389391
else:
390392
raise ValueError(
391393
f"Unsupported model {model_name}, only supports "
392-
"FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL now."
394+
"FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL/CODELLAMA/STARCODER now."
393395
)
394396

395397
if re.search("llama", model.config.architectures[0], re.IGNORECASE):

intel_extension_for_transformers/neural_chat/tests/nightly/models/test_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,5 +144,33 @@ def test_get_default_conv_template_v3_1(self):
144144
print(result)
145145
self.assertIn('The Intel Xeon Scalable Processors', str(result))
146146

147+
class TestStarCoderModel(unittest.TestCase):
148+
def setUp(self):
149+
return super().setUp()
150+
151+
def tearDown(self) -> None:
152+
return super().tearDown()
153+
154+
def test_code_gen(self):
155+
config = PipelineConfig(model_name_or_path="bigcode/starcoder")
156+
chatbot = build_chatbot(config=config)
157+
result = chatbot.predict("def print_hello_world():")
158+
print(result)
159+
self.assertIn("""print('Hello World')""", str(result))
160+
161+
class TestCodeLlamaModel(unittest.TestCase):
162+
def setUp(self):
163+
return super().setUp()
164+
165+
def tearDown(self) -> None:
166+
return super().tearDown()
167+
168+
def test_code_gen(self):
169+
config = PipelineConfig(model_name_or_path="codellama/CodeLlama-7b-hf")
170+
chatbot = build_chatbot(config=config)
171+
result = chatbot.predict("def print_hello_world():")
172+
print(result)
173+
self.assertIn("""print('Hello World')""", str(result))
174+
147175
if __name__ == "__main__":
148176
unittest.main()

0 commit comments

Comments
 (0)