Skip to content

Commit

Permalink
sqlite, test: expose sqlite online backup api
Browse files Browse the repository at this point in the history
  • Loading branch information
geeksilva97 committed Jan 12, 2025
1 parent 7409a1d commit 2aee11d
Show file tree
Hide file tree
Showing 3 changed files with 381 additions and 0 deletions.
296 changes: 296 additions & 0 deletions src/node_sqlite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "node_errors.h"
#include "node_mem-inl.h"
#include "sqlite3.h"
#include "threadpoolwork-inl.h"
#include "util-inl.h"

#include <cinttypes>
Expand All @@ -29,6 +30,7 @@ using v8::FunctionCallback;
using v8::FunctionCallbackInfo;
using v8::FunctionTemplate;
using v8::Global;
using v8::HandleScope;
using v8::Int32;
using v8::Integer;
using v8::Isolate;
Expand All @@ -40,6 +42,7 @@ using v8::NewStringType;
using v8::Null;
using v8::Number;
using v8::Object;
using v8::Promise;
using v8::SideEffectType;
using v8::String;
using v8::TryCatch;
Expand Down Expand Up @@ -81,6 +84,24 @@ inline MaybeLocal<Object> CreateSQLiteError(Isolate* isolate,
return e;
}

inline MaybeLocal<Object> CreateSQLiteError(Isolate* isolate, int errcode) {
const char* errstr = sqlite3_errstr(errcode);
Local<String> js_errmsg;
Local<Object> e;
Environment* env = Environment::GetCurrent(isolate);
if (!String::NewFromUtf8(isolate, errstr).ToLocal(&js_errmsg) ||
!CreateSQLiteError(isolate, errstr).ToLocal(&e) ||
e->Set(isolate->GetCurrentContext(),
env->errcode_string(),
Integer::New(isolate, errcode))
.IsNothing() ||
e->Set(isolate->GetCurrentContext(), env->errstr_string(), js_errmsg)
.IsNothing()) {
return MaybeLocal<Object>();
}
return e;
}

