Skip to content

Commit

Permalink
Cleanup after conv refactor (mlc-ai#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored May 29, 2023
1 parent 9aed143 commit 642669d
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 184 deletions.
189 changes: 21 additions & 168 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,31 +92,28 @@ Conversation RedPajamaChat() {
Conversation RWKV() {
Conversation conv;
conv.name = "rwkv";
conv.system = (
"\nThe following is a coherent verbose detailed conversation between a girl named Alice "
"and her friend Bob. \n"
"Alice is very intelligent, creative and friendly. \n"
"Alice is unlikely to disagree with Bob, and Alice doesn't like to ask Bob questions. \n"
"Alice likes to tell Bob a lot about herself and her opinions. \n"
"Alice usually gives Bob kind, helpful and informative advices."
);
conv.system =
("\nThe following is a coherent verbose detailed conversation between a girl named Alice "
"and her friend Bob. \n"
"Alice is very intelligent, creative and friendly. \n"
"Alice is unlikely to disagree with Bob, and Alice doesn't like to ask Bob questions. \n"
"Alice likes to tell Bob a lot about herself and her opinions. \n"
"Alice usually gives Bob kind, helpful and informative advices.");
conv.roles = {"Bob", "Alice"};
conv.messages = {
{"Bob", "Hello Alice, how are you doing?"},
{"Alice", "Hi! Thanks, I'm fine. What about you?"},
{"Bob", "I am fine. It's nice to see you. Look, here is a store selling tea and juice."},
{"Alice",
"Sure. Let's go inside. I would like to have some Mocha latte, which is my favourite!"},
{"Bob", "What is it?"},
{"Alice",
"Mocha latte is usually made with espresso, milk, chocolate, and frothed milk. Its "
"flavors are frequently sweet."},
{"Bob",
"Sounds tasty. I'll try it next time. Would you like to chat with me for a while?"},
{"Alice",
"Of course! I'm glad to answer your questions or give helpful advices. You know, I am "
"confident with my expertise. So please go ahead!"}
};
{"Bob", "Hello Alice, how are you doing?"},
{"Alice", "Hi! Thanks, I'm fine. What about you?"},
{"Bob", "I am fine. It's nice to see you. Look, here is a store selling tea and juice."},
{"Alice",
"Sure. Let's go inside. I would like to have some Mocha latte, which is my favourite!"},
{"Bob", "What is it?"},
{"Alice",
"Mocha latte is usually made with espresso, milk, chocolate, and frothed milk. Its "
"flavors are frequently sweet."},
{"Bob", "Sounds tasty. I'll try it next time. Would you like to chat with me for a while?"},
{"Alice",
"Of course! I'm glad to answer your questions or give helpful advices. You know, I am "
"confident with my expertise. So please go ahead!"}};
conv.separator_style = SeparatorStyle::kAddColon;
conv.offset = 8;
conv.seps = {"\n\n"};
Expand All @@ -142,7 +139,7 @@ Conversation VanillaLM() {
conv.add_bos = true;
return conv;
}
} // namespace
} // namespace

using ConvFactory = Conversation (*)();

Expand All @@ -161,149 +158,5 @@ Conversation Conversation::FromTemplate(const std::string& name) {
return it->second();
}

