Skip to content

Commit

Permalink
Make fixes to chatbot
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Oct 1, 2024
1 parent d617c0b commit aca02fa
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 58 deletions.
17 changes: 13 additions & 4 deletions llamafile/bestline.c
Original file line number Diff line number Diff line change
Expand Up @@ -3201,6 +3201,7 @@ static ssize_t bestlineEdit(int stdin_fd, int stdout_fd, const char *prompt, con
// fallthrough
case '\n': {
char is_finished = 1;
char needs_strip = 0;
free(history[--historylen]);
history[historylen] = 0;
l.final = 1;
Expand All @@ -3215,11 +3216,19 @@ static ssize_t bestlineEdit(int stdin_fd, int stdout_fd, const char *prompt, con
is_finished = 0;
if (llamamode)
if (StartsWith(l.full.b, "\"\"\""))
is_finished = l.full.len > 3 && EndsWith(l.full.b, "\"\"\"");
needs_strip = is_finished = l.full.len > 6 && EndsWith(l.full.b, "\"\"\"");
if (is_finished) {
*obuf = l.full.b;
free(l.buf);
return l.len;
if (needs_strip) {
int len = l.full.len - 6;
*obuf = strndup(l.full.b + 3, len);
abFree(&l.full);
free(l.buf);
return len;
} else {
*obuf = l.full.b;
free(l.buf);
return l.full.len;
}
} else {
l.prompt = "... ";
abAppends(&l.full, "\n");
Expand Down
78 changes: 24 additions & 54 deletions llamafile/chatbot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <assert.h>
#include <cosmo.h>
#include <ctype.h>
#include <math.h>
#include <signal.h>
#include <stdio.h>
Expand All @@ -34,7 +35,15 @@ static void on_sigint(int sig) {
g_got_sigint = 1;
}

static std::string basename(const std::string_view path) noexcept {
static bool is_empty(const char *s) {
int c;
while ((c = *s++))
if (!isspace(c))
return false;
return true;
}

static std::string basename(const std::string_view path) {
size_t i, e;
if ((e = path.size())) {
while (e > 1 && path[e - 1] == '/')
Expand All @@ -48,47 +57,6 @@ static std::string basename(const std::string_view path) noexcept {
}
}

std::string apply_template(const struct llama_model *model, const char *tmpl,
const struct llama_chat_message *chat, size_t n_msg, bool add_ass) {

// allocate buffer
int alloc_size = 0;
for (size_t i = 0; i < n_msg; ++i)
alloc_size += (strlen(chat[i].role) + strlen(chat[i].content)) * 2;
std::string buf;
buf.resize(alloc_size);

// run the first time to get the total output length
int res = llama_chat_apply_template(model, tmpl, chat, n_msg, add_ass, &buf[0], buf.size());

// error: chat template is not supported
bool fallback = false;
if (res < 0) {
if (tmpl != nullptr) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom
// template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
} else {
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", chat, n_msg, add_ass, &buf[0],
buf.size());
fallback = true;
}
}

// if it turns out that our buffer is too small, we resize it
if ((size_t)res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(fallback ? nullptr : model, fallback ? "chatml" : tmpl,
chat, n_msg, add_ass, &buf[0], buf.size());
assert(0 <= res && res <= buf.size());
}

buf.resize(res);
return buf;
}

static bool eval_tokens(struct llama_context *ctx_llama, std::vector<llama_token> tokens,
int n_batch, int *n_past) {
int N = (int)tokens.size();
Expand Down Expand Up @@ -136,13 +104,13 @@ int main(int argc, char **argv) {
if (!gpt_params_parse(argc, argv, params))
return 1;

printf("\n\
printf("\n\e[32m\
██╗ ██╗ █████╗ ███╗ ███╗ █████╗ ███████╗██╗██╗ ███████╗\n\
██║ ██║ ██╔══██╗████╗ ████║██╔══██╗██╔════╝██║██║ ██╔════╝\n\
██║ ██║ ███████║██╔████╔██║███████║█████╗ ██║██║ █████╗\n\
██║ ██║ ██╔══██║██║╚██╔╝██║██╔══██║██╔══╝ ██║██║ ██╔══╝\n\
███████╗███████╗██║ ██║██║ ╚═╝ ██║██║ ██║██║ ██║███████╗███████╗\n\
╚══════╝╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝╚══════╝\n\
╚══════╝╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝╚══════╝\e[39m\n\
\e[1msoftware\e[22m: llamafile " LLAMAFILE_VERSION_STRING "\n\
\e[1mmodel\e[22m: %s\n\n",
basename(params.model).c_str());
Expand All @@ -166,20 +134,19 @@ int main(int argc, char **argv) {
return 3;
clear_ephemeral();

if (params.prompt.empty()) {
if (params.prompt.empty())
params.prompt =
"A chat between a curious human and an artificial intelligence assistant. The "
"assistant gives helpful, detailed, and polite answers to the human's questions.";
}

print_ephemeral("loading prompt...");
int n_past = 0;
bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
llama_chat_message chat = {"system", params.prompt.c_str()};
std::string msg = apply_template(model, nullptr, &chat, 1, false);
eval_string(ctx, msg.c_str(), params.n_batch, &n_past, true, true);
std::vector<llama_chat_msg> chat = {{"system", params.prompt}};
std::string msg = llama_chat_apply_template(model, params.chat_template, chat, false);
eval_string(ctx, msg.c_str(), params.n_batch, &n_past, add_bos, true);
clear_ephemeral();
printf("%s\n", params.prompt.c_str());
printf("%s\n", params.special ? msg.c_str() : params.prompt.c_str());

// perform important setup
struct llama_sampling_context *ctx_sampling = llama_sampling_init(params.sparams);
Expand All @@ -189,16 +156,19 @@ int main(int argc, char **argv) {
char *line;
bestlineLlamaMode(true);
while ((line = bestlineWithHistory(">>> ", "llamafile"))) {
size_t buflen = strlen(line) * 2 + 1024;
llama_chat_message chat = {"user", line};
std::string msg = apply_template(model, nullptr, &chat, 1, true);
if (is_empty(line)) {
free(line);
continue;
}
std::vector<llama_chat_msg> chat = {{"user", line}};
std::string msg = llama_chat_apply_template(model, params.chat_template, chat, true);
eval_string(ctx, msg.c_str(), params.n_batch, &n_past, false, true);
while (!g_got_sigint) {
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL);
llama_sampling_accept(ctx_sampling, ctx, id, true);
printf("%s", llama_token_to_piece(ctx, id, params.special).c_str());
if (llama_token_is_eog(model, id))
break;
printf("%s", llama_token_to_piece(ctx, id).c_str(), false);
fflush(stdout);
if (!eval_id(ctx, id, &n_past)) {
fprintf(stderr, "[out of context]\n");
Expand Down

0 comments on commit aca02fa

Please sign in to comment.