Skip to content

Commit

Permalink
Begin setting up chat history database
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Nov 29, 2024
1 parent d8123c7 commit abe0d1d
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 12 deletions.
98 changes: 98 additions & 0 deletions llamafile/db.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "db.h"
#include <stdio.h>
#include <string>

__static_yoink("llamafile/schema.sql");

#define SCHEMA_VERSION 1

namespace llamafile {
namespace db {

static bool table_exists(sqlite3* db, const char* table_name) {
const char* query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;";
sqlite3_stmt* stmt;
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
return false;
}
if (sqlite3_bind_text(stmt, 1, table_name, -1, SQLITE_STATIC) != SQLITE_OK) {
sqlite3_finalize(stmt);
return false;
}
bool exists = sqlite3_step(stmt) == SQLITE_ROW;
sqlite3_finalize(stmt);
return exists;
}

static bool init_schema(sqlite3* db) {
FILE* f = fopen("/zip/llamafile/schema.sql", "r");
if (!f)
return false;
std::string schema;
int c;
while ((c = fgetc(f)) != EOF)
schema += c;
fclose(f);
char* errmsg = nullptr;
int rc = sqlite3_exec(db, schema.c_str(), nullptr, nullptr, &errmsg);
if (rc != SQLITE_OK) {
if (errmsg) {
fprintf(stderr, "SQL error: %s\n", errmsg);
sqlite3_free(errmsg);
}
return false;
}
return true;
}

sqlite3* open(const char* path) {
sqlite3* db;
int rc = sqlite3_open(path, &db);
if (rc) {
fprintf(stderr, "%s: can't open database: %s\n", path, sqlite3_errmsg(db));
return nullptr;
}
char* errmsg = nullptr;
if (sqlite3_exec(db, "PRAGMA journal_mode=WAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) {
fprintf(stderr, "Failed to set journal mode to WAL: %s\n", errmsg);
sqlite3_free(errmsg);
sqlite3_close(db);
return nullptr;
}
if (sqlite3_exec(db, "PRAGMA synchronous=NORMAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) {
fprintf(stderr, "Failed to set synchronous to NORMAL: %s\n", errmsg);
sqlite3_free(errmsg);
sqlite3_close(db);
return nullptr;
}
if (!table_exists(db, "metadata") && !init_schema(db)) {
fprintf(stderr, "%s: failed to initialize database schema\n", path);
sqlite3_close(db);
return nullptr;
}
return db;
}

void close(sqlite3* db) {
sqlite3_close(db);
}

} // namespace db
} // namespace llamafile
28 changes: 28 additions & 0 deletions llamafile/db.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "third_party/sqlite/sqlite3.h"

namespace llamafile {
namespace db {

sqlite3* open(const char*);
void close(sqlite3*);

} // namespace db
} // namespace llamafile
8 changes: 8 additions & 0 deletions llamafile/flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ bool FLAG_tinyblas = false;
bool FLAG_trace = false;
bool FLAG_unsecure = false;
const char *FLAG_chat_template = "";
const char *FLAG_db = nullptr;
const char *FLAG_file = nullptr;
const char *FLAG_ip_header = nullptr;
const char *FLAG_listen = "127.0.0.1:8080";
Expand Down Expand Up @@ -185,6 +186,13 @@ void llamafile_get_flags(int argc, char **argv) {
continue;
}

if (!strcmp(flag, "--db")) {
if (i == argc)
missing("--db");
FLAG_db = argv[i++];
continue;
}

//////////////////////////////////////////////////////////////////////
// server flags

Expand Down
1 change: 1 addition & 0 deletions llamafile/llamafile.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ extern bool FLAG_trace;
extern bool FLAG_trap;
extern bool FLAG_unsecure;
extern const char *FLAG_chat_template;
extern const char *FLAG_db;
extern const char *FLAG_file;
extern const char *FLAG_ip_header;
extern const char *FLAG_listen;
Expand Down
24 changes: 24 additions & 0 deletions llamafile/schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
CREATE TABLE metadata (
key TEXT PRIMARY KEY,
value TEXT
);

CREATE TABLE chats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
model TEXT,
title TEXT
);

CREATE TABLE messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
chat_id INTEGER,
role TEXT,
message TEXT,
temperature REAL,
top_p REAL,
presence_penalty REAL,
frequency_penalty REAL,
FOREIGN KEY (chat_id) REFERENCES chats(id)
);
7 changes: 7 additions & 0 deletions third_party/sqlite/BUILD.mk
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