void Conversation::LoadJSONOverride(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
picojson::object config = config_json.get<picojson::object>();
if (config.count("name")) {
CHECK(config["name"].is<std::string>()) << "Invalid name" << err_templ;
this->name = config["name"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"name\" not found.";
}
if (config.count("system")) {
CHECK(config["system"].is<std::string>()) << "Invalid system" << err_templ;
this->system = config["system"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"system\" not found.";
}
if (config.count("roles")) {
CHECK(config["roles"].is<picojson::array>()) << "Invalid roles" << err_templ;
picojson::array roles_arr = config["roles"].get<picojson::array>();
std::vector<std::string> roles;
for (const picojson::value& v : roles_arr) {
CHECK(v.is<std::string>()) << "Invalid roles" << err_templ;
roles.push_back(v.get<std::string>());
}
this->roles = roles;
} else {
CHECK(partial_update) << "Key \"roles\" not found.";
}
if (config.count("messages")) {
CHECK(config["messages"].is<picojson::array>()) << "Invalid messages" << err_templ;
std::vector<std::vector<std::string>> messages;
picojson::array msgs_arr = config["messages"].get<picojson::array>();
for (const picojson::value& msgs_i : msgs_arr) {
CHECK(msgs_i.is<picojson::array>()) << "Invalid messages" << err_templ;
picojson::array msgs_i_arr = msgs_i.get<picojson::array>();
std::vector<std::string> messages_i;
for (const picojson::value& msg_v : msgs_i_arr) {
CHECK(msg_v.is<std::string>()) << "Invalid messages" << err_templ;
messages_i.push_back(msg_v.get<std::string>());
}
messages.push_back(messages_i);
}
this->messages = messages;
} else {
CHECK(partial_update) << "Key \"messages\" not found.";
}
if (config.count("offset")) {
CHECK(config["offset"].is<int64_t>()) << "Invalid offset" << err_templ;
this->offset = config["offset"].get<int64_t>();
} else {
CHECK(partial_update) << "Key \"offset\" not found.";
}
if (config.count("separator_style")) {
CHECK(config["separator_style"].is<int64_t>()) << "Invalid separator style" << err_templ;
this->separator_style = SeparatorStyle(config["separator_style"].get<int64_t>());
} else {
CHECK(partial_update) << "Key \"separator_style\" not found.";
}
if (config.count("seps")) {
std::vector<std::string> seps;
CHECK(config["seps"].is<picojson::array>()) << "Invalid seps" << err_templ;
picojson::array seps_arr = config["seps"].get<picojson::array>();
for (const picojson::value& sep : seps_arr) {
CHECK(sep.is<std::string>()) << "Invalid seps" << err_templ;
seps.push_back(sep.get<std::string>());
}
this->seps = seps;
} else {
CHECK(partial_update) << "Key \"seps\" not found.";
}
if (config.count("stop_str")) {
CHECK(config["stop_str"].is<std::string>()) << "Invalid stop_str" << err_templ;
this->stop_str = config["stop_str"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"stop_str\" not found.";
}
if (config.count("stop_tokens")) {
CHECK(config["stop_tokens"].is<picojson::array>()) << "Invalid stop_tokens" << err_templ;
picojson::array stop_tokens_arr = config["stop_tokens"].get<picojson::array>();
std::vector<int32_t> stop_tokens;
for (const picojson::value& stop_token : stop_tokens_arr) {
CHECK(stop_token.is<int64_t>()) << "Invalid stop_tokens" << err_templ;
stop_tokens.push_back(stop_token.get<int64_t>());
}
this->stop_tokens = stop_tokens;
} else {
CHECK(partial_update) << "Key \"stop_tokens\" not found.";
}
if (config.count("add_bos")) {
CHECK(config["add_bos"].is<bool>()) << "Invalid add_bos" << err_templ;
this->add_bos = config["add_bos"].get<bool>();
} else {
CHECK(partial_update) << "Key \"add_bos\" not found.";
}
}

void Conversation::LoadJSONOverride(const std::string& config_str, bool partial_update) {
picojson::value config_json;
std::string err = picojson::parse(config_json, config_str);
if (!err.empty()) {
LOG(FATAL) << err;
return;
}
LoadJSONOverride(config_json, partial_update);
}

picojson::value Conversation::SerializeToJSON() const {
picojson::object config;
config["name"] = picojson::value(this->name);
config["system"] = picojson::value(this->system);
picojson::array roles_arr;
for (const std::string& role_str : this->roles) {
roles_arr.push_back(picojson::value(role_str));
}
config["roles"] = picojson::value(roles_arr);
picojson::array msgs_arr;
for (const std::vector<std::string>& msgs_i : this->messages) {
picojson::array msgs_i_arr;
for (const std::string& msg_str : msgs_i) {
msgs_i_arr.push_back(picojson::value(msg_str));
}
msgs_arr.push_back(picojson::value(msgs_i_arr));
}
config["messages"] = picojson::value(msgs_arr);
config["offset"] = picojson::value((int64_t)this->offset);
config["separator_style"] = picojson::value((int64_t)this->separator_style);
picojson::array seps_arr;
for (const std::string& sep_str : this->seps) {
seps_arr.push_back(picojson::value(sep_str));
}
config["seps"] = picojson::value(seps_arr);
config["stop_str"] = picojson::value(this->stop_str);
picojson::array stop_tokens_arr;
for (const int32_t& stop_token_str : this->stop_tokens) {
stop_tokens_arr.push_back(picojson::value((int64_t)stop_token_str));
}
config["stop_tokens"] = picojson::value(stop_tokens_arr);
config["add_bos"] = picojson::value(this->add_bos);
return picojson::value(config);
}

std::string Conversation::SerializeToJSONStr() const {
return SerializeToJSON().serialize(true);
}

} // namespace llm
} // namespace mlc
154 changes: 154 additions & 0 deletions cpp/conversation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@


#include "conversation.h"

#include <string>
#include <unordered_map>

namespace mlc {
namespace llm {

void Conversation::LoadJSONOverride(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
picojson::object config = config_json.get<picojson::object>();
if (config.count("name")) {
CHECK(config["name"].is<std::string>()) << "Invalid name" << err_templ;
this->name = config["name"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"name\" not found.";
}
if (config.count("system")) {
CHECK(config["system"].is<std::string>()) << "Invalid system" << err_templ;
this->system = config["system"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"system\" not found.";
}
if (config.count("roles")) {
CHECK(config["roles"].is<picojson::array>()) << "Invalid roles" << err_templ;
picojson::array roles_arr = config["roles"].get<picojson::array>();
std::vector<std::string> roles;
for (const picojson::value& v : roles_arr) {
CHECK(v.is<std::string>()) << "Invalid roles" << err_templ;
roles.push_back(v.get<std::string>());
}
this->roles = roles;
} else {
CHECK(partial_update) << "Key \"roles\" not found.";
}
if (config.count("messages")) {
CHECK(config["messages"].is<picojson::array>()) << "Invalid messages" << err_templ;
std::vector<std::vector<std::string>> messages;
picojson::array msgs_arr = config["messages"].get<picojson::array>();
for (const picojson::value& msgs_i : msgs_arr) {
CHECK(msgs_i.is<picojson::array>()) << "Invalid messages" << err_templ;
picojson::array msgs_i_arr = msgs_i.get<picojson::array>();
std::vector<std::string> messages_i;
for (const picojson::value& msg_v : msgs_i_arr) {
CHECK(msg_v.is<std::string>()) << "Invalid messages" << err_templ;
messages_i.push_back(msg_v.get<std::string>());
}
messages.push_back(messages_i);
}
this->messages = messages;
} else {
CHECK(partial_update) << "Key \"messages\" not found.";
}
if (config.count("offset")) {
CHECK(config["offset"].is<int64_t>()) << "Invalid offset" << err_templ;
this->offset = config["offset"].get<int64_t>();
} else {
CHECK(partial_update) << "Key \"offset\" not found.";
}
if (config.count("separator_style")) {
CHECK(config["separator_style"].is<int64_t>()) << "Invalid separator style" << err_templ;
this->separator_style = SeparatorStyle(config["separator_style"].get<int64_t>());
} else {
CHECK(partial_update) << "Key \"separator_style\" not found.";
}
if (config.count("seps")) {
std::vector<std::string> seps;
CHECK(config["seps"].is<picojson::array>()) << "Invalid seps" << err_templ;
picojson::array seps_arr = config["seps"].get<picojson::array>();
for (const picojson::value& sep : seps_arr) {
CHECK(sep.is<std::string>()) << "Invalid seps" << err_templ;
seps.push_back(sep.get<std::string>());
}
this->seps = seps;
} else {
CHECK(partial_update) << "Key \"seps\" not found.";
}
if (config.count("stop_str")) {
CHECK(config["stop_str"].is<std::string>()) << "Invalid stop_str" << err_templ;
this->stop_str = config["stop_str"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"stop_str\" not found.";
}
if (config.count("stop_tokens")) {
CHECK(config["stop_tokens"].is<picojson::array>()) << "Invalid stop_tokens" << err_templ;
picojson::array stop_tokens_arr = config["stop_tokens"].get<picojson::array>();
std::vector<int32_t> stop_tokens;
for (const picojson::value& stop_token : stop_tokens_arr) {
CHECK(stop_token.is<int64_t>()) << "Invalid stop_tokens" << err_templ;
stop_tokens.push_back(stop_token.get<int64_t>());
}
this->stop_tokens = stop_tokens;
} else {
CHECK(partial_update) << "Key \"stop_tokens\" not found.";
}
if (config.count("add_bos")) {
CHECK(config["add_bos"].is<bool>()) << "Invalid add_bos" << err_templ;
this->add_bos = config["add_bos"].get<bool>();
} else {
CHECK(partial_update) << "Key \"add_bos\" not found.";
}
}

void Conversation::LoadJSONOverride(const std::string& config_str, bool partial_update) {
picojson::value config_json;
std::string err = picojson::parse(config_json, config_str);
if (!err.empty()) {
LOG(FATAL) << err;
return;
}
LoadJSONOverride(config_json, partial_update);
}

picojson::value Conversation::SerializeToJSON() const {
picojson::object config;
config["name"] = picojson::value(this->name);
config["system"] = picojson::value(this->system);
picojson::array roles_arr;
for (const std::string& role_str : this->roles) {
roles_arr.push_back(picojson::value(role_str));
}
config["roles"] = picojson::value(roles_arr);
picojson::array msgs_arr;
for (const std::vector<std::string>& msgs_i : this->messages) {
picojson::array msgs_i_arr;
for (const std::string& msg_str : msgs_i) {
msgs_i_arr.push_back(picojson::value(msg_str));
}
msgs_arr.push_back(picojson::value(msgs_i_arr));
}
config["messages"] = picojson::value(msgs_arr);
config["offset"] = picojson::value((int64_t)this->offset);
config["separator_style"] = picojson::value((int64_t)this->separator_style);
picojson::array seps_arr;
for (const std::string& sep_str : this->seps) {
seps_arr.push_back(picojson::value(sep_str));
}
config["seps"] = picojson::value(seps_arr);
config["stop_str"] = picojson::value(this->stop_str);
picojson::array stop_tokens_arr;
for (const int32_t& stop_token_str : this->stop_tokens) {
stop_tokens_arr.push_back(picojson::value((int64_t)stop_token_str));
}
config["stop_tokens"] = picojson::value(stop_tokens_arr);
config["add_bos"] = picojson::value(this->add_bos);
return picojson::value(config);
}

std::string Conversation::SerializeToJSONStr() const { return SerializeToJSON().serialize(true); }

} // namespace llm
} // namespace mlc
Loading

0 comments on commit 642669d

Please sign in to comment.