Skip to content

Commit 2ffc236

Browse files
authored
Merge pull request #10 from SilasMarvin/silas-llamacpp-custom-file
Added file_path config option for llama_cpp models
2 parents f84b5fe + ef86f20 commit 2ffc236

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

src/config.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,6 @@ pub struct FileStore {
8181
pub crawl: bool,
8282
}
8383

84-
#[derive(Clone, Debug, Deserialize)]
85-
#[serde(deny_unknown_fields)]
86-
pub struct Model {
87-
pub repository: String,
88-
pub name: Option<String>,
89-
}
90-
9184
const fn n_gpu_layers_default() -> u32 {
9285
1000
9386
}
@@ -106,20 +99,25 @@ pub struct MistralFIM {
10699
pub fim_endpoint: Option<String>,
107100
// The model name
108101
pub model: String,
102+
// The maximum requests per second
109103
#[serde(default = "max_requests_per_second_default")]
110104
pub max_requests_per_second: f32,
111105
}
112106

113107
#[derive(Clone, Debug, Deserialize)]
114108
#[serde(deny_unknown_fields)]
115109
pub struct LLaMACPP {
116-
// The model to use
117-
#[serde(flatten)]
118-
pub model: Model,
110+
// Which model to use
111+
pub repository: Option<String>,
112+
pub name: Option<String>,
113+
pub file_path: Option<String>,
114+
// The layers to put on the GPU
119115
#[serde(default = "n_gpu_layers_default")]
120116
pub n_gpu_layers: u32,
117+
// The context size
121118
#[serde(default = "n_ctx_default")]
122119
pub n_ctx: u32,
120+
// The maximum requests per second
123121
#[serde(default = "max_requests_per_second_default")]
124122
pub max_requests_per_second: f32,
125123
}
@@ -129,6 +127,7 @@ pub struct LLaMACPP {
129127
pub struct OpenAI {
130128
// The auth token env var name
131129
pub auth_token_env_var_name: Option<String>,
130+
// The auth token
132131
pub auth_token: Option<String>,
133132
// The completions endpoint
134133
pub completions_endpoint: Option<String>,

src/transformer_backends/llama_cpp/mod.rs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use crate::{
99
},
1010
utils::format_chat_messages,
1111
};
12-
use anyhow::Context;
1312
use hf_hub::api::sync::ApiBuilder;
1413
use serde::Deserialize;
1514
use serde_json::Value;
@@ -41,15 +40,22 @@ pub struct LLaMACPP {
4140
impl LLaMACPP {
4241
#[instrument]
4342
pub fn new(configuration: config::LLaMACPP) -> anyhow::Result<Self> {
44-
let api = ApiBuilder::new().with_progress(true).build()?;
45-
let name = configuration
46-
.model
47-
.name
48-
.as_ref()
49-
.context("Please set `name` to use LLaMA.cpp")?;
50-
error!("Loading in: {} - {}\nIf this model has not been loaded before it may take a few minutes to download it. Please hangtight.", configuration.model.repository, name);
51-
let repo = api.model(configuration.model.repository.to_owned());
52-
let model_path = repo.get(name)?;
43+
let model_path = match (
44+
&configuration.file_path,
45+
&configuration.repository,
46+
&configuration.name,
47+
) {
48+
(Some(file_path), _, _) => std::path::PathBuf::from(file_path),
49+
(_, Some(repository), Some(name)) => {
50+
let api = ApiBuilder::new().with_progress(true).build()?;
51+
error!("Loading in: {} - {}\nIf this model has not been loaded before it may take a few minutes to download it. Please hangtight.", repository, name);
52+
let repo = api.model(repository.clone());
53+
repo.get(&name)?
54+
}
55+
_ => anyhow::bail!(
56+
"To use llama.cpp provide either `file_path` or `repository` and `name`"
57+
),
58+
};
5359
let model = Model::new(model_path, &configuration)?;
5460
Ok(Self { model })
5561
}

0 commit comments

Comments
 (0)