Skip to content

Commit

Permalink
RWKV rebase (mlc-ai#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored May 29, 2023
1 parent fe6f0a6 commit 9aed143
Show file tree
Hide file tree
Showing 22 changed files with 2,535 additions and 421 deletions.
48 changes: 33 additions & 15 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import mlc_llm
from mlc_llm import utils
from mlc_llm.relax_model import gpt_neox, llama, moss
from mlc_llm.relax_model import gpt_neox, llama, moss, rwkv


def _parse_args():
Expand Down Expand Up @@ -224,22 +224,37 @@ def mod_transform_before_build(
args: argparse.Namespace,
) -> tvm.IRModule:
"""First-stage: Legalize ops and trace"""
model_names = [
"prefill",
"decode",
"create_kv_cache",
"softmax_with_temperature",
"get_metadata",
]
if ARGS.model.startswith("rwkv-"):
model_names = [
"decode",
"create_kv_cache",
"softmax_with_temperature",
"get_metadata",
"reset_kv_cache",
]
else:
model_names = [
"prefill",
"decode",
"create_kv_cache",
"softmax_with_temperature",
"get_metadata",
]

if args.quantization.mode != "no":
mod = mlc_llm.transform.GroupQuantize( # pylint: disable=not-callable
group_size=40 if args.quantization.mode.endswith("3") else 32,
sym=args.quantization.sym,
mode=args.quantization.mode,
storage_nbit=args.quantization.storage_nbit,
dtype=args.quantization.model_dtype,
)(mod)
if ARGS.model.startswith("rwkv-"):
mod = mlc_llm.transform.RWKVQuantize( # pylint: disable=not-callable
mode=args.quantization.mode,
dtype=args.quantization.model_dtype,
)(mod)
else:
mod = mlc_llm.transform.GroupQuantize( # pylint: disable=not-callable
group_size=40 if args.quantization.mode.endswith("3") else 32,
sym=args.quantization.sym,
mode=args.quantization.mode,
storage_nbit=args.quantization.storage_nbit,
dtype=args.quantization.model_dtype,
)(mod)
mod = mlc_llm.transform.FuseTransposeMatmul()(mod) # pylint: disable=not-callable
mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter
mod = mlc_llm.transform.FuseDecodeMatmulEwise( # pylint: disable=not-callable
Expand All @@ -250,6 +265,7 @@ def mod_transform_before_build(
mod_transform, mod_deploy = utils.split_transform_deploy_mod(mod, model_names)

debug_dump_script(mod_transform, "mod_lift_params.py", args)
debug_dump_script(mod_deploy, "mod_deploy.py", args)

new_params = utils.transform_params(mod_transform, model_params)
utils.save_params(new_params, args.artifact_path)
Expand Down Expand Up @@ -363,6 +379,8 @@ def main():
mod, params = gpt_neox.get_model(ARGS, config)
elif ARGS.model_category == "moss":
mod, params = moss.get_model(ARGS, config)
elif ARGS.model_category == "rwkv":
mod, params = rwkv.get_model(ARGS, config)
else:
raise ValueError(f"Model {ARGS.model} not supported")
mod = mod_transform_before_build(mod, params, ARGS)
Expand Down
17 changes: 10 additions & 7 deletions cpp/cli_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ std::optional<std::filesystem::path> FindFile(
for (const std::filesystem::path& prefix : search_paths) {
for (const std::string& name : names) {
for (const std::string& suffix : suffixes) {
std::filesystem::path path = prefix / (name + suffix);
if (std::filesystem::exists(path) && std::filesystem::is_regular_file(path)) {
return path;
try {
std::filesystem::path path = std::filesystem::canonical(prefix / (name + suffix));
if (std::filesystem::exists(path) && std::filesystem::is_regular_file(path)) {
return path;
}
} catch (const std::filesystem::filesystem_error& e) {
}
}
}
Expand Down Expand Up @@ -156,8 +159,8 @@ struct ModelPaths {
/*!
* \brief Path to ${model}-${device}.{so|dylib}
*
* This dynamic library contains all the compute kernels used in LLM inference, and can be loaded
* using tvm::runtime::Module::LoadFromFile.
* This dynamic library contains all the compute kernels used in LLM inference, and can be
* loaded using tvm::runtime::Module::LoadFromFile.
*/
std::filesystem::path lib;

Expand Down Expand Up @@ -519,8 +522,8 @@ int main(int argc, char* argv[]) {
try {
ChatModule chat(GetDevice(device_name, device_id));
if (args.get<bool>("--evaluate")) {
// `--evaluate` is only used for performance debugging, and thus will call low-level APIs that
// are not supposed to be used in chat app setting
// `--evaluate` is only used for performance debugging, and thus will call low-level APIs
// that are not supposed to be used in chat app setting
ModelPaths model = ModelPaths::Find(artifact_path, device_name, local_id);
tvm::runtime::Module chat_mod = mlc::llm::CreateChatModule(GetDevice(device_name, device_id));
std::string model_path = model.config.parent_path().string();
Expand Down
41 changes: 40 additions & 1 deletion cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace mlc {
namespace llm {

namespace {
Conversation VicunaV11() {
Conversation conv;
conv.name = "vicuna_v1.1";
Expand Down Expand Up @@ -89,6 +89,43 @@ Conversation RedPajamaChat() {
return conv;
}

Conversation RWKV() {
Conversation conv;
conv.name = "rwkv";
conv.system = (
"\nThe following is a coherent verbose detailed conversation between a girl named Alice "
"and her friend Bob. \n"
"Alice is very intelligent, creative and friendly. \n"
"Alice is unlikely to disagree with Bob, and Alice doesn't like to ask Bob questions. \n"
"Alice likes to tell Bob a lot about herself and her opinions. \n"
"Alice usually gives Bob kind, helpful and informative advices."
);
conv.roles = {"Bob", "Alice"};
conv.messages = {
{"Bob", "Hello Alice, how are you doing?"},
{"Alice", "Hi! Thanks, I'm fine. What about you?"},
{"Bob", "I am fine. It's nice to see you. Look, here is a store selling tea and juice."},
{"Alice",
"Sure. Let's go inside. I would like to have some Mocha latte, which is my favourite!"},
{"Bob", "What is it?"},
{"Alice",
"Mocha latte is usually made with espresso, milk, chocolate, and frothed milk. Its "
"flavors are frequently sweet."},
{"Bob",
"Sounds tasty. I'll try it next time. Would you like to chat with me for a while?"},
{"Alice",
"Of course! I'm glad to answer your questions or give helpful advices. You know, I am "
"confident with my expertise. So please go ahead!"}
};
conv.separator_style = SeparatorStyle::kAddColon;
conv.offset = 8;
conv.seps = {"\n\n"};
conv.stop_str = "\n\n";
conv.stop_tokens = {0};
conv.add_bos = false;
return conv;
}

Conversation VanillaLM() {
Conversation conv;
conv.name = "LM";
Expand All @@ -105,6 +142,7 @@ Conversation VanillaLM() {
conv.add_bos = true;
return conv;
}
} // namespace

using ConvFactory = Conversation (*)();

Expand All @@ -113,6 +151,7 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"vicuna_v1.1", VicunaV11},
{"conv_one_shot", ConvOneShot},
{"redpajama_chat", RedPajamaChat},
{"rwkv", RWKV},
{"LM", VanillaLM},
};
auto it = factory.find(name);
Expand Down
6 changes: 5 additions & 1 deletion cpp/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ class Conversation {
this->messages.back().push_back(msg);
}

void Reset() {
this->messages.resize(this->offset);
}

private:
// Identity function
static std::string Identity(std::string msg) { return msg; }
Expand All @@ -187,7 +191,7 @@ class Conversation {
} else {
// need to add a sep of last response
// which was not added in the processing step.
ret.push_back(this->seps[1 % this->seps.size()]);
// ret.push_back(this->seps[1 % this->seps.size()]);
}

ICHECK_EQ(start_pos % 2, 0);
Expand Down
Loading

0 comments on commit 9aed143

Please sign in to comment.