From 28e98b617aed12445b7304ac0903934461261aac Mon Sep 17 00:00:00 2001 From: Justine Tunney <jtunney@mozilla.com> Date: Sat, 12 Oct 2024 17:41:52 -0700 Subject: [PATCH] Show prompt loading progress in chatbot --- llamafile/chatbot.cpp | 76 +++++++++++++++++++++++--------- llamafile/highlight_markdown.cpp | 2 + 2 files changed, 57 insertions(+), 21 deletions(-) diff --git a/llamafile/chatbot.cpp b/llamafile/chatbot.cpp index 44932c8a6c..54d88e92fe 100644 --- a/llamafile/chatbot.cpp +++ b/llamafile/chatbot.cpp @@ -15,7 +15,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "llamafile/highlight.h" #include <assert.h> #include <cosmo.h> #include <ctype.h> @@ -29,6 +28,7 @@ #include "llama.cpp/common.h" #include "llama.cpp/llama.h" #include "llamafile/bestline.h" +#include "llamafile/highlight.h" #include "llamafile/llamafile.h" #define BOLD "\e[1m" @@ -74,6 +74,21 @@ static std::string basename(const std::string_view path) { } } +__attribute__((format(printf, 1, 2))) static std::string format(const char *fmt, ...) { + va_list ap, ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = 512; + std::string res(size, '\0'); + int need = vsnprintf(res.data(), size, fmt, ap); + res.resize(need + 1, '\0'); + if (need + 1 > size) + vsnprintf(res.data(), need + 1, fmt, ap2); + va_end(ap2); + va_end(ap); + return res; +} + static void on_completion(const char *line, bestlineCompletions *comp) { static const char *const kCompletions[] = { "/context", // @@ -134,6 +149,15 @@ static void print_logo(const char16_t *s) { } } +static void print_ephemeral(const std::string_view &description) { + fprintf(stderr, " " BRIGHT_BLACK "%.*s" UNFOREGROUND "\r", (int)description.size(), + description.data()); +} + +static void clear_ephemeral(void) { + fprintf(stderr, CLEAR_FORWARD); +} + static void die_out_of_context(void) { fprintf(stderr, "\n" BRIGHT_RED @@ -145,7 +169,13 @@ static void die_out_of_context(void) { static void eval_tokens(std::vector<llama_token> tokens, int n_batch) { int N = (int)tokens.size(); + if (n_past + N > llama_n_ctx(g_ctx)) { + n_past += N; + die_out_of_context(); + } for (int i = 0; i < N; i += n_batch) { + if (N > n_batch) + print_ephemeral(format("loading prompt %d%%...", (int)((double)i / N * 100))); int n_eval = (int)tokens.size() - i; if (n_eval > n_batch) n_eval = n_batch; @@ -161,17 +191,8 @@ static void eval_id(int id) { eval_tokens(tokens, 1); } -static void eval_string(const char *str, int n_batch, bool add_special, bool parse_special) { - std::string str2 = str; - eval_tokens(llama_tokenize(g_ctx, str2, add_special, parse_special), n_batch); -} - -static void print_ephemeral(const char *description) { - fprintf(stderr, " " BRIGHT_BLACK "%s" UNFOREGROUND "\r", description); -} - -static void clear_ephemeral(void) { - fprintf(stderr, CLEAR_FORWARD); +static void eval_string(const std::string &str, int n_batch, bool add_special, bool parse_special) { + eval_tokens(llama_tokenize(g_ctx, str, add_special, parse_special), n_batch); } int chatbot_main(int argc, char **argv) { @@ -180,8 +201,12 @@ int chatbot_main(int argc, char **argv) { log_disable(); gpt_params params; - if (!gpt_params_parse(argc, argv, params)) - return 1; + params.n_batch = 512; // for better progress indication + params.sparams.temp = 0; // don't believe in randomness by default + if (!gpt_params_parse(argc, argv, params)) { + fprintf(stderr, "error: failed to parse flags\n"); + exit(1); + } print_logo(u"\n\ ██╗ ██╗ █████╗ ███╗ ███╗ █████╗ ███████╗██╗██╗ ███████╗\n\ @@ -203,15 +228,25 @@ int chatbot_main(int argc, char **argv) { llama_model_params model_params = llama_model_default_params(); model_params.n_gpu_layers = llamafile_gpu_layers(35); g_model = llama_load_model_from_file(params.model.c_str(), model_params); - if (g_model == NULL) - return 2; + if (g_model == NULL) { + clear_ephemeral(); + fprintf(stderr, "%s: failed to load model\n", params.model.c_str()); + exit(2); + } + if (!params.n_ctx) + params.n_ctx = llama_n_ctx_train(g_model); + if (params.n_ctx < params.n_batch) + params.n_batch = params.n_ctx; clear_ephemeral(); print_ephemeral("initializing context..."); llama_context_params ctx_params = llama_context_params_from_gpt_params(params); g_ctx = llama_new_context_with_model(g_model, ctx_params); - if (g_ctx == NULL) - return 3; + if (g_ctx == NULL) { + clear_ephemeral(); + fprintf(stderr, "error: failed to initialize context\n"); + exit(3); + } clear_ephemeral(); if (params.prompt.empty()) @@ -219,11 +254,10 @@ int chatbot_main(int argc, char **argv) { "A chat between a curious human and an artificial intelligence assistant. The " "assistant gives helpful, detailed, and polite answers to the human's questions."; - print_ephemeral("loading prompt..."); bool add_bos = llama_should_add_bos_token(llama_get_model(g_ctx)); std::vector<llama_chat_msg> chat = {{"system", params.prompt}}; std::string msg = llama_chat_apply_template(g_model, params.chat_template, chat, false); - eval_string(msg.c_str(), params.n_batch, add_bos, true); + eval_string(msg, params.n_batch, add_bos, true); clear_ephemeral(); printf("%s\n", params.special ? msg.c_str() : params.prompt.c_str()); @@ -254,7 +288,7 @@ int chatbot_main(int argc, char **argv) { } std::vector<llama_chat_msg> chat = {{"user", line}}; std::string msg = llama_chat_apply_template(g_model, params.chat_template, chat, true); - eval_string(msg.c_str(), params.n_batch, false, true); + eval_string(msg, params.n_batch, false, true); while (!g_got_sigint) { llama_token id = llama_sampling_sample(sampler, g_ctx, NULL); llama_sampling_accept(sampler, g_ctx, id, true); diff --git a/llamafile/highlight_markdown.cpp b/llamafile/highlight_markdown.cpp index cde006f01c..c77b661142 100644 --- a/llamafile/highlight_markdown.cpp +++ b/llamafile/highlight_markdown.cpp @@ -83,6 +83,8 @@ void HighlightMarkdown::feed(std::string *r, std::string_view input) { if (c == '*') { t_ = NORMAL; *r += RESET; + } else { + t_ = STRONG; } break;