Skip to content

Commit

Permalink
Replace warp with axum
Browse files Browse the repository at this point in the history
Set llm as optional feature
  • Loading branch information
Endle authored Aug 18, 2024
1 parent 8b43c01 commit fad2a13
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 59 deletions.
15 changes: 11 additions & 4 deletions fire_seq_search_server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@ license = "MIT"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
default = ["llm"]
llm = []

[dependencies]
# Http Client

tokio = { version = "1", features = ["full"] }
warp = "0.3"

# Http Client
axum = "0.7.5"
serde_json = "1.0"

# Serde
# https://serde.rs/derive.html
# https://stackoverflow.com/a/49313680/1166518
Expand All @@ -21,8 +28,8 @@ url = "2.3.1"
tantivy = "0.18"


log = "0.4.0"
env_logger = "0.9.0"
log = "0.4.22"
env_logger = "0.11.5"

# Rust
clap = { version = "4.0", features = ["derive"] }
Expand Down
28 changes: 19 additions & 9 deletions fire_seq_search_server/src/http_client/endpoints.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
use std::sync::Arc;
use log::debug;
use crate::query_engine::QueryEngine;
use serde_json;

pub fn get_server_info(engine_arc: Arc<QueryEngine>) -> String {
serde_json::to_string( &engine_arc.server_info ).unwrap()
use crate::query_engine::{QueryEngine, ServerInformation};
use axum::Json;
use axum::extract::State;
use axum::{response::Html, routing::get, Router, extract::Path};

pub async fn get_server_info(State(engine_arc): State<Arc<QueryEngine>>)
-> Json<ServerInformation> {
axum::Json( engine_arc.server_info.to_owned() )
}
pub fn query(term: String, engine_arc: Arc<QueryEngine>)
-> String {

pub async fn query(
Path(term) : Path<String>,
State(engine_arc): State<Arc<QueryEngine>>
) -> Html<String>{

debug!("Original Search term {}", term);
engine_arc.query_pipeline(term)
let r = engine_arc.query_pipeline(term);
Html(r)
}


pub fn generate_word_cloud(engine_arc: Arc<QueryEngine>) -> String {
pub async fn generate_word_cloud(State(engine_arc): State<Arc<QueryEngine>>)
-> Html<String> {
let div_id = "fireSeqSearchWordcloudRawJson";
let json = engine_arc.generate_wordcloud();

let div = format!("<div id=\"{}\">{}</div>", div_id, json);
div
Html(div)
}

3 changes: 2 additions & 1 deletion fire_seq_search_server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ pub fn generate_server_info_for_test() -> ServerInformation {
parse_pdf_links: false,
exclude_zotero_items: false,
obsidian_md: false,
convert_underline_hierarchy: true
convert_underline_hierarchy: true,
host: "127.0.0.1:22024".to_string(),
};
server_info
}
Expand Down
71 changes: 27 additions & 44 deletions fire_seq_search_server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::net::SocketAddr;

use warp::Filter;
use log::info;
use fire_seq_search_server::query_engine::{QueryEngine, ServerInformation};
use fire_seq_search_server::local_llm::LlmEngine;
Expand Down Expand Up @@ -48,65 +47,47 @@ struct Cli{

use tokio::task;

use axum;
use axum::routing::get;
use fire_seq_search_server::http_client::endpoints;

#[tokio::main]
async fn main() {
env_logger::builder()
.format_timestamp(None)
.format_target(false)
.init();

let llm = task::spawn( async { LlmEngine::llm_init().await });
//let llm = llm.await.unwrap();
//llm.summarize("hi my friend").await;
let mut llm_loader = None;
if cfg!(feature="llm") {
info!("LLM Enabled");
//tokio::task::JoinHandle<LlmEngine>
llm_loader = Some(task::spawn( async { LlmEngine::llm_init().await }));
}

info!("main thread running");
let matches = Cli::parse();
let host: String = matches.host.clone().unwrap_or_else(|| "127.0.0.1:3030".to_string());
let host: SocketAddr = host.parse().unwrap_or_else(
|_| panic!("Invalid host: {}", host)
);
let server_info: ServerInformation = build_server_info(matches);
let engine = QueryEngine::construct(server_info);

let mut engine = QueryEngine::construct(server_info);
if cfg!(feature="llm") {
let llm:LlmEngine = llm_loader.unwrap().await.unwrap();
engine.llm = Some(llm);
}

let engine_arc = std::sync::Arc::new(engine);
let arc_for_query = engine_arc.clone();
let call_query = warp::path!("query" / String)
.map(move |name| {
fire_seq_search_server::http_client::endpoints::query(
name, arc_for_query.clone() )
});

let arc_for_server_info = engine_arc.clone();
let get_server_info = warp::path("server_info")
.map(move ||
fire_seq_search_server::http_client::endpoints::get_server_info(
arc_for_server_info.clone()
));

let arc_for_wordcloud = engine_arc.clone();
let create_word_cloud = warp::path("wordcloud")
.map(move || {
let div = fire_seq_search_server::http_client::endpoints::generate_word_cloud(
arc_for_wordcloud.clone()
);
warp::http::Response::builder()
.header("content-type", "text/html; charset=utf-8")
.body(div)
// .status(warp::http::StatusCode::OK)
});

let routes = warp::get().and(
call_query
.or(get_server_info)
.or(create_word_cloud)
);
warp::serve(routes)
.run(host)
.await;


let app = axum::Router::new()
.route("/query/:term", get(endpoints::query))
.route("/server_info", get(endpoints::get_server_info))
.route("/wordcloud", get(endpoints::generate_word_cloud))
.with_state(engine_arc.clone());

let listener = tokio::net::TcpListener::bind(&engine_arc.server_info.host)
.await.unwrap();
axum::serve(listener, app).await.unwrap();
// let llm = llm.await.unwrap();
//llm.summarize("hi my friend").await;
}


Expand All @@ -123,6 +104,7 @@ fn build_server_info(args: Cli) -> ServerInformation {
String::from(guess)
}
};
let host: String = args.host.clone().unwrap_or_else(|| "127.0.0.1:3030".to_string());
ServerInformation{
notebook_path: args.notebook_path,
notebook_name,
Expand All @@ -134,6 +116,7 @@ fn build_server_info(args: Cli) -> ServerInformation {
exclude_zotero_items:args.exclude_zotero_items,
obsidian_md: args.obsidian_md,
convert_underline_hierarchy: true,
host,
}
}

Expand Down
7 changes: 6 additions & 1 deletion fire_seq_search_server/src/query_engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::post_query::post_query_wrapper;



// This struct should be immutable when the program starts running
#[derive(Debug, Clone, serde::Serialize)]
pub struct ServerInformation {
pub notebook_path: String,
Expand All @@ -18,21 +19,24 @@ pub struct ServerInformation {
pub exclude_zotero_items:bool,
pub obsidian_md: bool,


/// Experimental. Not sure if I should use this global config - 2022-12-30
pub convert_underline_hierarchy: bool,

pub host: String,
}

struct DocumentSetting {
schema: tantivy::schema::Schema,
tokenizer: JiebaTokenizer,
}

use crate::local_llm::LlmEngine;
pub struct QueryEngine {
pub server_info: ServerInformation,
reader: tantivy::IndexReader,
query_parser: tantivy::query::QueryParser,
articles: Vec<Article>,
pub llm: Option<LlmEngine>,
}

impl QueryEngine {
Expand All @@ -50,6 +54,7 @@ impl QueryEngine {
reader,
query_parser,
articles: loaded_articles,
llm: None,
}
}

Expand Down

0 comments on commit fad2a13

Please sign in to comment.