diff --git a/.gitignore b/.gitignore index 8ffe29584056..ef15c70d86eb 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,6 @@ release/ backend-assets/ prepare /ggml-metal.metal +target/ +Cargo.lock +model.bin diff --git a/backend/rust/Cargo.toml b/backend/rust/Cargo.toml new file mode 100644 index 000000000000..c9ec91b1559f --- /dev/null +++ b/backend/rust/Cargo.toml @@ -0,0 +1,8 @@ +[workspace] +resolver = "2" +members = [ + "bunker", + "backend", + "codegen", + "models", +] \ No newline at end of file diff --git a/backend/rust/Makefile b/backend/rust/Makefile new file mode 100644 index 000000000000..4213797500f8 --- /dev/null +++ b/backend/rust/Makefile @@ -0,0 +1,57 @@ +# default are fmt and then check +.DEFAULT_GOAL := all + +.PHONY: all +all: fmt check + +.PHONY: fmt +fmt: + @echo "Formatting code..." + @cargo fmt --all --check + +.PHONY: build +build: + @echo "Building..." + @cargo build --release + +.PHONY: test +test: + @echo "Testing..." + @cargo test --all + +.PHONY: check +check: + @echo "Checking..." + @cargo check --all + +.PHONY: clean +clean: + @echo "Cleaning..." + @cargo clean + +.PHONY: burn +burn: + @echo "Burning..." + @cargo run --bin server --package backend-burn + + + +############################################################################################################ +# gRPC testing commands + + +.PHONY: list +list: + @echo "Burning..." + @grpcurl -plaintext -import-path ../../../pkg/grpc/proto -proto backend.proto list backend.Backend + +.PHONY: health +health: + @echo "Burning..." + @grpcurl -plaintext -import-path ../../pkg/grpc/proto -proto backend.proto -d '{}' '[::1]:50051' backend.Backend/Health + + +.PHONY: status +status: + @echo "Burning..." + @grpcurl -plaintext -import-path ../../pkg/grpc/proto -proto backend.proto -d '{}' '[::1]:50051' backend.Backend/Status diff --git a/backend/rust/README.md b/backend/rust/README.md new file mode 100644 index 000000000000..ae1e2ff2eb9e --- /dev/null +++ b/backend/rust/README.md @@ -0,0 +1,54 @@ +## Here is a backend written in Rust for the LocalAI project + +Here are some rules for the Rust backend: +* Same proto file with the LocalAI's other backends, we should keep the same interface of the backend. +* `async` should be as the default way to write code. +* Streaming response should be supported. +* Only server side gRPC services are supported for current backend. +* The backend should also have metrics for monitoring. + + +### The information of the environment + +* cargo 1.73.0 (9c4383fb5 2023-08-26) +* rustup 1.26.0 (5af9b9484 2023-04-05) +* rustc 1.73.0 (cc66ad468 2023-10-03) + +## Build the development environment + +#### Protocol Buffers compiler + +Ubuntu or Debian + +``` +sudo apt update && sudo apt upgrade -y +sudo apt install -y protobuf-compiler libprotobuf-dev +``` + +macOS +``` +brew install protobuf +``` + +### Cargo fmt all the code + +``` +cargo fmt --all --check +``` + +### Check the gRPC backend status + +It will return base64 encoded string of the `OK`. + + +```bash +make burn + +make test +``` + +``` +{ + "message": "T0s=" +} +``` diff --git a/backend/rust/backend/Cargo.toml b/backend/rust/backend/Cargo.toml new file mode 100644 index 000000000000..767c349c8574 --- /dev/null +++ b/backend/rust/backend/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "backend" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[[bin]] +name = "server" +path = "src/main.rs" + +[dependencies] + +# import bunker here +bunker = { path = "../bunker" } +models = { path = "../models" } + +tokio = "1.33.0" +async-trait = "0.1.74" +tonic = "0.10.2" +tokio-stream = "0.1.14" + +tracing = "0.1" +tracing-subscriber = "0.3" +nix = { version="0.27.1", features=["resource"]} diff --git a/backend/rust/backend/Makefile b/backend/rust/backend/Makefile new file mode 100644 index 000000000000..12c9ff3073e0 --- /dev/null +++ b/backend/rust/backend/Makefile @@ -0,0 +1,19 @@ +.PHONY: check +check: + @echo "Checking..." + @cargo check + +.PHONY: fmt +fmt: + @echo "Formatting code..." + @cargo fmt + +.PHONY: build +build: + @echo "Building..." + @cargo build + +.PHONY: doc +doc: + @echo "Documenting..." + @cargo doc --no-deps --document-private-items --open diff --git a/backend/rust/backend/src/main.rs b/backend/rust/backend/src/main.rs new file mode 100644 index 000000000000..0b60212d5790 --- /dev/null +++ b/backend/rust/backend/src/main.rs @@ -0,0 +1,334 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex}; + +use bunker::pb::Result as PbResult; +use bunker::pb::{ + EmbeddingResult, GenerateImageRequest, HealthMessage, MemoryUsageData, ModelOptions, + PredictOptions, Reply, StatusResponse, TokenizationResponse, TranscriptRequest, + TranscriptResult, TtsRequest, +}; + +use bunker::BackendService; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; + +use async_trait::async_trait; + +use tracing::{event, span, Level}; + +use models::*; + + +/// TODO: In order to use the model, we need to add some common attributes like: model, device, tokenizer, embeddings, etc. +/// And these attributes should be thread safe. +#[derive(Default, Debug)] +pub struct BurnBackend; + +#[async_trait] +impl BackendService for BurnBackend { + type PredictStreamStream = ReceiverStream>; + + #[tracing::instrument] + async fn health(&self, request: Request) -> Result, Status> { + // return a Result,Status> + let reply = Reply { + message: "OK".into(), + }; + let res = Response::new(reply); + Ok(res) + } + + #[tracing::instrument] + async fn predict(&self, request: Request) -> Result, Status> { + // TODO: How to get model from load_model function? + todo!("How to get model from load_model function?") + } + + #[tracing::instrument] + async fn load_model( + &self, + request: Request, + ) -> Result, Status> { + let result = match request.get_ref().model.as_str() { + "mnist" => { + let model = MNINST::load_model(request.get_ref().clone()); + let result= PbResult { + message: format!("Model {} loaded successfully", request.get_ref().model), + success: true, + }; + Ok(Response::new(result)) + } + _ => { + let result = PbResult { + message: format!("Model {} not found", request.get_ref().model), + success: false, + }; + Ok(Response::new(result)) + } + }; + // TODO: add model to backend, how to transfer model to backend and let predict funciton can use it? + result + } + + #[tracing::instrument] + async fn predict_stream( + &self, + request: Request, + ) -> Result>>, Status> { + todo!() + } + + #[tracing::instrument] + async fn embedding( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn generate_image( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn audio_transcription( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn tts(&self, request: Request) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn tokenize_string( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn status( + &self, + request: Request, + ) -> Result, Status> { + // Here we do not need to cover the windows platform + let mut breakdown = HashMap::new(); + let mut memory_usage: u64 = 0; + + use nix::sys::resource::{getrusage, UsageWho}; + let usage = getrusage(UsageWho::RUSAGE_SELF).expect("Failed to fet usage"); + memory_usage = usage.as_ref().ru_maxrss as u64; + breakdown.insert("RSS".to_string(), memory_usage); + + let memory_usage = Option::from(MemoryUsageData { + total: memory_usage, + breakdown, + }); + + let reponse = StatusResponse { + state: 0, //TODO: add state https://github.com/mudler/LocalAI/blob/9b17af18b3aa0c3cab16284df2d6f691736c30c1/pkg/grpc/proto/backend.proto#L188C9-L188C9 + memory: memory_usage, + }; + + Ok(Response::new(reponse)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tonic::Request; + + #[tokio::test] + async fn test_health() { + let backend = BurnBackend::default(); + let request = Request::new(HealthMessage {}); + let response = backend.health(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + let message_str = String::from_utf8(response.get_ref().message.clone()).unwrap(); + assert_eq!(message_str, "OK"); + } + #[tokio::test] + async fn test_status() { + let backend = BurnBackend::default(); + let request = Request::new(HealthMessage {}); + let response = backend.status(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + let state = response.get_ref().state; + let memory = response.get_ref().memory.clone(); + assert_eq!(state, 0); + assert!(memory.is_some()); + } + + #[tokio::test] + async fn test_load_model() { + let backend = BurnBackend::default(); + let model_options = ModelOptions { + model: "test".to_string(), + context_size: 0, + seed: 0, + n_batch: 0, + f16_memory: false, + m_lock: false, + m_map: false, + vocab_only: false, + low_vram: false, + embeddings: false, + numa: false, + ngpu_layers: 0, + main_gpu: "".to_string(), + tensor_split: "".to_string(), + threads: 1, + library_search_path: "".to_string(), + rope_freq_base: 0.0, + rope_freq_scale: 0.0, + rms_norm_eps: 0.0, + ngqa: 0, + model_file: "models/src/mnist/model.bin".to_string(), + device: "".to_string(), + use_triton: false, + model_base_name: "".to_string(), + use_fast_tokenizer: false, + pipeline_type: "".to_string(), + scheduler_type: "".to_string(), + cuda: false, + cfg_scale: 0.0, + img2img: false, + clip_model: "".to_string(), + clip_subfolder: "".to_string(), + clip_skip: 0, + tokenizer: "".to_string(), + lora_base: "".to_string(), + lora_adapter: "".to_string(), + no_mul_mat_q: false, + draft_model: "".to_string(), + audio_path: "".to_string(), + quantization: "".to_string(), + }; + + // Load the wrong model + let request = Request::new(model_options.clone()); + let response = backend.load_model(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + let message_str = response.get_ref().message.clone(); + assert_eq!( + message_str, + format!("Model {} not found", model_options.model.clone()) + ); + + // Load the correct model + let mut model_options2=model_options.clone(); + model_options2.model="mnist".to_string(); + model_options2.model_file="models/src/mnist/model.bin".to_string(); + + let request = Request::new(model_options2.clone()); + let response = backend.load_model(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + let message_str = response.get_ref().message.clone(); + assert_eq!( + message_str, + format!("Model {} loaded successfully", model_options2.model.clone()) + ); + } + + #[tokio::test] + async fn test_predict() { + let backend = BurnBackend::default(); + let request = Request::new(PredictOptions { + prompt: "test".to_string(), + seed: 100, + threads: 1, + tokens: 10, + temperature: 0.0, + top_k: 0, + top_p: 0.0, + repeat: 0, + batch: 1, + n_keep: 0, + penalty: 0.0, + f16kv: false, + debug_mode: false, + stop_prompts: vec!["".to_string()], + ignore_eos: false, + tail_free_sampling_z: 0.0, + typical_p: 0.0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + mirostat: 0, + mirostat_eta: 0.0, + mirostat_tau: 0.0, + penalize_nl: false, + logit_bias: "".to_string(), + m_lock: false, + m_map: false, + prompt_cache_all: false, + prompt_cache_ro: false, + grammar: "".to_string(), + main_gpu: "".to_string(), + tensor_split: "".to_string(), + prompt_cache_path: "".to_string(), + debug: false, + embedding_tokens: vec![0], + embeddings: "".to_string(), + rope_freq_base: 0.0, + rope_freq_scale: 0.0, + negative_prompt_scale: 0.0, + negative_prompt: "".to_string(), + n_draft: 0, + }); + let response: Result, Status> = backend.predict(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + //TO_DO: add test for response + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = tracing_subscriber::fmt() + .compact() + .with_file(true) + .with_line_number(true) + .with_thread_ids(true) + .with_target(false) + .finish(); + + tracing::subscriber::set_global_default(subscriber)?; + + // call bunker::run with BurnBackend + let burn_backend = BurnBackend {}; + let addr = "[::1]:50051" + .parse::() + .expect("Failed to parse address"); + + // Implmenet Into for addr + let result = bunker::run(burn_backend, addr).await?; + + event!(Level::INFO, "Burn Server is starting"); + + let span = span!(Level::INFO, "Burn Server"); + let _enter = span.enter(); + + event!(Level::INFO, "Burn Server started successfully"); + + Ok(result) +} diff --git a/backend/rust/bunker/Cargo.toml b/backend/rust/bunker/Cargo.toml new file mode 100644 index 000000000000..d9d9d8ff505f --- /dev/null +++ b/backend/rust/bunker/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "bunker" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tonic = { version = "0.10", features = [] } +prost = "0.12" +tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "sync", "time"] } +tokio-stream = "0.1" + +async-stream = "0.2" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +rand = "0.7" +async-trait = "0.1.74" +tracing-subscriber = "0.3.17" +tracing = "0.1.39" diff --git a/backend/rust/bunker/generated/backend.rs b/backend/rust/bunker/generated/backend.rs new file mode 100644 index 000000000000..8e06344905d5 --- /dev/null +++ b/backend/rust/bunker/generated/backend.rs @@ -0,0 +1,966 @@ +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthMessage {} +/// The request message containing the user's name. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PredictOptions { + #[prost(string, tag = "1")] + pub prompt: ::prost::alloc::string::String, + #[prost(int32, tag = "2")] + pub seed: i32, + #[prost(int32, tag = "3")] + pub threads: i32, + #[prost(int32, tag = "4")] + pub tokens: i32, + #[prost(int32, tag = "5")] + pub top_k: i32, + #[prost(int32, tag = "6")] + pub repeat: i32, + #[prost(int32, tag = "7")] + pub batch: i32, + #[prost(int32, tag = "8")] + pub n_keep: i32, + #[prost(float, tag = "9")] + pub temperature: f32, + #[prost(float, tag = "10")] + pub penalty: f32, + #[prost(bool, tag = "11")] + pub f16kv: bool, + #[prost(bool, tag = "12")] + pub debug_mode: bool, + #[prost(string, repeated, tag = "13")] + pub stop_prompts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(bool, tag = "14")] + pub ignore_eos: bool, + #[prost(float, tag = "15")] + pub tail_free_sampling_z: f32, + #[prost(float, tag = "16")] + pub typical_p: f32, + #[prost(float, tag = "17")] + pub frequency_penalty: f32, + #[prost(float, tag = "18")] + pub presence_penalty: f32, + #[prost(int32, tag = "19")] + pub mirostat: i32, + #[prost(float, tag = "20")] + pub mirostat_eta: f32, + #[prost(float, tag = "21")] + pub mirostat_tau: f32, + #[prost(bool, tag = "22")] + pub penalize_nl: bool, + #[prost(string, tag = "23")] + pub logit_bias: ::prost::alloc::string::String, + #[prost(bool, tag = "25")] + pub m_lock: bool, + #[prost(bool, tag = "26")] + pub m_map: bool, + #[prost(bool, tag = "27")] + pub prompt_cache_all: bool, + #[prost(bool, tag = "28")] + pub prompt_cache_ro: bool, + #[prost(string, tag = "29")] + pub grammar: ::prost::alloc::string::String, + #[prost(string, tag = "30")] + pub main_gpu: ::prost::alloc::string::String, + #[prost(string, tag = "31")] + pub tensor_split: ::prost::alloc::string::String, + #[prost(float, tag = "32")] + pub top_p: f32, + #[prost(string, tag = "33")] + pub prompt_cache_path: ::prost::alloc::string::String, + #[prost(bool, tag = "34")] + pub debug: bool, + #[prost(int32, repeated, tag = "35")] + pub embedding_tokens: ::prost::alloc::vec::Vec, + #[prost(string, tag = "36")] + pub embeddings: ::prost::alloc::string::String, + #[prost(float, tag = "37")] + pub rope_freq_base: f32, + #[prost(float, tag = "38")] + pub rope_freq_scale: f32, + #[prost(float, tag = "39")] + pub negative_prompt_scale: f32, + #[prost(string, tag = "40")] + pub negative_prompt: ::prost::alloc::string::String, + #[prost(int32, tag = "41")] + pub n_draft: i32, +} +/// The response message containing the result +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Reply { + #[prost(bytes = "vec", tag = "1")] + pub message: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ModelOptions { + #[prost(string, tag = "1")] + pub model: ::prost::alloc::string::String, + #[prost(int32, tag = "2")] + pub context_size: i32, + #[prost(int32, tag = "3")] + pub seed: i32, + #[prost(int32, tag = "4")] + pub n_batch: i32, + #[prost(bool, tag = "5")] + pub f16_memory: bool, + #[prost(bool, tag = "6")] + pub m_lock: bool, + #[prost(bool, tag = "7")] + pub m_map: bool, + #[prost(bool, tag = "8")] + pub vocab_only: bool, + #[prost(bool, tag = "9")] + pub low_vram: bool, + #[prost(bool, tag = "10")] + pub embeddings: bool, + #[prost(bool, tag = "11")] + pub numa: bool, + #[prost(int32, tag = "12")] + pub ngpu_layers: i32, + #[prost(string, tag = "13")] + pub main_gpu: ::prost::alloc::string::String, + #[prost(string, tag = "14")] + pub tensor_split: ::prost::alloc::string::String, + #[prost(int32, tag = "15")] + pub threads: i32, + #[prost(string, tag = "16")] + pub library_search_path: ::prost::alloc::string::String, + #[prost(float, tag = "17")] + pub rope_freq_base: f32, + #[prost(float, tag = "18")] + pub rope_freq_scale: f32, + #[prost(float, tag = "19")] + pub rms_norm_eps: f32, + #[prost(int32, tag = "20")] + pub ngqa: i32, + #[prost(string, tag = "21")] + pub model_file: ::prost::alloc::string::String, + /// AutoGPTQ + #[prost(string, tag = "22")] + pub device: ::prost::alloc::string::String, + #[prost(bool, tag = "23")] + pub use_triton: bool, + #[prost(string, tag = "24")] + pub model_base_name: ::prost::alloc::string::String, + #[prost(bool, tag = "25")] + pub use_fast_tokenizer: bool, + /// Diffusers + #[prost(string, tag = "26")] + pub pipeline_type: ::prost::alloc::string::String, + #[prost(string, tag = "27")] + pub scheduler_type: ::prost::alloc::string::String, + #[prost(bool, tag = "28")] + pub cuda: bool, + #[prost(float, tag = "29")] + pub cfg_scale: f32, + #[prost(bool, tag = "30")] + pub img2img: bool, + #[prost(string, tag = "31")] + pub clip_model: ::prost::alloc::string::String, + #[prost(string, tag = "32")] + pub clip_subfolder: ::prost::alloc::string::String, + #[prost(int32, tag = "33")] + pub clip_skip: i32, + /// RWKV + #[prost(string, tag = "34")] + pub tokenizer: ::prost::alloc::string::String, + /// LLM (llama.cpp) + #[prost(string, tag = "35")] + pub lora_base: ::prost::alloc::string::String, + #[prost(string, tag = "36")] + pub lora_adapter: ::prost::alloc::string::String, + #[prost(bool, tag = "37")] + pub no_mul_mat_q: bool, + #[prost(string, tag = "39")] + pub draft_model: ::prost::alloc::string::String, + #[prost(string, tag = "38")] + pub audio_path: ::prost::alloc::string::String, + /// vllm + #[prost(string, tag = "40")] + pub quantization: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Result { + #[prost(string, tag = "1")] + pub message: ::prost::alloc::string::String, + #[prost(bool, tag = "2")] + pub success: bool, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct EmbeddingResult { + #[prost(float, repeated, tag = "1")] + pub embeddings: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TranscriptRequest { + #[prost(string, tag = "2")] + pub dst: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub language: ::prost::alloc::string::String, + #[prost(uint32, tag = "4")] + pub threads: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TranscriptResult { + #[prost(message, repeated, tag = "1")] + pub segments: ::prost::alloc::vec::Vec, + #[prost(string, tag = "2")] + pub text: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TranscriptSegment { + #[prost(int32, tag = "1")] + pub id: i32, + #[prost(int64, tag = "2")] + pub start: i64, + #[prost(int64, tag = "3")] + pub end: i64, + #[prost(string, tag = "4")] + pub text: ::prost::alloc::string::String, + #[prost(int32, repeated, tag = "5")] + pub tokens: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GenerateImageRequest { + #[prost(int32, tag = "1")] + pub height: i32, + #[prost(int32, tag = "2")] + pub width: i32, + #[prost(int32, tag = "3")] + pub mode: i32, + #[prost(int32, tag = "4")] + pub step: i32, + #[prost(int32, tag = "5")] + pub seed: i32, + #[prost(string, tag = "6")] + pub positive_prompt: ::prost::alloc::string::String, + #[prost(string, tag = "7")] + pub negative_prompt: ::prost::alloc::string::String, + #[prost(string, tag = "8")] + pub dst: ::prost::alloc::string::String, + #[prost(string, tag = "9")] + pub src: ::prost::alloc::string::String, + /// Diffusers + #[prost(string, tag = "10")] + pub enable_parameters: ::prost::alloc::string::String, + #[prost(int32, tag = "11")] + pub clip_skip: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TtsRequest { + #[prost(string, tag = "1")] + pub text: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub model: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub dst: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TokenizationResponse { + #[prost(int32, tag = "1")] + pub length: i32, + #[prost(int32, repeated, tag = "2")] + pub tokens: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MemoryUsageData { + #[prost(uint64, tag = "1")] + pub total: u64, + #[prost(map = "string, uint64", tag = "2")] + pub breakdown: ::std::collections::HashMap<::prost::alloc::string::String, u64>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StatusResponse { + #[prost(enumeration = "status_response::State", tag = "1")] + pub state: i32, + #[prost(message, optional, tag = "2")] + pub memory: ::core::option::Option, +} +/// Nested message and enum types in `StatusResponse`. +pub mod status_response { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum State { + Uninitialized = 0, + Busy = 1, + Ready = 2, + Error = -1, + } + impl State { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + State::Uninitialized => "UNINITIALIZED", + State::Busy => "BUSY", + State::Ready => "READY", + State::Error => "ERROR", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "UNINITIALIZED" => Some(Self::Uninitialized), + "BUSY" => Some(Self::Busy), + "READY" => Some(Self::Ready), + "ERROR" => Some(Self::Error), + _ => None, + } + } + } +} +/// Generated server implementations. +pub mod backend_server { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with BackendServer. + #[async_trait] + pub trait Backend: Send + Sync + 'static { + async fn health( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + async fn predict( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + async fn load_model( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the PredictStream method. + type PredictStreamStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, + > + + Send + + 'static; + async fn predict_stream( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + async fn embedding( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + async fn generate_image( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + async fn audio_transcription( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + async fn tts( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + async fn tokenize_string( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + async fn status( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + } + #[derive(Debug)] + pub struct BackendServer { + inner: _Inner, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + struct _Inner(Arc); + impl BackendServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + let inner = _Inner(inner); + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> for BackendServer + where + T: Backend, + B: Body + Send + 'static, + B::Error: Into + Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + let inner = self.inner.clone(); + match req.uri().path() { + "/backend.Backend/Health" => { + #[allow(non_camel_case_types)] + struct HealthSvc(pub Arc); + impl tonic::server::UnaryService + for HealthSvc { + type Response = super::Reply; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::health(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = HealthSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/Predict" => { + #[allow(non_camel_case_types)] + struct PredictSvc(pub Arc); + impl tonic::server::UnaryService + for PredictSvc { + type Response = super::Reply; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::predict(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = PredictSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/LoadModel" => { + #[allow(non_camel_case_types)] + struct LoadModelSvc(pub Arc); + impl tonic::server::UnaryService + for LoadModelSvc { + type Response = super::Result; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::load_model(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = LoadModelSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/PredictStream" => { + #[allow(non_camel_case_types)] + struct PredictStreamSvc(pub Arc); + impl< + T: Backend, + > tonic::server::ServerStreamingService + for PredictStreamSvc { + type Response = super::Reply; + type ResponseStream = T::PredictStreamStream; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::predict_stream(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = PredictStreamSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/Embedding" => { + #[allow(non_camel_case_types)] + struct EmbeddingSvc(pub Arc); + impl tonic::server::UnaryService + for EmbeddingSvc { + type Response = super::EmbeddingResult; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::embedding(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = EmbeddingSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/GenerateImage" => { + #[allow(non_camel_case_types)] + struct GenerateImageSvc(pub Arc); + impl< + T: Backend, + > tonic::server::UnaryService + for GenerateImageSvc { + type Response = super::Result; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::generate_image(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = GenerateImageSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/AudioTranscription" => { + #[allow(non_camel_case_types)] + struct AudioTranscriptionSvc(pub Arc); + impl< + T: Backend, + > tonic::server::UnaryService + for AudioTranscriptionSvc { + type Response = super::TranscriptResult; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::audio_transcription(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = AudioTranscriptionSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/TTS" => { + #[allow(non_camel_case_types)] + struct TTSSvc(pub Arc); + impl tonic::server::UnaryService + for TTSSvc { + type Response = super::Result; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::tts(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = TTSSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/TokenizeString" => { + #[allow(non_camel_case_types)] + struct TokenizeStringSvc(pub Arc); + impl tonic::server::UnaryService + for TokenizeStringSvc { + type Response = super::TokenizationResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::tokenize_string(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = TokenizeStringSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/Status" => { + #[allow(non_camel_case_types)] + struct StatusSvc(pub Arc); + impl tonic::server::UnaryService + for StatusSvc { + type Response = super::StatusResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::status(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = StatusSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => { + Box::pin(async move { + Ok( + http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap(), + ) + }) + } + } + } + } + impl Clone for BackendServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + impl Clone for _Inner { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } + } + impl std::fmt::Debug for _Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + impl tonic::server::NamedService for BackendServer { + const NAME: &'static str = "backend.Backend"; + } +} diff --git a/backend/rust/bunker/src/lib.rs b/backend/rust/bunker/src/lib.rs new file mode 100644 index 000000000000..9e0f781e0443 --- /dev/null +++ b/backend/rust/bunker/src/lib.rs @@ -0,0 +1,25 @@ +/// Import the code() backend.rs was generated from backend.proto +pub mod pb { + include!("../generated/backend.rs"); +} + +use std::net::SocketAddr; +use tonic::transport::Server; + +pub use crate::pb::backend_server::Backend as BackendService; +use crate::pb::backend_server::BackendServer; + +// Run the backend with the default behavior +pub async fn run( + backend: impl BackendService, + addr: impl Into, +) -> Result<(), Box> { + let svc = BackendServer::new(backend); + + let r = Server::builder() + .add_service(svc) + .serve(addr.into()) + .await?; + + Ok(r) +} diff --git a/backend/rust/codegen/Cargo.toml b/backend/rust/codegen/Cargo.toml new file mode 100644 index 000000000000..c8eadcf13764 --- /dev/null +++ b/backend/rust/codegen/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "codegen" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[build-dependencies] +tonic-build = "0.10" diff --git a/backend/rust/codegen/build.rs b/backend/rust/codegen/build.rs new file mode 100644 index 000000000000..f702cc7fdc62 --- /dev/null +++ b/backend/rust/codegen/build.rs @@ -0,0 +1,11 @@ +fn main() { + tonic_build::configure() + .out_dir("../bunker/generated") + .build_server(true) + .build_client(false) + .compile( + &["../../../pkg/grpc/proto/backend.proto"], + &["../../../pkg/grpc/proto"], + ) + .unwrap(); +} diff --git a/backend/rust/codegen/src/lib.rs b/backend/rust/codegen/src/lib.rs new file mode 100644 index 000000000000..219ad14b9ebb --- /dev/null +++ b/backend/rust/codegen/src/lib.rs @@ -0,0 +1,2 @@ +//! Here is for the more complex situation of code generation. For example, defind the constant for +//! the different build target. diff --git a/backend/rust/models/Cargo.toml b/backend/rust/models/Cargo.toml new file mode 100644 index 000000000000..8cb651876cf4 --- /dev/null +++ b/backend/rust/models/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "models" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +default = ["ndarray"] + +ndarray = ["burn/ndarray"] +wgpu = ["burn/wgpu"] + +[dependencies] +bunker = { path = "../bunker" } +burn = { version="0.10.0", features=["ndarray","wgpu"] } # https://github.com/mudler/LocalAI/discussions/1219 +serde = "1.0.190" diff --git a/backend/rust/models/src/lib.rs b/backend/rust/models/src/lib.rs new file mode 100644 index 000000000000..8d13c577703f --- /dev/null +++ b/backend/rust/models/src/lib.rs @@ -0,0 +1,19 @@ +use bunker::pb::{ModelOptions, PredictOptions}; + +pub(crate) mod mnist; +pub use mnist::mnist::MNINST; + +pub(crate) mod llama; + + +/// Trait for implementing a Language Model. +pub trait LLM { + + type Model: LLM; + + /// Loads the model from the given options. + fn load_model(request: ModelOptions) -> Result>; + /// Predicts the output for the given input options. + fn predict(request: PredictOptions) -> Result>; +} + diff --git a/backend/rust/models/src/llama/llama.rs b/backend/rust/models/src/llama/llama.rs new file mode 100644 index 000000000000..805f3eb7c34a --- /dev/null +++ b/backend/rust/models/src/llama/llama.rs @@ -0,0 +1,720 @@ +//! The source code is from https://github.com/Gadersd/llama2-burn/blob/main/src/model.rs +//! The license is Special MIT License.(And the code will be replaced in the future, currently it is just a test.) +//! Adapter by Aisuko + +use std::f32::NEG_INFINITY; + +use burn::{ + config::Config, + module::{Module, Param}, + nn, + tensor::{ + activation::{sigmoid, softmax}, + backend::Backend, + module::embedding, + Data, Distribution, Int, Tensor, + }, backend::wgpu::tensor, +}; + +#[derive(Config, Debug)] +pub struct LlamaConfig { + n_vocab: usize, + n_ctx: usize, + n_state: usize, + multiple_of: usize, + ffn_dim_multiplier: Option, + n_head: usize, + n_kv_head: usize, + n_layer: usize, + #[config(default = 1e-6)] + norm_eps: f64, +} + +impl LlamaConfig { + pub fn init(&self) -> Llama { + let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(); + let rotary_encoder = + RotaryEncodingConfig::new(self.n_ctx, self.n_state / self.n_head, 10000.0).init(); + let blocks: Vec<_> = (0..self.n_layer) + .into_iter() + .map(|_| { + ResidualDecoderAttentionBlockConfig::new( + self.n_state, + self.multiple_of, + self.n_head, + self.n_kv_head, + self.norm_eps, + ) + .with_ffn_dim_multiplier(self.ffn_dim_multiplier) + .init() + }) + .collect(); + + let norm = RMSNormConfig::new(self.n_state, self.norm_eps).init(); + let output = nn::LinearConfig::new(self.n_state, self.n_vocab) + .with_bias(false) + .init(); + + let mask = attn_decoder_mask(self.n_ctx).into(); + + let n_vocab = self.n_vocab; + let n_ctx = self.n_ctx; + + Llama { + token_embedding, + rotary_encoder, + blocks, + norm, + output, + mask, + n_vocab, + n_ctx, + } + } +} + +#[derive(Module, Debug)] +pub struct Llama { + token_embedding: nn::Embedding, + rotary_encoder: RotaryEncoding, + blocks: Vec>, + norm: RMSNorm, + output: nn::Linear, + mask: Param>, + n_vocab: usize, + n_ctx: usize, +} + +impl Llama { + pub fn forward(&self, x: Tensor) -> Tensor { + let [n_batch, seq_len] = x.dims(); + + assert!( + seq_len <= self.n_ctx, + "Token sequence length {} must not exceed {}.", + seq_len, + self.n_ctx + ); + + let x = self.token_embedding.forward(x); + + let mut x = x; + for block in &self.blocks { + x = block.forward(x, &self.rotary_encoder, self.mask.val()); + } + + self.output.forward(self.norm.forward(x)) + } +} + +#[derive(Config)] +pub struct ResidualDecoderAttentionBlockConfig { + n_state: usize, + multiple_of: usize, + ffn_dim_multiplier: Option, + n_head: usize, + n_kv_head: usize, + norm_eps: f64, +} + +impl ResidualDecoderAttentionBlockConfig { + fn init(&self) -> ResidualDecoderAttentionBlock { + let attn = + MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head, self.n_kv_head).init(); + let attn_norm = RMSNormConfig::new(self.n_state, self.norm_eps).init(); + + let mlp = MLPConfig::new(self.n_state, 4 * self.n_state, self.multiple_of) + .with_ffn_dim_multiplier(self.ffn_dim_multiplier) + .init(); + let mlp_norm = RMSNormConfig::new(self.n_state, self.norm_eps).init(); + + ResidualDecoderAttentionBlock { + attn, + attn_norm, + mlp, + mlp_norm, + } + } +} + +#[derive(Module, Debug)] +pub struct ResidualDecoderAttentionBlock { + attn: MultiHeadSelfAttention, + attn_norm: RMSNorm, + mlp: MLP, + mlp_norm: RMSNorm, +} + +impl ResidualDecoderAttentionBlock { + fn forward( + &self, + x: Tensor, + rotary_encoder: &RotaryEncoding, + mask: Tensor, + ) -> Tensor { + let x = x.clone() + + self + .attn + .forward(self.attn_norm.forward(x), rotary_encoder, Some(mask)); + let x = x.clone() + self.mlp.forward(self.mlp_norm.forward(x)); + return x; + } +} + +#[derive(Config)] +pub struct MLPConfig { + n_state: usize, + n_state_hidden: usize, + multiple_of: usize, + ffn_dim_multiplier: Option, +} + +impl MLPConfig { + fn init(&self) -> MLP { + let mut hidden_dim = 2 * self.n_state_hidden / 3; + if let Some(ffn_dim_multiplier) = self.ffn_dim_multiplier { + hidden_dim = ffn_dim_multiplier * hidden_dim; + } + hidden_dim = self.multiple_of * ((hidden_dim + self.multiple_of - 1) / self.multiple_of); + + let w1 = nn::LinearConfig::new(self.n_state, hidden_dim) + .with_bias(false) + .init(); + let w2 = nn::LinearConfig::new(hidden_dim, self.n_state) + .with_bias(false) + .init(); + let w3 = nn::LinearConfig::new(self.n_state, hidden_dim) + .with_bias(false) + .init(); + + let silu = SILU::new(); + + MLP { w1, w2, w3, silu } + } +} + +#[derive(Module, Debug)] +pub struct MLP { + w1: nn::Linear, + w2: nn::Linear, + w3: nn::Linear, + silu: SILU, +} + +impl MLP { + fn forward(&self, x: Tensor) -> Tensor { + self.w2 + .forward(self.silu.forward(self.w1.forward(x.clone())) * self.w3.forward(x)) + } +} + +#[derive(Config)] +pub struct MultiHeadSelfAttentionConfig { + n_state: usize, + n_head: usize, + n_kv_head: usize, +} + +impl MultiHeadSelfAttentionConfig { + fn init(&self) -> MultiHeadSelfAttention { + assert!( + self.n_state % self.n_head == 0, + "State size {} must be a multiple of the number of heads {}", + self.n_state, + self.n_head + ); + assert!( + self.n_head % self.n_kv_head == 0, + "The number of query heads {} must be a multiple of the number of k/v heads {}", + self.n_head, + self.n_kv_head + ); + + let n_head_dim = self.n_state / self.n_head; + + let n_head = self.n_head; + let n_kv_head = self.n_kv_head; + let query = nn::LinearConfig::new(self.n_state, self.n_state) + .with_bias(false) + .init(); + let key = nn::LinearConfig::new(self.n_state, n_kv_head * n_head_dim) + .with_bias(false) + .init(); + let value = nn::LinearConfig::new(self.n_state, n_kv_head * n_head_dim) + .with_bias(false) + .init(); + let out = nn::LinearConfig::new(self.n_state, self.n_state) + .with_bias(false) + .init(); + + MultiHeadSelfAttention { + n_head, + n_kv_head, + query, + key, + value, + out, + } + } +} + +#[derive(Module, Debug)] +pub struct MultiHeadSelfAttention { + n_head: usize, + n_kv_head: usize, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + out: nn::Linear, +} + +impl MultiHeadSelfAttention { + fn forward( + &self, + x: Tensor, + rotary_encoder: &RotaryEncoding, + mask: Option>, + ) -> Tensor { + let q = self.query.forward(x.clone()); + let k = self.key.forward(x.clone()); + let v = self.value.forward(x); + + let wv = qkv_attention_rotary(q, k, v, mask, self.n_head, self.n_kv_head, rotary_encoder); + + return self.out.forward(wv); + } +} + +fn qkv_attention_rotary( + q: Tensor, + k: Tensor, + v: Tensor, + mask: Option>, + n_head: usize, + n_kv_head: usize, + rotary_encoder: &RotaryEncoding, +) -> Tensor { + let [n_batch, n_qctx, n_state] = q.dims(); + let [_, n_ctx, _] = k.dims(); + + let n_hstate = n_state / n_head; + let scale = (n_hstate as f64).powf(-0.25); // keeps the value weightings roughly normally distributed + + let q: Tensor = q.reshape([n_batch, n_qctx, n_head, n_hstate]); + // interleave kv heads to match the number of q heads + let n_repeat = n_head / n_kv_head; + let k = repeat_kv(k.reshape([n_batch, n_ctx, n_kv_head, n_hstate]), n_repeat); + let v = repeat_kv(v.reshape([n_batch, n_ctx, n_kv_head, n_hstate]), n_repeat); + + // the last two dims need to be (n_ctx, n_hstate) + let q = rotary_encoder.forward(q.swap_dims(1, 2)) * scale; + let k = rotary_encoder.forward(k.swap_dims(1, 2)) * scale; + let v = v.swap_dims(1, 2); + + // compute value weightings + let qk = q.matmul(k.transpose()); + + // apply mask + let qk = if let Some(mask) = mask { + qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>() + } else { + qk + }; + + // normalize value weightings + let w = softmax(qk, 3); + let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3); + + return o; +} + +/// For a tensor of size (n_batch, n_ctx, n_kv_head, n_hstate), repeats the head keys or values in an interleaving manner so that the number +/// of heads is effectively multiplied by n_repeat +fn repeat_kv(x: Tensor, n_repeat: usize) -> Tensor { + if n_repeat > 1 { + let [n_batch, n_ctx, n_kv_head, n_hstate] = x.dims(); + x.repeat(3, n_repeat) + .reshape([n_batch, n_ctx, n_kv_head * n_repeat, n_hstate]) + } else { + x + } +} + +/// Generates a strictly upper triangular matrix filled with -inf that when added to an attention weight matrix prevents +/// vectors from attending to other vectors further in the sequence, preventing future information from flowing into the past +pub fn attn_decoder_mask(seq_length: usize) -> Tensor { + let mut mask = Tensor::::zeros([seq_length, seq_length]); + + for i in 0..(seq_length - 1) { + let values = Tensor::::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY); + mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values); + } + + return mask; +} + +#[derive(Config, Debug)] +pub struct RotaryEncodingConfig { + max_sequence_length: usize, + state_size: usize, + theta: f64, +} + +impl RotaryEncodingConfig { + pub fn init(&self) -> RotaryEncoding { + assert!(self.state_size % 2 == 0, "Head dims must be even."); + assert!(self.theta > 0.0, "Theta must be positive."); + + let half_state_size = self.state_size / 2; + + let arange_m = Tensor::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]]).into(); + + let inv_freq = powto( + self.theta, + Tensor::arange(0..half_state_size).float() * (2.0 / self.state_size as f64), + ) + .powf(-1.0); + + let periods = Tensor::arange(0..self.max_sequence_length) + .float() + .unsqueeze::<2>() + .transpose() + .repeat(1, half_state_size) + * inv_freq.unsqueeze(); + + let p_cos = periods.clone().cos(); + let p_sin = periods.sin(); + let tensor=Tensor::cat(vec![p_cos, p_sin], 1); + + let tensor2=tensor.reshape([self.max_sequence_length,2,half_state_size]); + + let tensor3=tensor2.transpose(); + + let tensor41=tensor3.repeat(2, 2); + + let tensor5=tensor41.reshape([self.max_sequence_length,self.state_size,2]); + + let freq_cis=tensor5.into(); + + RotaryEncoding { arange_m, freq_cis } + } + +} + +fn powto(base: f64, x: Tensor) -> Tensor { + let logbase = base.ln(); + x.mul_scalar(logbase).exp() +} + +/// Conceptually, pairs the values of a vector (v0 v1 v2 ... vn) into complex numbers (c0 c1 c2 ... cn/2) +/// which are then rotated counter-clockwise by the angle seq_index / theta^(2*pair_index/n). +/// This encodes sequence positions in a way that is agnostic to the maximum sequence length +/// which potentially allows for arbitrarily long sequences without retraining. +#[derive(Module, Debug)] +pub struct RotaryEncoding { + arange_m: Param>, + freq_cis: Param>, +} + +impl RotaryEncoding { + /// Applies rotary positional encoding to a tensor of dimenions (..., seq_len, n_state) + fn forward(&self, x: Tensor) -> Tensor { + assert!(D >= 2); + let orig_shape = x.shape(); + let (n_ctx, n_state) = (orig_shape.dims[D - 2], orig_shape.dims[D - 1]); + let dummy_dim_size = orig_shape.num_elements() / (n_ctx * n_state); + + let out = x + .reshape([dummy_dim_size, n_ctx, n_state / 2, 2]) + .matmul(self.arange_m.val().unsqueeze()) + .reshape([dummy_dim_size, n_ctx, n_state, 2]) + * self.freq_cis.val().slice([0..n_ctx]).unsqueeze(); + + out.sum_dim(D - 1).reshape(orig_shape) + } +} + +#[derive(Config)] +pub struct RMSNormConfig { + layer_size: usize, + eps: f64, +} + +impl RMSNormConfig { + fn init(&self) -> RMSNorm { + assert!(self.eps > 0.0, "eps must be positive."); + + let weight = Tensor::ones([self.layer_size]); + let eps = self.eps; + + RMSNorm { weight, eps } + } +} + +#[derive(Module, Debug)] +pub struct RMSNorm { + weight: Tensor, + eps: f64, +} + +impl RMSNorm { + fn forward(&self, x: Tensor) -> Tensor { + let rms = (x.clone().powf(2.0).mean_dim(D - 1) + self.eps).sqrt(); + (x / rms) * self.weight.clone().unsqueeze() + } +} + +#[derive(Module, Clone, Debug)] +pub struct SILU {} + +impl SILU { + fn new() -> Self { + Self {} + } + + fn forward(&self, x: Tensor) -> Tensor { + x.clone() * sigmoid(x) + } +} + +use npy::{self, NpyData}; //TODO: NpyData is deprecated, use ndarray_npy instead, but before replace it, I want to make sure the project works well. +use num_traits::cast::ToPrimitive; // TODO: Same here. +use std::error::Error; +use std::io::Read; + +use burn::tensor::ElementConversion; + +fn numpy_to_tensor( + numpy_data: NpyData, + device: &B::Device, +) -> Tensor { + let mut v = numpy_data.to_vec(); + + let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect(); + let data: Vec = v[D..].into_iter().map(|e| e.elem()).collect(); + + Tensor::from_data_device(Data::new(data, shape.into()), device) +} + +fn load_tensor( + name: &str, + path: &str, + device: &B::Device, +) -> Result, Box> { + let tensor_path = format!("{}/{}.npy", path, name); + + let mut buf = vec![]; + std::fs::File::open(&tensor_path)?.read_to_end(&mut buf)?; + + let tensor_numpy: NpyData = NpyData::from_bytes(&buf)?; + + let tensor = numpy_to_tensor(tensor_numpy, device); + + println!("{}", tensor_path); + + Ok(tensor) +} + +fn load_f32(name: &str, path: &str, device: &B::Device) -> Result> { + load_tensor::(name, path, device).map(|t| t.into_scalar().to_f32().unwrap()) +} + +fn load_usize( + name: &str, + path: &str, + device: &B::Device, +) -> Result> { + load_tensor::(name, path, device).map(|t| t.into_scalar().to_usize().unwrap()) +} + +fn load_linear( + path: &str, + device: &B::Device, +) -> Result, Box> { + let weight = load_tensor::("weight", path, device)?; + let bias = load_tensor::("bias", path, device).ok(); + + let record = nn::LinearRecord { + weight: weight.into(), + bias: bias.map(|t| t.into()), + }; + + let linear: nn::Linear = nn::LinearConfig::new(3, 3).init_with(record); + Ok(linear) +} + +fn load_rmsnorm(path: &str, device: &B::Device) -> Result, Box> { + let weight = load_tensor::("weight", path, device)?; + let eps = load_f32::("eps", path, device)?.into(); + + let rmsnorm = RMSNorm { + weight: weight.into(), + eps: eps, + }; + + Ok(rmsnorm) +} + +fn load_attention( + path: &str, + device: &B::Device, +) -> Result, Box> { + let query = load_linear(&format!("{}/{}", path, "wq"), device)?; + let key = load_linear(&format!("{}/{}", path, "wk"), device)?; + let value = load_linear(&format!("{}/{}", path, "wv"), device)?; + let out = load_linear(&format!("{}/{}", path, "wo"), device)?; + + let n_head = load_usize::("n_head", path, device)?; + let n_kv_head = load_usize::("n_kv_head", path, device)?; + + let attention = MultiHeadSelfAttention { + n_head: n_head, + n_kv_head: n_kv_head, + query: query, + key: key, + value: value, + out: out, + }; + + Ok(attention) +} + +fn load_feedforward(path: &str, device: &B::Device) -> Result, Box> { + let w1 = load_linear(&format!("{}/{}", path, "w1"), device)?; + let w2 = load_linear(&format!("{}/{}", path, "w2"), device)?; + let w3 = load_linear(&format!("{}/{}", path, "w3"), device)?; + + let mlp = MLP { + w1: w1, + w2: w2, + w3: w3, + silu: SILU::new(), + }; + + Ok(mlp) +} + +fn load_transformer_block( + path: &str, + device: &B::Device, +) -> Result, Box> { + let attn = load_attention(&format!("{}/{}", path, "attention"), device)?; + let attn_norm = load_rmsnorm(&format!("{}/{}", path, "attention_norm"), device)?; + let mlp = load_feedforward(&format!("{}/{}", path, "feedforward"), device)?; + let mlp_norm = load_rmsnorm(&format!("{}/{}", path, "ffn_norm"), device)?; + + let block = ResidualDecoderAttentionBlock { + attn: attn, + attn_norm: attn_norm, + mlp: mlp, + mlp_norm: mlp_norm, + }; + + Ok(block) +} + +use burn::nn::{EmbeddingConfig, EmbeddingRecord}; + +pub fn load_llama_dump( + path: &str, + device: &B::Device, +) -> Result<(Llama, LlamaConfig), Box> { + let mut blocks: Vec> = vec![]; + let n_layer = load_usize::("n_layer", path, device)?; + for i in 0..n_layer { + let block = load_transformer_block(&format!("{}/layer{}", path, i), device)?; + blocks.push(block); + } + + let n_ctx = load_usize::("n_ctx", path, device)?; + let theta = load_f32::("theta", path, device)?; + let multiple_of = load_usize::("multiple_of", path, device)?; + let ffn_dim_multiplier = load_usize::("ffn_dim_multiplier", path, device).ok(); + + let token_embedding = load_tensor("tok_embeddings/weight", path, device)?; + let [n_vocab, n_state] = token_embedding.dims(); + let n_head = blocks[0].attn.n_head; + let n_kv_head = blocks[0].attn.n_kv_head; + let head_dim = n_state / n_head; + + let token_embedding = EmbeddingConfig::new(n_vocab, n_state).init_with(EmbeddingRecord { + weight: token_embedding.into(), + }); + let rotary_encoding = RotaryEncodingConfig::new(n_ctx, head_dim, theta.into()).init(); + + let norm = load_rmsnorm(&format!("{}/{}", path, "norm"), device)?; + let output = load_linear(&format!("{}/{}", path, "output"), device)?; + let mask = attn_decoder_mask(n_ctx).into(); + + let norm_eps = norm.eps; + + let llama = Llama { + token_embedding: token_embedding, + rotary_encoder: rotary_encoding, + blocks: blocks, + norm: norm, + output: output, + mask: mask, + n_vocab: n_vocab, + n_ctx: n_ctx, + }; + + let llama_config = LlamaConfig::new( + n_vocab, + n_ctx, + n_state, + multiple_of, + n_head, + n_kv_head, + n_layer, + ) + .with_norm_eps(norm_eps) + .with_ffn_dim_multiplier(ffn_dim_multiplier); + + Ok((llama, llama_config)) +} + + +#[cfg(test)] +mod tests{ + use super::*; + + #[test] + fn test_feq_cis_reshape(){ + use burn::backend::WgpuBackend; + use burn::backend::wgpu::{AutoGraphicsApi}; + + type Backend = WgpuBackend; + + let config= RotaryEncodingConfig{ + max_sequence_length: 10, + state_size: 4, + theta: 1.0, + }; + + let encoding=config.init::(); + + assert_eq!(encoding.freq_cis.dims(),[10,4,2]); + assert_eq!(encoding.arange_m.dims(),[2,4]); + } + + #[test] + fn test_rotary_encoding_config_init(){ + use burn::backend::WgpuBackend; + use burn::backend::wgpu::{AutoGraphicsApi}; + + type Backend = WgpuBackend; + + let config = RotaryEncodingConfig{ + state_size: 4, + theta:1.0, + max_sequence_length: 10, + }; + + let encoding=config.init::(); + + assert_eq!(encoding.arange_m.dims(),[2,4]); + assert_eq!(encoding.freq_cis.dims(),[10,4,2]); + + } +} \ No newline at end of file diff --git a/backend/rust/models/src/llama/mod.rs b/backend/rust/models/src/llama/mod.rs new file mode 100644 index 000000000000..c000cb3c53ee --- /dev/null +++ b/backend/rust/models/src/llama/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod llama; + +use llama::*; \ No newline at end of file diff --git a/backend/rust/models/src/mnist/mnist.rs b/backend/rust/models/src/mnist/mnist.rs new file mode 100644 index 000000000000..6b2c06ae2932 --- /dev/null +++ b/backend/rust/models/src/mnist/mnist.rs @@ -0,0 +1,188 @@ +//! Defination of a mninst model and config of it. +//! The source code is from https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/src/model.rs +//! The license is Apache-2.0 and MIT. +//! Adapter by Aisuko + +use burn::{ + backend::wgpu::{AutoGraphicsApi, WgpuDevice}, + module::Module, + nn::{self, BatchNorm, PaddingConfig2d}, + record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, + tensor::{backend::Backend, Tensor}, +}; + +// https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/model.bin + +const NUM_CLASSES: usize = 10; + +#[derive(Module, Debug)] +/// A struct representing an MNINST model. +pub struct MNINST { + /// The first convolutional block of the model. + conv1: ConvBlock, + /// The second convolutional block of the model. + conv2: ConvBlock, + /// The third convolutional block of the model. + conv3: ConvBlock, + /// A dropout layer used in the model. + dropout: nn::Dropout, + /// The first fully connected layer of the model. + fc1: nn::Linear, + /// The second fully connected layer of the model. + fc2: nn::Linear, + /// The activation function used in the model. + activation: nn::GELU, +} + +impl MNINST { + pub fn new(model_name: &str) -> Self { + let conv1 = ConvBlock::new([1, 8], [3, 3]); // 1 input channel, 8 output channels, 3x3 kernel size + let conv2 = ConvBlock::new([8, 16], [3, 3]); // 8 input channels, 16 output channels, 3x3 kernel size + let conv3 = ConvBlock::new([16, 24], [3, 3]); // 16 input channels, 24 output channels, 3x3 kernel size + let hidden_size = 24 * 22 * 22; + let fc1 = nn::LinearConfig::new(hidden_size, 32) + .with_bias(false) + .init(); + let fc2 = nn::LinearConfig::new(32, NUM_CLASSES) + .with_bias(false) + .init(); + + let dropout = nn::DropoutConfig::new(0.5).init(); + + let instance = Self { + conv1: conv1, + conv2: conv2, + conv3: conv3, + dropout: dropout, + fc1: fc1, + fc2: fc2, + activation: nn::GELU::new(), + }; + use std::path::Path; + let path_name=Path::new(model_name); + + let state_encoded: &[u8] = &std::fs::read(&path_name).expect("Failed to read model file"); + let record = BinBytesRecorder::::default() + .load(state_encoded.to_vec()) + .expect("Failed to decode state"); + + instance.load_record(record) + } + + /// Applies the forward pass of the neural network on the given input tensor. + /// + /// # Arguments + /// + /// * `input` - A 3-dimensional tensor of shape [batch_size, height, width]. + /// + /// # Returns + /// + /// A 2-dimensional tensor of shape [batch_size, num_classes] containing the output of the neural network. + pub fn forward(&self, input: Tensor) -> Tensor { + // Get the dimensions of the input tensor + let [batch_size, height, width] = input.dims(); + // Reshape the input tensor to have a shape of [batch_size, 1, height, width] and detach it + let x = input.reshape([batch_size, 1, height, width]).detach(); + // Apply the first convolutional layer to the input tensor + let x = self.conv1.forward(x); + // Apply the second convolutional layer to the output of the first convolutional layer + let x = self.conv2.forward(x); + // Apply the third convolutional layer to the output of the second convolutional layer + let x = self.conv3.forward(x); + + // Get the dimensions of the output tensor from the third convolutional layer + let [batch_size, channels, height, width] = x.dims(); + // Reshape the output tensor to have a shape of [batch_size, channels*height*width] + let x = x.reshape([batch_size, channels * height * width]); + + // Apply dropout to the output of the third convolutional layer + let x = self.dropout.forward(x); + // Apply the first fully connected layer to the output of the dropout layer + let x = self.fc1.forward(x); + // Apply the activation function to the output of the first fully connected layer + let x = self.activation.forward(x); + + // Apply the second fully connected layer to the output of the activation function + self.fc2.forward(x) + } + + pub fn inference(&mut self, input: &[f32]) -> Result, Box> { + // Reshape from the 1D array to 3d tensor [batch, height, width] + let input: Tensor = Tensor::from_floats(input).reshape([1, 28, 28]); + + // Normalize input: make between [0,1] and make the mean=0 and std=1 + // values mean=0.1307, std=0.3081 + // Source: https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + let input = ((input / 255) - 0.1307) / 0.3081; + + // Run the tensor input through the model + let output: Tensor = self.forward(input); + + // Convert the model output into probalibility distribution using softmax formula + let output = burn::tensor::activation::softmax(output, 1); + + // Flatten oupuut tensor with [1,10] shape into boxed slice of [f32] + let output = output.into_data().convert::().value; + + Ok(output) + } +} + +/// A struct representing a convolutional block in a neural network model. +#[derive(Module, Debug)] +pub struct ConvBlock { + /// A 2D convolutional layer. + conv: nn::conv::Conv2d, + /// A batch normalization layer. + norm: BatchNorm, + /// A GELU activation function. + activation: nn::GELU, +} + +/// A convolutional block with batch normalization and GELU activation. +impl ConvBlock { + /// Creates a new `ConvBlock` with the given number of output channels and kernel size. + pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { + // Initialize a 2D convolutional layer with the given output channels and kernel size, + // and set the padding to "valid". + let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) + .with_padding(PaddingConfig2d::Valid) + .init(); + + // Initialize a batch normalization layer with the number of channels in the second dimension of the output. + let norm = nn::BatchNormConfig::new(channels[1]).init(); + + // Create a new `ConvBlock` with the initialized convolutional and batch normalization layers, + // and a GELU activation function. + Self { + conv: conv, + norm: norm, + activation: nn::GELU::new(), + } + } + + /// Applies the convolutional block to the given input tensor. + pub fn forward(&self, input: Tensor) -> Tensor { + // Apply the convolutional layer to the input tensor. + let x = self.conv.forward(input); + + // Apply the batch normalization layer to the output of the convolutional layer. + let x = self.norm.forward(x); + + // Apply the GELU activation function to the output of the batch normalization layer. + self.activation.forward(x) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[cfg(feature = "ndarray")] + pub type Backend = burn::backend::NdArrayBackend; + #[test] + fn test_inference() { + let mut model = MNINST::::new("model.bin"); + let output = model.inference(&[0.0; 28 * 28]).unwrap(); + assert_eq!(output.len(), 10); + } +} diff --git a/backend/rust/models/src/mnist/mod.rs b/backend/rust/models/src/mnist/mod.rs new file mode 100644 index 000000000000..8372ec3f1bbd --- /dev/null +++ b/backend/rust/models/src/mnist/mod.rs @@ -0,0 +1,26 @@ +use crate::LLM; + +pub(crate) mod mnist; + +use mnist::MNINST; + +use bunker::pb::{ModelOptions, PredictOptions}; + +#[cfg(feature = "ndarray")] +pub type Backend = burn::backend::NdArrayBackend; + +impl LLM for MNINST { + type Model = MNINST; + + fn load_model(request: ModelOptions) -> Result> { + let model = request.model_file; + let instance= MNINST::::new(&model); + // check instance and return result + Ok(instance) + } + + fn predict(pre_ops: PredictOptions) -> Result> { + // convert prost::alloc::string::String to &[f32] + todo!() + } +}