PKGS += THIRD_PARTY_SQLITE

THIRD_PARTY_SQLITE_SRCS = \
third_party/sqlite/sqlite3.c \
third_party/sqlite/shell.c \

THIRD_PARTY_SQLITE_HDRS = \
third_party/sqlite/sqlite3.h \

o/$(MODE)/third_party/sqlite/sqlite.a: \
o/$(MODE)/third_party/sqlite/sqlite3.o \

Expand Down
1 change: 1 addition & 0 deletions third_party/sqlite/README.llamafile
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ LICENSE
LOCAL CHANGES

- Renamed <zlib.h> to <third_party/zlib/zlib.h>
- Mangled some quoted includes to not confuse mkdeps
12 changes: 6 additions & 6 deletions third_party/sqlite/shell.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ typedef sqlite3_int64 i64;
typedef sqlite3_uint64 u64;
typedef unsigned char u8;
#if SQLITE_USER_AUTHENTICATION
# include "sqlite3userauth.h"
# includez "sqlite3userauth.h"
#endif
#include <ctype.h>
#include <stdarg.h>
Expand Down Expand Up @@ -169,7 +169,7 @@ typedef unsigned char u8;

#elif HAVE_LINENOISE

# include "linenoise.h"
# includez "linenoise.h"
# define shell_add_history(X) linenoiseHistoryAdd(X)
# define shell_read_history(X) linenoiseHistoryLoad(X)
# define shell_write_history(X) linenoiseHistorySave(X)
Expand Down Expand Up @@ -1710,7 +1710,7 @@ static void shellAddSchemaName(
#define WIN32_LEAN_AND_MEAN
#endif

#include "windows.h"
#includez "windows.h"

/*
** We need several support functions from the SQLite core.
Expand Down Expand Up @@ -7996,10 +7996,10 @@ SQLITE_EXTENSION_INIT1
# include <utime.h>
# include <sys/time.h>
#else
# include "windows.h"
# includez "windows.h"
# include <io.h>
# include <direct.h>
/* # include "test_windirent.h" */
/* # includez "test_windirent.h" */
# define dirent DIRENT
# ifndef chmod
# define chmod _chmod
Expand Down Expand Up @@ -8945,7 +8945,7 @@ int sqlite3_fileio_init(
* redefined SQLite API calls as the above extension code does.
* Just pull in this .c to accomplish this. As a beneficial side
* effect, this extension becomes a single translation unit. */
# include "test_windirent.c"
# includez "test_windirent.c"
#endif

/************************* End ../ext/misc/fileio.c ********************/
Expand Down
12 changes: 6 additions & 6 deletions third_party/sqlite/sqlite3.c
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,9 @@
** disabled.
*/
#if defined(_HAVE_MINGW_H)
# include "mingw.h"
# includez "mingw.h"
#elif defined(_HAVE__MINGW_H)
# include "_mingw.h"
# includez "_mingw.h"
#endif

/*
Expand Down Expand Up @@ -13911,7 +13911,7 @@ struct fts5_api {
** autoconf-based build
*/
#if defined(_HAVE_SQLITE_CONFIG_H) && !defined(SQLITECONFIG_H)
#include "sqlite_cfg.h"
#includez "sqlite_cfg.h"
#define SQLITECONFIG_H 1
#endif

Expand Down Expand Up @@ -29996,7 +29996,7 @@ SQLITE_PRIVATE sqlite3_mutex_methods const *sqlite3DefaultMutex(void){
/*
** Include the primary Windows SDK header file.
*/
#include "windows.h"
#includez "windows.h"

#ifdef __CYGWIN__
# include <sys/cygwin.h>
Expand Down Expand Up @@ -196803,7 +196803,7 @@ SQLITE_PRIVATE int sqlite3Fts3InitTokenizer(

#ifdef SQLITE_TEST

#include "tclsqlite.h"
#includez "tclsqlite.h"
/* #include <string.h> */

/*
Expand Down Expand Up @@ -211715,7 +211715,7 @@ SQLITE_PRIVATE int sqlite3GetToken(const unsigned char*,int*); /* In the SQLite
** found in sqliteInt.h
*/
#if !defined(SQLITE_AMALGAMATION)
#include "sqlite3rtree.h"
#includez "sqlite3rtree.h"
typedef sqlite3_int64 i64;
typedef sqlite3_uint64 u64;
typedef unsigned char u8;
Expand Down

0 comments on commit abe0d1d

Please sign in to comment.