Skip to content

Commit

Permalink
Update Llama convert: load eot id from json
Browse files Browse the repository at this point in the history
  • Loading branch information
yvonwin committed Apr 26, 2024
1 parent 2cf8901 commit 82c89cf
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions qwen_cpp/convert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Convert Hugging Face Qwen models to GGML format
Convert Hugging Face Qwen and LLama models to GGML format
"""
import argparse
import platform
Expand Down Expand Up @@ -229,8 +229,7 @@ def dump_model(f, model, ggml_type):

class Llama3Converter:
"""
主要是tokenizer的区别。
还有qkv没有bias
qkv has no bias
"""
MODEL_TYPE = ModelType.Llama3
@classmethod
Expand All @@ -251,10 +250,10 @@ def dump_config(f, config, generation_config, tokenizer, ggml_type):
config.num_hidden_layers,
config.intermediate_size,
config.max_position_embeddings,
config.eos_token_id, # eos 151645
128001, # pad
128000, # <|begin_of_text|>
128009, # "<|end_of_text|>"
config.eos_token_id, # <|end_of_text|> 128001
config.eos_token_id, # llama3 no pad, so not use actually
config.bos_token_id, # <|begin_of_text|> 128000
generation_config.eos_token_id[1] if isinstance(generation_config.eos_token_id, list) else 128009 # <|eot_id|> 128009
]
f.write(struct.pack("i" * len(config_values), *config_values))

Expand All @@ -278,7 +277,7 @@ def dump_model(f, model, ggml_type):
"model.norm.weight",
"lm_head.weight",
]
print(len(weight_names))
# print(len(weight_names)) // 8b: 291
dump_state_dict(f, weight_names, model.state_dict(), ggml_type)

class Qwen2Converter:
Expand Down Expand Up @@ -461,7 +460,7 @@ def convert(f: BinaryIO, model_name_or_path: str, dtype: str = "q4_0"):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)

print(model)
# print(model)
# state_dict = model.state_dict()
# keys = state_dict.keys()
# print(keys)
Expand All @@ -471,7 +470,6 @@ def convert(f: BinaryIO, model_name_or_path: str, dtype: str = "q4_0"):
# print(model.config.eos_token_id) # 151645
# print(model.generation_config)
# print(tokenizer.model_max_length)
# print(list(tokenizer.added_tokens_decoder.keys())) # 1.5 only

if model.config.architectures[0]=="Qwen2ForCausalLM":
if "Code" not in model_name_or_path:
Expand All @@ -498,7 +496,7 @@ def main():
parser.add_argument(
"-i",
"--model_name_or_path",
default="Qwen/Qwen1.5-1.8B-Chat", # Qwen/Qwen1.5-0.5B-Chat; Qwen/Qwen1.5-MoE-A2.7B-Chat
default="Qwen/Qwen1.5-1.8B-Chat", # Qwen/Qwen1.5-7B-Chat; meta-llama/Meta-Llama-3-8B-Instruct
type=str,
help="Model name or path used in AutoModel.from_pretrained",
)
Expand Down

0 comments on commit 82c89cf

Please sign in to comment.