inline MaybeLocal<Object> CreateSQLiteError(Isolate* isolate, sqlite3* db) {
int errcode = sqlite3_extended_errcode(db);
const char* errstr = sqlite3_errstr(errcode);
Expand Down Expand Up @@ -128,6 +149,171 @@ inline void THROW_ERR_SQLITE_ERROR(Isolate* isolate, int errcode) {
isolate->ThrowException(error);
}

class BackupJob : public ThreadPoolWork {
public:
explicit BackupJob(Environment* env,
DatabaseSync* source,
Local<Promise::Resolver> resolver,
std::string source_db,
std::string destination_name,
std::string dest_db,
int pages,
Local<Function> progressFunc)
: ThreadPoolWork(env, "node_sqlite3.BackupJob"),
env_(env),
source_(source),
source_db_(source_db),
destination_name_(destination_name),
dest_db_(dest_db),
pages_(pages) {
resolver_.Reset(env->isolate(), resolver);
progressFunc_.Reset(env->isolate(), progressFunc);
}

void ScheduleBackup() {
Isolate* isolate = env()->isolate();
HandleScope handle_scope(isolate);

backup_status_ = sqlite3_open(destination_name_.c_str(), &pDest_);

Local<Promise::Resolver> resolver =
Local<Promise::Resolver>::New(env()->isolate(), resolver_);

Local<Object> e = Local<Object>();

if (backup_status_ != SQLITE_OK) {
CreateSQLiteError(isolate, pDest_).ToLocal(&e);

Cleanup();

resolver->Reject(env()->context(), e).ToChecked();

return;
}

pBackup_ = sqlite3_backup_init(
pDest_, dest_db_.c_str(), source_->Connection(), source_db_.c_str());

if (pBackup_ == nullptr) {
CreateSQLiteError(isolate, pDest_).ToLocal(&e);

sqlite3_close(pDest_);

resolver->Reject(env()->context(), e).ToChecked();

return;
}

this->ScheduleWork();
}

void DoThreadPoolWork() override {
backup_status_ = sqlite3_backup_step(pBackup_, pages_);

const char* errstr = sqlite3_errstr(backup_status_);
}

void AfterThreadPoolWork(int status) override {
HandleScope handle_scope(env()->isolate());

if (resolver_.IsEmpty()) {
Cleanup();

return;
}

Local<Promise::Resolver> resolver =
Local<Promise::Resolver>::New(env()->isolate(), resolver_);

if (!(backup_status_ == SQLITE_OK || backup_status_ == SQLITE_DONE ||
backup_status_ == SQLITE_BUSY || backup_status_ == SQLITE_LOCKED)) {
Local<Object> e = Local<Object>();

CreateSQLiteError(env()->isolate(), backup_status_).ToLocal(&e);

Cleanup();

resolver->Reject(env()->context(), e).ToChecked();

return;
}

int total_pages = sqlite3_backup_pagecount(pBackup_);
int remaining_pages = sqlite3_backup_remaining(pBackup_);

if (remaining_pages != 0) {
Local<Function> fn =
Local<Function>::New(env()->isolate(), progressFunc_);

if (!fn.IsEmpty()) {
Local<Value> argv[] = {
Integer::New(env()->isolate(), total_pages),
Integer::New(env()->isolate(), remaining_pages),
};

TryCatch try_catch(env()->isolate());
fn->Call(env()->context(), Null(env()->isolate()), 2, argv)
.FromMaybe(Local<Value>());

if (try_catch.HasCaught()) {
Cleanup();

resolver->Reject(env()->context(), try_catch.Exception()).ToChecked();

return;
}
}

// There's still work to do
this->ScheduleWork();

return;
}

Local<String> message =
String::NewFromUtf8(
env()->isolate(), "Backup completed", NewStringType::kNormal)
.ToLocalChecked();

Local<Object> e = Local<Object>();
CreateSQLiteError(env()->isolate(), pDest_).ToLocal(&e);

Cleanup();

if (backup_status_ == SQLITE_OK) {
resolver->Resolve(env()->context(), message).ToChecked();
} else {
resolver->Reject(env()->context(), e).ToChecked();
}
}

private:
void Cleanup() {
if (pBackup_) {
sqlite3_backup_finish(pBackup_);
}

if (pDest_) {
backup_status_ = sqlite3_errcode(pDest_);
sqlite3_close(pDest_);
}
}

// https://github.com/nodejs/node/blob/649da3b8377e030ea7b9a1bc0308451e26e28740/src/crypto/crypto_keygen.h#L126
int backup_status_;
Environment* env() const { return env_; }
sqlite3* pDest_;
sqlite3_backup* pBackup_;
Environment* env_;
DatabaseSync* source_;
Global<Promise::Resolver> resolver_;
Global<Function> progressFunc_;
std::string source_db_;
std::string destination_name_;
std::string dest_db_;
int pages_;
};

class UserDefinedFunction {
public:
explicit UserDefinedFunction(Environment* env,
Expand Down Expand Up @@ -533,6 +719,115 @@ void DatabaseSync::Exec(const FunctionCallbackInfo<Value>& args) {
CHECK_ERROR_OR_THROW(env->isolate(), db->connection_, r, SQLITE_OK, void());
}

// database.backup(destination, { sourceDb, targetDb, rate, progress: (total,
// remaining) => {} )
void DatabaseSync::Backup(const FunctionCallbackInfo<Value>& args) {
Environment* env = Environment::GetCurrent(args);

if (!args[0]->IsString()) {
THROW_ERR_INVALID_ARG_TYPE(
env->isolate(), "The \"destination\" argument must be a string.");
return;
}

int rate = 100;
std::string source_db = "main";
std::string dest_db = "main";

DatabaseSync* db;
ASSIGN_OR_RETURN_UNWRAP(&db, args.This());

THROW_AND_RETURN_ON_BAD_STATE(env, !db->IsOpen(), "database is not open");

Utf8Value destFilename(env->isolate(), args[0].As<String>());
Local<Function> progressFunc = Local<Function>();

if (args.Length() > 1) {
if (!args[1]->IsObject()) {
THROW_ERR_INVALID_ARG_TYPE(env->isolate(),
"The \"options\" argument must be an object.");
return;
}

Local<Object> options = args[1].As<Object>();
Local<String> progress_string =
FIXED_ONE_BYTE_STRING(env->isolate(), "progress");
Local<String> rate_string = FIXED_ONE_BYTE_STRING(env->isolate(), "rate");
Local<String> target_db_string =
FIXED_ONE_BYTE_STRING(env->isolate(), "targetDb");
Local<String> source_db_string =
FIXED_ONE_BYTE_STRING(env->isolate(), "sourceDb");

Local<Value> rateValue =
options->Get(env->context(), rate_string).ToLocalChecked();

if (!rateValue->IsUndefined()) {
if (!rateValue->IsInt32()) {
THROW_ERR_INVALID_ARG_TYPE(
env->isolate(),
"The \"options.rate\" argument must be an integer.");
return;
}

rate = rateValue.As<Int32>()->Value();
}

Local<Value> sourceDbValue =
options->Get(env->context(), source_db_string).ToLocalChecked();

if (!sourceDbValue->IsUndefined()) {
if (!sourceDbValue->IsString()) {
THROW_ERR_INVALID_ARG_TYPE(
env->isolate(),
"The \"options.sourceDb\" argument must be a string.");
return;
}

source_db =
Utf8Value(env->isolate(), sourceDbValue.As<String>()).ToString();
}

Local<Value> targetDbValue =
options->Get(env->context(), target_db_string).ToLocalChecked();

if (!targetDbValue->IsUndefined()) {
if (!targetDbValue->IsString()) {
THROW_ERR_INVALID_ARG_TYPE(
env->isolate(),
"The \"options.targetDb\" argument must be a string.");
return;
}

dest_db =
Utf8Value(env->isolate(), targetDbValue.As<String>()).ToString();
}

Local<Value> progressValue =
options->Get(env->context(), progress_string).ToLocalChecked();

if (!progressValue->IsUndefined()) {
if (!progressValue->IsFunction()) {
THROW_ERR_INVALID_ARG_TYPE(
env->isolate(),
"The \"options.progress\" argument must be a function.");
return;
}

progressFunc = progressValue.As<Function>();
}
}

Local<Promise::Resolver> resolver = Promise::Resolver::New(env->context())
.ToLocalChecked()
.As<Promise::Resolver>();

args.GetReturnValue().Set(resolver->GetPromise());

BackupJob* job = new BackupJob(
env, db, resolver, source_db, *destFilename, dest_db, rate, progressFunc);
job->ScheduleBackup();
}

void DatabaseSync::CustomFunction(const FunctionCallbackInfo<Value>& args) {
DatabaseSync* db;
ASSIGN_OR_RETURN_UNWRAP(&db, args.This());
Expand Down Expand Up @@ -1718,6 +2013,7 @@ static void Initialize(Local<Object> target,
SetProtoMethod(isolate, db_tmpl, "close", DatabaseSync::Close);
SetProtoMethod(isolate, db_tmpl, "prepare", DatabaseSync::Prepare);
SetProtoMethod(isolate, db_tmpl, "exec", DatabaseSync::Exec);
SetProtoMethod(isolate, db_tmpl, "backup", DatabaseSync::Backup);
SetProtoMethod(isolate, db_tmpl, "function", DatabaseSync::CustomFunction);
SetProtoMethod(
isolate, db_tmpl, "createSession", DatabaseSync::CreateSession);
Expand Down
2 changes: 2 additions & 0 deletions src/node_sqlite.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class DatabaseSync : public BaseObject {
static void Close(const v8::FunctionCallbackInfo<v8::Value>& args);
static void Prepare(const v8::FunctionCallbackInfo<v8::Value>& args);
static void Exec(const v8::FunctionCallbackInfo<v8::Value>& args);
static void Backup(const v8::FunctionCallbackInfo<v8::Value>& args);
static void CustomFunction(const v8::FunctionCallbackInfo<v8::Value>& args);
static void CreateSession(const v8::FunctionCallbackInfo<v8::Value>& args);
static void ApplyChangeset(const v8::FunctionCallbackInfo<v8::Value>& args);
Expand All @@ -81,6 +82,7 @@ class DatabaseSync : public BaseObject {
bool enable_load_extension_;
sqlite3* connection_;

std::set<sqlite3_backup*> backups_;
std::set<sqlite3_session*> sessions_;
std::unordered_set<StatementSync*> statements_;

Expand Down
Loading

0 comments on commit 2aee11d

Please sign in to comment.