Skip to content

Commit

Permalink
[CLI] Add stats command (#14)
Browse files Browse the repository at this point in the history
This PR addds stats command in cli to display
the stats of last round of convo.
  • Loading branch information
tqchen authored Apr 30, 2023
1 parent b8c421b commit d3e7f16
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions cpp/cli_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ std::vector<std::string> CountUTF8(const std::string& s) {
return std::move(ret);
}

void PrintSpecialCommands() {
std::cout << "You can use the following special commands:\n"
<< " /help print the special commands\n"
<< " /exit quit the cli\n"
<< " /stats print out the latest stats (token/sec)\n"
<< " /reset restart a fresh chat\n"
<< std::endl
<< std::flush;
}

/*!
* \brief Start a chat conversation.
*
Expand Down Expand Up @@ -153,6 +163,7 @@ void Chat(tvm::runtime::Module chat_mod, const std::string& model, int64_t max_g
auto f_stop = chat_mod.GetFunction("stopped");
auto f_encode = chat_mod.GetFunction("encode");
auto f_decode = chat_mod.GetFunction("decode");
auto f_stats = chat_mod.GetFunction("runtime_stats_text");
std::string role0 = chat_mod.GetFunction("get_role0")();
std::string role1 = chat_mod.GetFunction("get_role1")();

Expand All @@ -168,6 +179,13 @@ void Chat(tvm::runtime::Module chat_mod, const std::string& model, int64_t max_g
continue;
} else if (inp.substr(0, 5) == "/exit") {
break;
} else if (inp.substr(0, 6) == "/stats") {
std::string stats_text = f_stats();
std::cout << stats_text << std::endl << std::flush;
continue;
} else if (inp.substr(0, 5) == "/help") {
PrintSpecialCommands();
continue;
}

std::string prev_printed = "";
Expand Down Expand Up @@ -249,7 +267,6 @@ int main(int argc, char* argv[]) {
std::vector<std::string> search_paths = {artifact_path + "/" + model + "/" + candidate,
artifact_path + "/models/" + model,
artifact_path + "/" + model, artifact_path + "/lib"};
std::string prefix = lib_name + "_" + candidate;
// search for lib_x86_64 and lib
lib_path_opt = FindFile(
search_paths,
Expand All @@ -264,7 +281,7 @@ int main(int argc, char* argv[]) {
if (!lib_path_opt) {
std::cerr << "Cannot find " << model << " lib in preferred path \"" << artifact_path << "/"
<< model << "/" << dtype_candidates[0] << "/" << lib_name << "_"
<< GetLibSuffixes()[0] << "\" or other candidate paths";
<< dtype_candidates[0] << GetLibSuffixes()[0] << "\" or other candidate paths";
return 1;
}
std::cout << "Use lib " << lib_path_opt.value().string() << std::endl;
Expand Down Expand Up @@ -309,7 +326,9 @@ int main(int argc, char* argv[]) {
std::cout << "Initializing the chat module..." << std::endl;
Module chat_mod =
mlc::llm::CreateChatModule(lib, tokenizer_path_opt.value().string(), params, device);
std::cout << "Finish loading\n" << std::endl;
std::cout << "Finish loading" << std::endl;
PrintSpecialCommands();

if (args.get<bool>("--evaluate")) {
chat_mod.GetFunction("evaluate")();
} else {
Expand Down

0 comments on commit d3e7f16

Please sign in to comment.