From 28e98b617aed12445b7304ac0903934461261aac Mon Sep 17 00:00:00 2001
From: Justine Tunney <jtunney@mozilla.com>
Date: Sat, 12 Oct 2024 17:41:52 -0700
Subject: [PATCH] Show prompt loading progress in chatbot

---
 llamafile/chatbot.cpp            | 76 +++++++++++++++++++++++---------
 llamafile/highlight_markdown.cpp |  2 +
 2 files changed, 57 insertions(+), 21 deletions(-)

diff --git a/llamafile/chatbot.cpp b/llamafile/chatbot.cpp
index 44932c8a6c..54d88e92fe 100644
--- a/llamafile/chatbot.cpp
+++ b/llamafile/chatbot.cpp
@@ -15,7 +15,6 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "llamafile/highlight.h"
 #include <assert.h>
 #include <cosmo.h>
 #include <ctype.h>
@@ -29,6 +28,7 @@
 #include "llama.cpp/common.h"
 #include "llama.cpp/llama.h"
 #include "llamafile/bestline.h"
+#include "llamafile/highlight.h"
 #include "llamafile/llamafile.h"
 
 #define BOLD "\e[1m"
@@ -74,6 +74,21 @@ static std::string basename(const std::string_view path) {
     }
 }
 
+__attribute__((format(printf, 1, 2))) static std::string format(const char *fmt, ...) {
+    va_list ap, ap2;
+    va_start(ap, fmt);
+    va_copy(ap2, ap);
+    int size = 512;
+    std::string res(size, '\0');
+    int need = vsnprintf(res.data(), size, fmt, ap);
+    res.resize(need + 1, '\0');
+    if (need + 1 > size)
+        vsnprintf(res.data(), need + 1, fmt, ap2);
+    va_end(ap2);
+    va_end(ap);
+    return res;
+}
+
 static void on_completion(const char *line, bestlineCompletions *comp) {
     static const char *const kCompletions[] = {
         "/context", //
@@ -134,6 +149,15 @@ static void print_logo(const char16_t *s) {
     }
 }
 
+static void print_ephemeral(const std::string_view &description) {
+    fprintf(stderr, " " BRIGHT_BLACK "%.*s" UNFOREGROUND "\r", (int)description.size(),
+            description.data());
+}
+
+static void clear_ephemeral(void) {
+    fprintf(stderr, CLEAR_FORWARD);
+}
+
 static void die_out_of_context(void) {
     fprintf(stderr,
             "\n" BRIGHT_RED
@@ -145,7 +169,13 @@ static void die_out_of_context(void) {
 
 static void eval_tokens(std::vector<llama_token> tokens, int n_batch) {
     int N = (int)tokens.size();
+    if (n_past + N > llama_n_ctx(g_ctx)) {
+        n_past += N;
+        die_out_of_context();
+    }
     for (int i = 0; i < N; i += n_batch) {
+        if (N > n_batch)
+            print_ephemeral(format("loading prompt %d%%...", (int)((double)i / N * 100)));
         int n_eval = (int)tokens.size() - i;
         if (n_eval > n_batch)
             n_eval = n_batch;
@@ -161,17 +191,8 @@ static void eval_id(int id) {
     eval_tokens(tokens, 1);
 }
 
-static void eval_string(const char *str, int n_batch, bool add_special, bool parse_special) {
-    std::string str2 = str;
-    eval_tokens(llama_tokenize(g_ctx, str2, add_special, parse_special), n_batch);
-}
-
-static void print_ephemeral(const char *description) {
-    fprintf(stderr, " " BRIGHT_BLACK "%s" UNFOREGROUND "\r", description);
-}
-
-static void clear_ephemeral(void) {
-    fprintf(stderr, CLEAR_FORWARD);
+static void eval_string(const std::string &str, int n_batch, bool add_special, bool parse_special) {
+    eval_tokens(llama_tokenize(g_ctx, str, add_special, parse_special), n_batch);
 }
 
 int chatbot_main(int argc, char **argv) {
@@ -180,8 +201,12 @@ int chatbot_main(int argc, char **argv) {
     log_disable();
 
     gpt_params params;
-    if (!gpt_params_parse(argc, argv, params))
-        return 1;
+    params.n_batch = 512; // for better progress indication
+    params.sparams.temp = 0; // don't believe in randomness by default
+    if (!gpt_params_parse(argc, argv, params)) {
+        fprintf(stderr, "error: failed to parse flags\n");
+        exit(1);
+    }
 
     print_logo(u"\n\
 ██╗     ██╗      █████╗ ███╗   ███╗ █████╗ ███████╗██╗██╗     ███████╗\n\
@@ -203,15 +228,25 @@ int chatbot_main(int argc, char **argv) {
     llama_model_params model_params = llama_model_default_params();
     model_params.n_gpu_layers = llamafile_gpu_layers(35);
     g_model = llama_load_model_from_file(params.model.c_str(), model_params);
-    if (g_model == NULL)
-        return 2;
+    if (g_model == NULL) {
+        clear_ephemeral();
+        fprintf(stderr, "%s: failed to load model\n", params.model.c_str());
+        exit(2);
+    }
+    if (!params.n_ctx)
+        params.n_ctx = llama_n_ctx_train(g_model);
+    if (params.n_ctx < params.n_batch)
+        params.n_batch = params.n_ctx;
     clear_ephemeral();
 
     print_ephemeral("initializing context...");
     llama_context_params ctx_params = llama_context_params_from_gpt_params(params);
     g_ctx = llama_new_context_with_model(g_model, ctx_params);
-    if (g_ctx == NULL)
-        return 3;
+    if (g_ctx == NULL) {
+        clear_ephemeral();
+        fprintf(stderr, "error: failed to initialize context\n");
+        exit(3);
+    }
     clear_ephemeral();
 
     if (params.prompt.empty())
@@ -219,11 +254,10 @@ int chatbot_main(int argc, char **argv) {
             "A chat between a curious human and an artificial intelligence assistant. The "
             "assistant gives helpful, detailed, and polite answers to the human's questions.";
 
-    print_ephemeral("loading prompt...");
     bool add_bos = llama_should_add_bos_token(llama_get_model(g_ctx));
     std::vector<llama_chat_msg> chat = {{"system", params.prompt}};
     std::string msg = llama_chat_apply_template(g_model, params.chat_template, chat, false);
-    eval_string(msg.c_str(), params.n_batch, add_bos, true);
+    eval_string(msg, params.n_batch, add_bos, true);
     clear_ephemeral();
     printf("%s\n", params.special ? msg.c_str() : params.prompt.c_str());
 
@@ -254,7 +288,7 @@ int chatbot_main(int argc, char **argv) {
         }
         std::vector<llama_chat_msg> chat = {{"user", line}};
         std::string msg = llama_chat_apply_template(g_model, params.chat_template, chat, true);
-        eval_string(msg.c_str(), params.n_batch, false, true);
+        eval_string(msg, params.n_batch, false, true);
         while (!g_got_sigint) {
             llama_token id = llama_sampling_sample(sampler, g_ctx, NULL);
             llama_sampling_accept(sampler, g_ctx, id, true);
diff --git a/llamafile/highlight_markdown.cpp b/llamafile/highlight_markdown.cpp
index cde006f01c..c77b661142 100644
--- a/llamafile/highlight_markdown.cpp
+++ b/llamafile/highlight_markdown.cpp
@@ -83,6 +83,8 @@ void HighlightMarkdown::feed(std::string *r, std::string_view input) {
             if (c == '*') {
                 t_ = NORMAL;
                 *r += RESET;
+            } else {
+                t_ = STRONG;
             }
             break;