From fe6f0a6f4d6d5679927e959bc287e3f95896ca96 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 29 May 2023 06:17:35 -0700 Subject: [PATCH] Load conversation from JSON (#252) --- .gitmodules | 6 ++ 3rdparty/googletest | 1 + 3rdparty/picojson | 1 + CMakeLists.txt | 16 ++++- cpp/conv_templates.cc | 144 +++++++++++++++++++++++++++++++++++++ cpp/conversation.h | 87 +++++++++++++++++++--- cpp/llm_chat.cc | 139 ++++++++++++++++++++++++++++------- tests/cpp/conv_unittest.cc | 29 ++++++++ 8 files changed, 387 insertions(+), 36 deletions(-) create mode 160000 3rdparty/googletest create mode 160000 3rdparty/picojson create mode 100644 tests/cpp/conv_unittest.cc diff --git a/.gitmodules b/.gitmodules index d0d1610453..1905048b81 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,9 @@ [submodule "3rdparty/tokenizers-cpp"] path = 3rdparty/tokenizers-cpp url = https://github.com/mlc-ai/tokenizers-cpp +[submodule "3rdparty/googletest"] + path = 3rdparty/googletest + url = https://github.com/google/googletest.git +[submodule "3rdparty/picojson"] + path = 3rdparty/picojson + url = https://github.com/kazuho/picojson.git diff --git a/3rdparty/googletest b/3rdparty/googletest new file mode 160000 index 0000000000..4580469122 --- /dev/null +++ b/3rdparty/googletest @@ -0,0 +1 @@ +Subproject commit 45804691223635953f311cf31a10c632553bbfc3 diff --git a/3rdparty/picojson b/3rdparty/picojson new file mode 160000 index 0000000000..111c9be518 --- /dev/null +++ b/3rdparty/picojson @@ -0,0 +1 @@ +Subproject commit 111c9be5188f7350c2eac9ddaedd8cca3d7bf394 diff --git a/CMakeLists.txt b/CMakeLists.txt index 1518d5df20..7309edc5ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,8 @@ if (MLC_LLM_INSTALL_STATIC_LIB) set(BUILD_STATIC_RUNTIME ON) endif() +option(BUILD_CPP_TEST "Build cpp unittests" OFF) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) # tvm runtime config: minimize runtime components @@ -81,8 +83,18 @@ set_target_properties(mlc_llm_static PROPERTIES OUTPUT_NAME mlc_llm) target_link_libraries(mlc_llm PUBLIC tvm_runtime) target_link_libraries(mlc_llm PRIVATE tokenizers_cpp) - -# Exmaple app that may depends on mlc_llm +if (BUILD_CPP_TEST) + message(STATUS "Building cpp unittests") + add_subdirectory(3rdparty/googletest) + file(GLOB_RECURSE MLC_LLM_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/cpp/*unittest.cc) + add_executable(mlc_llm_cpp_tests ${MLC_LLM_TEST_SRCS}) + target_include_directories(mlc_llm_cpp_tests PRIVATE ${MLC_LLM_INCLUDES}) + target_include_directories(mlc_llm_cpp_tests PRIVATE ${PROJECT_SOURCE_DIR}/cpp) + target_include_directories(mlc_llm_cpp_tests PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) + target_link_libraries(mlc_llm_cpp_tests PUBLIC mlc_llm gtest gtest_main) +endif(BUILD_CPP_TEST) + +# Example app that may depends on mlc_llm add_executable(mlc_chat_cli $) target_include_directories(mlc_cli_objs PRIVATE ${MLC_LLM_INCLUDES}) target_include_directories(mlc_cli_objs PRIVATE 3rdparty/argparse/include) diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index 00a7c301c1..8e7cf24cf6 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -122,5 +122,149 @@ Conversation Conversation::FromTemplate(const std::string& name) { return it->second(); } +void Conversation::LoadJSONOverride(const picojson::value& config_json, bool partial_update) { + std::string err_templ = " in conversion template json file."; + picojson::object config = config_json.get(); + if (config.count("name")) { + CHECK(config["name"].is()) << "Invalid name" << err_templ; + this->name = config["name"].get(); + } else { + CHECK(partial_update) << "Key \"name\" not found."; + } + if (config.count("system")) { + CHECK(config["system"].is()) << "Invalid system" << err_templ; + this->system = config["system"].get(); + } else { + CHECK(partial_update) << "Key \"system\" not found."; + } + if (config.count("roles")) { + CHECK(config["roles"].is()) << "Invalid roles" << err_templ; + picojson::array roles_arr = config["roles"].get(); + std::vector roles; + for (const picojson::value& v : roles_arr) { + CHECK(v.is()) << "Invalid roles" << err_templ; + roles.push_back(v.get()); + } + this->roles = roles; + } else { + CHECK(partial_update) << "Key \"roles\" not found."; + } + if (config.count("messages")) { + CHECK(config["messages"].is()) << "Invalid messages" << err_templ; + std::vector> messages; + picojson::array msgs_arr = config["messages"].get(); + for (const picojson::value& msgs_i : msgs_arr) { + CHECK(msgs_i.is()) << "Invalid messages" << err_templ; + picojson::array msgs_i_arr = msgs_i.get(); + std::vector messages_i; + for (const picojson::value& msg_v : msgs_i_arr) { + CHECK(msg_v.is()) << "Invalid messages" << err_templ; + messages_i.push_back(msg_v.get()); + } + messages.push_back(messages_i); + } + this->messages = messages; + } else { + CHECK(partial_update) << "Key \"messages\" not found."; + } + if (config.count("offset")) { + CHECK(config["offset"].is()) << "Invalid offset" << err_templ; + this->offset = config["offset"].get(); + } else { + CHECK(partial_update) << "Key \"offset\" not found."; + } + if (config.count("separator_style")) { + CHECK(config["separator_style"].is()) << "Invalid separator style" << err_templ; + this->separator_style = SeparatorStyle(config["separator_style"].get()); + } else { + CHECK(partial_update) << "Key \"separator_style\" not found."; + } + if (config.count("seps")) { + std::vector seps; + CHECK(config["seps"].is()) << "Invalid seps" << err_templ; + picojson::array seps_arr = config["seps"].get(); + for (const picojson::value& sep : seps_arr) { + CHECK(sep.is()) << "Invalid seps" << err_templ; + seps.push_back(sep.get()); + } + this->seps = seps; + } else { + CHECK(partial_update) << "Key \"seps\" not found."; + } + if (config.count("stop_str")) { + CHECK(config["stop_str"].is()) << "Invalid stop_str" << err_templ; + this->stop_str = config["stop_str"].get(); + } else { + CHECK(partial_update) << "Key \"stop_str\" not found."; + } + if (config.count("stop_tokens")) { + CHECK(config["stop_tokens"].is()) << "Invalid stop_tokens" << err_templ; + picojson::array stop_tokens_arr = config["stop_tokens"].get(); + std::vector stop_tokens; + for (const picojson::value& stop_token : stop_tokens_arr) { + CHECK(stop_token.is()) << "Invalid stop_tokens" << err_templ; + stop_tokens.push_back(stop_token.get()); + } + this->stop_tokens = stop_tokens; + } else { + CHECK(partial_update) << "Key \"stop_tokens\" not found."; + } + if (config.count("add_bos")) { + CHECK(config["add_bos"].is()) << "Invalid add_bos" << err_templ; + this->add_bos = config["add_bos"].get(); + } else { + CHECK(partial_update) << "Key \"add_bos\" not found."; + } +} + +void Conversation::LoadJSONOverride(const std::string& config_str, bool partial_update) { + picojson::value config_json; + std::string err = picojson::parse(config_json, config_str); + if (!err.empty()) { + LOG(FATAL) << err; + return; + } + LoadJSONOverride(config_json, partial_update); +} + +picojson::value Conversation::SerializeToJSON() const { + picojson::object config; + config["name"] = picojson::value(this->name); + config["system"] = picojson::value(this->system); + picojson::array roles_arr; + for (const std::string& role_str : this->roles) { + roles_arr.push_back(picojson::value(role_str)); + } + config["roles"] = picojson::value(roles_arr); + picojson::array msgs_arr; + for (const std::vector& msgs_i : this->messages) { + picojson::array msgs_i_arr; + for (const std::string& msg_str : msgs_i) { + msgs_i_arr.push_back(picojson::value(msg_str)); + } + msgs_arr.push_back(picojson::value(msgs_i_arr)); + } + config["messages"] = picojson::value(msgs_arr); + config["offset"] = picojson::value((int64_t)this->offset); + config["separator_style"] = picojson::value((int64_t)this->separator_style); + picojson::array seps_arr; + for (const std::string& sep_str : this->seps) { + seps_arr.push_back(picojson::value(sep_str)); + } + config["seps"] = picojson::value(seps_arr); + config["stop_str"] = picojson::value(this->stop_str); + picojson::array stop_tokens_arr; + for (const int32_t& stop_token_str : this->stop_tokens) { + stop_tokens_arr.push_back(picojson::value((int64_t)stop_token_str)); + } + config["stop_tokens"] = picojson::value(stop_tokens_arr); + config["add_bos"] = picojson::value(this->add_bos); + return picojson::value(config); +} + +std::string Conversation::SerializeToJSONStr() const { + return SerializeToJSON().serialize(true); +} + } // namespace llm } // namespace mlc diff --git a/cpp/conversation.h b/cpp/conversation.h index db4be51931..c2f2fc8a77 100644 --- a/cpp/conversation.h +++ b/cpp/conversation.h @@ -1,14 +1,15 @@ /*! * Copyright (c) 2023 by Contributors - * \file llm_chat.cc - * \brief Implementation of llm chat. + * \file conversation.h + * \brief Header of conversation template in MLC-LLM. */ -#define PICOJSON_USE_INT64 -#define __STDC_FORMAT_MACROS #include #include #include +#define PICOJSON_USE_INT64 +#define __STDC_FORMAT_MACROS +#include namespace mlc { namespace llm { @@ -53,22 +54,88 @@ class Conversation { Conversation() = default; + inline bool operator==(const Conversation& other) const { + bool eq_roles = true; + if (roles.size() != other.roles.size()) { + eq_roles = false; + } else { + eq_roles = std::equal(roles.begin(), roles.end(), other.roles.begin()); + } + bool eq_messages = true; + if (messages.size() != other.messages.size()) { + eq_messages = false; + } else { + for (size_t i = 0; i < messages.size(); ++i) { + const std::vector& lhs_message_i = messages[i]; + const std::vector& rhs_message_i = other.messages[i]; + if (lhs_message_i.size() != rhs_message_i.size()) { + eq_messages = false; + break; + } else { + eq_messages &= + std::equal(lhs_message_i.begin(), lhs_message_i.end(), rhs_message_i.begin()); + } + } + } + bool eq_seps = true; + if (seps.size() != other.seps.size()) { + eq_seps = false; + } else { + eq_seps = std::equal(seps.begin(), seps.end(), other.seps.begin()); + } + bool eq_stop_tokens = true; + if (stop_tokens.size() != other.stop_tokens.size()) { + eq_stop_tokens = false; + } else { + eq_stop_tokens = + std::equal(stop_tokens.begin(), stop_tokens.end(), other.stop_tokens.begin()); + } + return (name == other.name) && (system == other.system) && (offset == other.offset) && + (separator_style == other.separator_style) && (stop_str == other.stop_str) && + (add_bos == other.add_bos) && eq_roles && eq_messages && eq_seps && eq_stop_tokens; + } + /** * \brief Create conversation from existing registered template. * \param name The template name. */ static Conversation FromTemplate(const std::string& name); - // TODO(mlc-team): Implement this /*! - * \brief Load overrides from JSON that partially - * overrides some of the options. + * \brief Load JSON config in raw string and overrides options. + * + * \param config_str A json config in raw string that partially specifies + * some of the options. + * \param partial_update Whether it's a partial update or full update, if set to true, + * we perform a partial update on some of the provided options; if set to false, all + * options must be provided. + * \note This function overrides existing configurations. + */ + void LoadJSONOverride(const std::string& config_str, bool partial_update = false); + + /*! + * \brief Load JSON config and overrides options. * - * \param config_json A json config that partially specifies - * some of the options + * \param config_json A json config in picojson type that is partially specifies + * some of the options. + * \param partial_update Whether it's a partial update or full update, if set to true, + * we perform a partial update on some of the provided options; if set to false, all + * options must be provided. * \note This function overrides existing configurations. */ - void LoadJSONOverride(const std::string& config_json); + void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false); + + /*! + * \brief Serialize the Conversation to JSON. + * \return Serialized conversion in JSON format. + */ + picojson::value SerializeToJSON() const; + + /*! + * \brief Serialize the Conversation to JSON String. + * \return A string storing the serialized conversation in JSON format. + */ + std::string SerializeToJSONStr() const; /*! * \brief Get the entire prompt array diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index edac4ef6d2..6700c11326 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -143,10 +143,103 @@ class LLMChat { return os.str(); } - // TODO(mlc-team): - // Add an option: app_config_json - // which optionally overrides the mlc-chat-config - void Reload(tvm::runtime::Module executable, String model_path) { + /*! + * \brief Load JSON config and override options. + * \param config_json A json config in picojson type that is partially specifies + * some of the options. + * \param partial_update Whether it's a partial update or full update, if set to true, + * we perform a partial update on some of the provided options; if set to false, all + * options must be provided. + * \note This function overrides existing configurations. + */ + void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false) { + picojson::object config = config_json.get(); + if (config.count("temperature")) { + CHECK(config["temperature"].is()); + this->temperature_ = config["temperature"].get(); + } else { + CHECK(partial_update) << "Key \"temperature\" not found."; + } + if (config.count("repetition_penalty")) { + CHECK(config["repetition_penalty"].is()); + CHECK(this->repetition_penalty_ > 0) << "Repetition penalty must be a positive number!"; + this->repetition_penalty_ = config["repetition_penalty"].get(); + } else { + CHECK(partial_update) << "Key \"repetition_penalty\" not found."; + } + if (config.count("top_p")) { + CHECK(config["top_p"].is()); + this->top_p_ = config["top_p"].get(); + } else { + CHECK(partial_update) << "Key \"top_p\" not found."; + } + if (config.count("mean_gen_len")) { + CHECK(config["mean_gen_len"].is()); + this->mean_gen_len_ = config["mean_gen_len"].get(); + } else { + CHECK(partial_update) << "Key \"mean_gen_len\" not found."; + } + if (config.count("shift_fill_factor")) { + CHECK(config["shift_fill_factor"].is()); + this->shift_fill_factor_ = config["shift_fill_factor"].get(); + } else { + CHECK(partial_update) << "Key \"shift_fill_factor\" not found."; + } + if (config.count("conv_template") || config.count("conv_config")) { + if (config.count("conv_template")) { + ICHECK(config["conv_template"].is()); + std::string conv_template = config["conv_template"].get(); + this->conversation_ = Conversation::FromTemplate(conv_template); + } + if (config.count("conv_config")) { + // conv_config can override conv_template + this->conversation_.LoadJSONOverride(config["conv_config"]); + } + } else { + CHECK(partial_update) << "Key \"conv_template\" and \"conv_config\" not found."; + } + } + + /*! + * \brief Load JSON config and override options. + * \param config_json A json config in picojson type that is partially specifies + * some of the options. + * \param partial_update Whether it's a partial update or full update, if set to true, + * we perform a partial update on some of the provided options; if set to false, all + * options must be provided. + * \note This function overrides existing configurations. + */ + void LoadJSONOverride(const std::string& config_str, bool partial_update = false) { + picojson::value config_json; + std::string err = picojson::parse(config_json, config_str); + if (!err.empty()) { + LOG(FATAL) << err; + return; + } + LoadJSONOverride(config_json, partial_update); + } + + picojson::value SerializeToJSON() const { + picojson::object config; + config["temperature"] = picojson::value(this->temperature_); + config["repetition_penalty"] = picojson::value(this->repetition_penalty_); + config["top_p"] = picojson::value(this->top_p_); + config["mean_gen_len"] = picojson::value(this->mean_gen_len_); + config["shift_fill_factor"] = picojson::value(this->shift_fill_factor_); + config["conv_config"] = this->conversation_.SerializeToJSON(); + return picojson::value(config); + } + + std::string SerializeToJSONStr() const { return SerializeToJSON().serialize(true); } + + /*! + * \brief Reload model, tokenizers and configurations from the specified model path. + * \param executable The module to reload. + * \param model_path The path to search for models. + * \param app_config_json The JSON string used to partially override the configuration loaded from + * disk, default to empty string. + */ + void Reload(tvm::runtime::Module executable, String model_path, String app_config_json = "") { // Step 1. Set tokenizer. this->tokenizer_ = TokenizerFromPath(model_path); @@ -210,22 +303,7 @@ class LLMChat { ICHECK(config_istream); config_ostream << config_istream.rdbuf(); std::string config_str = config_ostream.str(); - picojson::value config_info; - picojson::parse(config_info, config_str); - auto config = config_info.get(); - ICHECK(config["conv_template"].is()); - ICHECK(config["temperature"].is()); - ICHECK(config["repetition_penalty"].is()); - ICHECK(config["top_p"].is()); - ICHECK(config["mean_gen_len"].is()); - ICHECK(config["shift_fill_factor"].is()); - std::string conv_template = config["conv_template"].get(); - this->temperature_ = config["temperature"].get(); - this->repetition_penalty_ = config["repetition_penalty"].get(); - CHECK(this->repetition_penalty_ > 0) << "Repetition penalty must be a positive number!"; - this->top_p_ = config["top_p"].get(); - this->mean_gen_len_ = config["mean_gen_len"].get(); - this->shift_fill_factor_ = config["shift_fill_factor"].get(); + LoadJSONOverride(config_str, false); // Step 6. Process metadata String metadata_str = this->get_metadata_func_(); @@ -237,8 +315,12 @@ class LLMChat { this->model_name_ = metadata["model_name"].get(); this->max_window_size_ = metadata["max_window_size"].get(); - // Step 7. Initialize conversation. - this->conversation_ = Conversation::FromTemplate(conv_template); + // Step 7. Override configuration from app_config_json. + if (!app_config_json.empty()) { + LOG(INFO) << "Apply configuration override from user-provided JSON..."; + LoadJSONOverride(app_config_json, true); + } + this->ResetChat(); } @@ -760,11 +842,16 @@ class LLMChatModule : public ModuleNode { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "reload") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.size(), 2); chat_ = nullptr; ClearGlobalMemoryManager(); chat_ = std::make_unique(LLMChat(device_)); - chat_->Reload(args[0], args[1]); + if (args.size() == 2) { + chat_->Reload(args[0], args[1]); + } else if (args.size() == 3) { + chat_->Reload(args[0], args[1], args[2]); + } else { + LOG(FATAL) << "Invalid number of arguments for reload function"; + } }); } else if (name == "unload") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { @@ -814,6 +901,10 @@ class LLMChatModule : public ModuleNode { } else if (name == "reset_runtime_stats") { return PackedFunc( [this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->ResetRuntimeStats(); }); + } else if (name == "serialize_config") { + return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { + *rv = GetChat()->SerializeToJSONStr(); + }); } else { return PackedFunc(nullptr); } diff --git a/tests/cpp/conv_unittest.cc b/tests/cpp/conv_unittest.cc new file mode 100644 index 0000000000..4ab6141b6f --- /dev/null +++ b/tests/cpp/conv_unittest.cc @@ -0,0 +1,29 @@ +#include +#include + +void _TestConversationJSONRoundTrip(std::string templ_name) { + mlc::llm::Conversation conv = mlc::llm::Conversation::FromTemplate(templ_name); + std::string conv_json = conv.SerializeToJSONStr(); + mlc::llm::Conversation conv_new; + conv_new.LoadJSONOverride(conv_json, false); + ASSERT_EQ(conv, conv_new); +} + +void _TestConversationPartialUpdate() { + mlc::llm::Conversation conv; + std::string json_str = "{\"offset\": -1}"; + ASSERT_ANY_THROW(conv.LoadJSONOverride(json_str, false)); + conv.LoadJSONOverride(json_str, true); + ASSERT_EQ(conv.offset, -1); +} + +TEST(ConversationTest, ConversationJSONRoundTripTest) { + _TestConversationJSONRoundTrip("vicuna_v1.1"); + _TestConversationJSONRoundTrip("conv_one_shot"); + _TestConversationJSONRoundTrip("redpajama_chat"); + _TestConversationJSONRoundTrip("LM"); +} + +TEST(ConversationTest, ConversationPartialUpdateTest) { + _TestConversationPartialUpdate(); +}