Skip to content

Commit

Permalink
Load conversation from JSON (mlc-ai#252)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored May 29, 2023
1 parent ddb14d4 commit fe6f0a6
Show file tree
Hide file tree
Showing 8 changed files with 387 additions and 36 deletions.
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@
[submodule "3rdparty/tokenizers-cpp"]
path = 3rdparty/tokenizers-cpp
url = https://github.com/mlc-ai/tokenizers-cpp
[submodule "3rdparty/googletest"]
path = 3rdparty/googletest
url = https://github.com/google/googletest.git
[submodule "3rdparty/picojson"]
path = 3rdparty/picojson
url = https://github.com/kazuho/picojson.git
1 change: 1 addition & 0 deletions 3rdparty/googletest
Submodule googletest added at 458046
1 change: 1 addition & 0 deletions 3rdparty/picojson
Submodule picojson added at 111c9b
16 changes: 14 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ if (MLC_LLM_INSTALL_STATIC_LIB)
set(BUILD_STATIC_RUNTIME ON)
endif()

option(BUILD_CPP_TEST "Build cpp unittests" OFF)

set(CMAKE_POSITION_INDEPENDENT_CODE ON)

# tvm runtime config: minimize runtime components
Expand Down Expand Up @@ -81,8 +83,18 @@ set_target_properties(mlc_llm_static PROPERTIES OUTPUT_NAME mlc_llm)
target_link_libraries(mlc_llm PUBLIC tvm_runtime)
target_link_libraries(mlc_llm PRIVATE tokenizers_cpp)


# Exmaple app that may depends on mlc_llm
if (BUILD_CPP_TEST)
message(STATUS "Building cpp unittests")
add_subdirectory(3rdparty/googletest)
file(GLOB_RECURSE MLC_LLM_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/cpp/*unittest.cc)
add_executable(mlc_llm_cpp_tests ${MLC_LLM_TEST_SRCS})
target_include_directories(mlc_llm_cpp_tests PRIVATE ${MLC_LLM_INCLUDES})
target_include_directories(mlc_llm_cpp_tests PRIVATE ${PROJECT_SOURCE_DIR}/cpp)
target_include_directories(mlc_llm_cpp_tests PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
target_link_libraries(mlc_llm_cpp_tests PUBLIC mlc_llm gtest gtest_main)
endif(BUILD_CPP_TEST)

# Example app that may depends on mlc_llm
add_executable(mlc_chat_cli $<TARGET_OBJECTS:mlc_cli_objs>)
target_include_directories(mlc_cli_objs PRIVATE ${MLC_LLM_INCLUDES})
target_include_directories(mlc_cli_objs PRIVATE 3rdparty/argparse/include)
Expand Down
144 changes: 144 additions & 0 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,149 @@ Conversation Conversation::FromTemplate(const std::string& name) {
return it->second();
}

void Conversation::LoadJSONOverride(const picojson::value& config_json, bool partial_update) {
std::string err_templ = " in conversion template json file.";
picojson::object config = config_json.get<picojson::object>();
if (config.count("name")) {
CHECK(config["name"].is<std::string>()) << "Invalid name" << err_templ;
this->name = config["name"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"name\" not found.";
}
if (config.count("system")) {
CHECK(config["system"].is<std::string>()) << "Invalid system" << err_templ;
this->system = config["system"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"system\" not found.";
}
if (config.count("roles")) {
CHECK(config["roles"].is<picojson::array>()) << "Invalid roles" << err_templ;
picojson::array roles_arr = config["roles"].get<picojson::array>();
std::vector<std::string> roles;
for (const picojson::value& v : roles_arr) {
CHECK(v.is<std::string>()) << "Invalid roles" << err_templ;
roles.push_back(v.get<std::string>());
}
this->roles = roles;
} else {
CHECK(partial_update) << "Key \"roles\" not found.";
}
if (config.count("messages")) {
CHECK(config["messages"].is<picojson::array>()) << "Invalid messages" << err_templ;
std::vector<std::vector<std::string>> messages;
picojson::array msgs_arr = config["messages"].get<picojson::array>();
for (const picojson::value& msgs_i : msgs_arr) {
CHECK(msgs_i.is<picojson::array>()) << "Invalid messages" << err_templ;
picojson::array msgs_i_arr = msgs_i.get<picojson::array>();
std::vector<std::string> messages_i;
for (const picojson::value& msg_v : msgs_i_arr) {
CHECK(msg_v.is<std::string>()) << "Invalid messages" << err_templ;
messages_i.push_back(msg_v.get<std::string>());
}
messages.push_back(messages_i);
}
this->messages = messages;
} else {
CHECK(partial_update) << "Key \"messages\" not found.";
}
if (config.count("offset")) {
CHECK(config["offset"].is<int64_t>()) << "Invalid offset" << err_templ;
this->offset = config["offset"].get<int64_t>();
} else {
CHECK(partial_update) << "Key \"offset\" not found.";
}
if (config.count("separator_style")) {
CHECK(config["separator_style"].is<int64_t>()) << "Invalid separator style" << err_templ;
this->separator_style = SeparatorStyle(config["separator_style"].get<int64_t>());
} else {
CHECK(partial_update) << "Key \"separator_style\" not found.";
}
if (config.count("seps")) {
std::vector<std::string> seps;
CHECK(config["seps"].is<picojson::array>()) << "Invalid seps" << err_templ;
picojson::array seps_arr = config["seps"].get<picojson::array>();
for (const picojson::value& sep : seps_arr) {
CHECK(sep.is<std::string>()) << "Invalid seps" << err_templ;
seps.push_back(sep.get<std::string>());
}
this->seps = seps;
} else {
CHECK(partial_update) << "Key \"seps\" not found.";
}
if (config.count("stop_str")) {
CHECK(config["stop_str"].is<std::string>()) << "Invalid stop_str" << err_templ;
this->stop_str = config["stop_str"].get<std::string>();
} else {
CHECK(partial_update) << "Key \"stop_str\" not found.";
}
if (config.count("stop_tokens")) {
CHECK(config["stop_tokens"].is<picojson::array>()) << "Invalid stop_tokens" << err_templ;
picojson::array stop_tokens_arr = config["stop_tokens"].get<picojson::array>();
std::vector<int32_t> stop_tokens;
for (const picojson::value& stop_token : stop_tokens_arr) {
CHECK(stop_token.is<int64_t>()) << "Invalid stop_tokens" << err_templ;
stop_tokens.push_back(stop_token.get<int64_t>());
}
this->stop_tokens = stop_tokens;
} else {
CHECK(partial_update) << "Key \"stop_tokens\" not found.";
}
if (config.count("add_bos")) {
CHECK(config["add_bos"].is<bool>()) << "Invalid add_bos" << err_templ;
this->add_bos = config["add_bos"].get<bool>();
} else {
CHECK(partial_update) << "Key \"add_bos\" not found.";
}
}

void Conversation::LoadJSONOverride(const std::string& config_str, bool partial_update) {
picojson::value config_json;
std::string err = picojson::parse(config_json, config_str);
if (!err.empty()) {
LOG(FATAL) << err;
return;
}
LoadJSONOverride(config_json, partial_update);
}

picojson::value Conversation::SerializeToJSON() const {
picojson::object config;
config["name"] = picojson::value(this->name);
config["system"] = picojson::value(this->system);
picojson::array roles_arr;
for (const std::string& role_str : this->roles) {
roles_arr.push_back(picojson::value(role_str));
}
config["roles"] = picojson::value(roles_arr);
picojson::array msgs_arr;
for (const std::vector<std::string>& msgs_i : this->messages) {
picojson::array msgs_i_arr;
for (const std::string& msg_str : msgs_i) {
msgs_i_arr.push_back(picojson::value(msg_str));
}
msgs_arr.push_back(picojson::value(msgs_i_arr));
}
config["messages"] = picojson::value(msgs_arr);
config["offset"] = picojson::value((int64_t)this->offset);
config["separator_style"] = picojson::value((int64_t)this->separator_style);
picojson::array seps_arr;
for (const std::string& sep_str : this->seps) {
seps_arr.push_back(picojson::value(sep_str));
}
config["seps"] = picojson::value(seps_arr);
config["stop_str"] = picojson::value(this->stop_str);
picojson::array stop_tokens_arr;
for (const int32_t& stop_token_str : this->stop_tokens) {
stop_tokens_arr.push_back(picojson::value((int64_t)stop_token_str));
}
config["stop_tokens"] = picojson::value(stop_tokens_arr);
config["add_bos"] = picojson::value(this->add_bos);
return picojson::value(config);
}

std::string Conversation::SerializeToJSONStr() const {
return SerializeToJSON().serialize(true);
}

} // namespace llm
} // namespace mlc
87 changes: 77 additions & 10 deletions cpp/conversation.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
/*!
* Copyright (c) 2023 by Contributors
* \file llm_chat.cc
* \brief Implementation of llm chat.
* \file conversation.h
* \brief Header of conversation template in MLC-LLM.
*/
#define PICOJSON_USE_INT64
#define __STDC_FORMAT_MACROS
#include <tvm/runtime/module.h>

#include <string>
#include <vector>
#define PICOJSON_USE_INT64
#define __STDC_FORMAT_MACROS
#include <picojson.h>

namespace mlc {
namespace llm {
Expand Down Expand Up @@ -53,22 +54,88 @@ class Conversation {

Conversation() = default;

inline bool operator==(const Conversation& other) const {
bool eq_roles = true;
if (roles.size() != other.roles.size()) {
eq_roles = false;
} else {
eq_roles = std::equal(roles.begin(), roles.end(), other.roles.begin());
}
bool eq_messages = true;
if (messages.size() != other.messages.size()) {
eq_messages = false;
} else {
for (size_t i = 0; i < messages.size(); ++i) {
const std::vector<std::string>& lhs_message_i = messages[i];
const std::vector<std::string>& rhs_message_i = other.messages[i];
if (lhs_message_i.size() != rhs_message_i.size()) {
eq_messages = false;
break;
} else {
eq_messages &=
std::equal(lhs_message_i.begin(), lhs_message_i.end(), rhs_message_i.begin());
}
}
}
bool eq_seps = true;
if (seps.size() != other.seps.size()) {
eq_seps = false;
} else {
eq_seps = std::equal(seps.begin(), seps.end(), other.seps.begin());
}
bool eq_stop_tokens = true;
if (stop_tokens.size() != other.stop_tokens.size()) {
eq_stop_tokens = false;
} else {
eq_stop_tokens =
std::equal(stop_tokens.begin(), stop_tokens.end(), other.stop_tokens.begin());
}
return (name == other.name) && (system == other.system) && (offset == other.offset) &&
(separator_style == other.separator_style) && (stop_str == other.stop_str) &&
(add_bos == other.add_bos) && eq_roles && eq_messages && eq_seps && eq_stop_tokens;
}

/**
* \brief Create conversation from existing registered template.
* \param name The template name.
*/
static Conversation FromTemplate(const std::string& name);

// TODO(mlc-team): Implement this
/*!
* \brief Load overrides from JSON that partially
* overrides some of the options.
* \brief Load JSON config in raw string and overrides options.
*
* \param config_str A json config in raw string that partially specifies
* some of the options.
* \param partial_update Whether it's a partial update or full update, if set to true,
* we perform a partial update on some of the provided options; if set to false, all
* options must be provided.
* \note This function overrides existing configurations.
*/
void LoadJSONOverride(const std::string& config_str, bool partial_update = false);

/*!
* \brief Load JSON config and overrides options.
*
* \param config_json A json config that partially specifies
* some of the options
* \param config_json A json config in picojson type that is partially specifies
* some of the options.
* \param partial_update Whether it's a partial update or full update, if set to true,
* we perform a partial update on some of the provided options; if set to false, all
* options must be provided.
* \note This function overrides existing configurations.
*/
void LoadJSONOverride(const std::string& config_json);
void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false);

/*!
* \brief Serialize the Conversation to JSON.
* \return Serialized conversion in JSON format.
*/
picojson::value SerializeToJSON() const;

/*!
* \brief Serialize the Conversation to JSON String.
* \return A string storing the serialized conversation in JSON format.
*/
std::string SerializeToJSONStr() const;

/*!
* \brief Get the entire prompt array
Expand Down
Loading

0 comments on commit fe6f0a6

Please sign in to comment.