Skip to content

Commit a9220bc

Browse files
authored
Merge pull request #11 from SilasMarvin/silas-add-ollama-api
Added Ollama as a backend
2 parents 2ffc236 + a9a069f commit a9220bc

File tree

3 files changed

+286
-0
lines changed

3 files changed

+286
-0
lines changed

src/config.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ pub enum ValidModel {
2929
Anthropic(Anthropic),
3030
#[serde(rename = "mistral_fim")]
3131
MistralFIM(MistralFIM),
32+
#[serde(rename = "ollama")]
33+
Ollama(Ollama),
3234
}
3335

3436
#[derive(Debug, Clone, Deserialize, Serialize)]
@@ -89,6 +91,16 @@ const fn n_ctx_default() -> u32 {
8991
1000
9092
}
9193

94+
#[derive(Clone, Debug, Deserialize)]
95+
#[serde(deny_unknown_fields)]
96+
pub struct Ollama {
97+
// The model name
98+
pub model: String,
99+
// The maximum requests per second
100+
#[serde(default = "max_requests_per_second_default")]
101+
pub max_requests_per_second: f32,
102+
}
103+
92104
#[derive(Clone, Debug, Deserialize)]
93105
#[serde(deny_unknown_fields)]
94106
pub struct MistralFIM {
@@ -237,6 +249,7 @@ impl Config {
237249
ValidModel::OpenAI(open_ai) => Ok(open_ai.max_requests_per_second),
238250
ValidModel::Anthropic(anthropic) => Ok(anthropic.max_requests_per_second),
239251
ValidModel::MistralFIM(mistral_fim) => Ok(mistral_fim.max_requests_per_second),
252+
ValidModel::Ollama(ollama) => Ok(ollama.max_requests_per_second),
240253
}
241254
}
242255
}
@@ -298,6 +311,33 @@ mod test {
298311
Config::new(args).unwrap();
299312
}
300313

314+
#[test]
315+
fn ollama_config() {
316+
let args = json!({
317+
"initializationOptions": {
318+
"memory": {
319+
"file_store": {}
320+
},
321+
"models": {
322+
"model1": {
323+
"type": "ollama",
324+
"model": "llama3"
325+
}
326+
},
327+
"completion": {
328+
"model": "model1",
329+
"parameters": {
330+
"max_context": 1024,
331+
"options": {
332+
"num_predict": 32
333+
}
334+
}
335+
}
336+
}
337+
});
338+
Config::new(args).unwrap();
339+
}
340+
301341
#[test]
302342
fn open_ai_config() {
303343
let args = json!({

src/transformer_backends/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ mod anthropic;
1414
#[cfg(feature = "llama_cpp")]
1515
mod llama_cpp;
1616
mod mistral_fim;
17+
mod ollama;
1718
mod open_ai;
1819

1920
#[async_trait::async_trait]
@@ -71,6 +72,7 @@ impl TryFrom<ValidModel> for Box<dyn TransformerBackend + Send + Sync> {
7172
ValidModel::MistralFIM(mistral_fim) => {
7273
Ok(Box::new(mistral_fim::MistralFIM::new(mistral_fim)))
7374
}
75+
ValidModel::Ollama(ollama) => Ok(Box::new(ollama::Ollama::new(ollama))),
7476
}
7577
}
7678
}

src/transformer_backends/ollama.rs

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
use serde::{Deserialize, Serialize};
2+
use serde_json::{json, Value};
3+
use std::collections::HashMap;
4+
use tracing::instrument;
5+
6+
use crate::{
7+
config::{self, ChatMessage, FIM},
8+
memory_backends::Prompt,
9+
transformer_worker::{
10+
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
11+
},
12+
utils::{format_chat_messages, format_context_code},
13+
};
14+
15+
use super::TransformerBackend;
16+
17+
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
18+
#[derive(Debug, Deserialize)]
19+
pub struct OllamaRunParams {
20+
pub fim: Option<FIM>,
21+
messages: Option<Vec<ChatMessage>>,
22+
#[serde(default)]
23+
options: HashMap<String, Value>,
24+
system: Option<String>,
25+
template: Option<String>,
26+
keep_alive: Option<String>,
27+
}
28+
29+
pub struct Ollama {
30+
configuration: config::Ollama,
31+
}
32+
33+
#[derive(Deserialize)]
34+
struct OllamaCompletionsResponse {
35+
response: Option<String>,
36+
error: Option<Value>,
37+
#[serde(default)]
38+
#[serde(flatten)]
39+
other: HashMap<String, Value>,
40+
}
41+
42+
#[derive(Debug, Deserialize, Serialize)]
43+
struct OllamaChatMessage {
44+
role: String,
45+
content: String,
46+
}
47+
48+
#[derive(Deserialize)]
49+
struct OllamaChatResponse {
50+
message: Option<OllamaChatMessage>,
51+
error: Option<Value>,
52+
#[serde(default)]
53+
#[serde(flatten)]
54+
other: HashMap<String, Value>,
55+
}
56+
57+
impl Ollama {
58+
#[instrument]
59+
pub fn new(configuration: config::Ollama) -> Self {
60+
Self { configuration }
61+
}
62+
63+
async fn get_completion(
64+
&self,
65+
prompt: &str,
66+
params: OllamaRunParams,
67+
) -> anyhow::Result<String> {
68+
let client = reqwest::Client::new();
69+
let res: OllamaCompletionsResponse = client
70+
.post("http://localhost:11434/api/generate")
71+
.header("Content-Type", "application/json")
72+
.header("Accept", "application/json")
73+
.json(&json!({
74+
"model": self.configuration.model,
75+
"prompt": prompt,
76+
"options": params.options,
77+
"keep_alive": params.keep_alive,
78+
"raw": true,
79+
"stream": false
80+
}))
81+
.send()
82+
.await?
83+
.json()
84+
.await?;
85+
if let Some(error) = res.error {
86+
anyhow::bail!("{:?}", error.to_string())
87+
} else if let Some(response) = res.response {
88+
Ok(response)
89+
} else {
90+
anyhow::bail!(
91+
"Uknown error while making request to Ollama: {:?}",
92+
res.other
93+
)
94+
}
95+
}
96+
97+
async fn get_chat(
98+
&self,
99+
messages: Vec<ChatMessage>,
100+
params: OllamaRunParams,
101+
) -> anyhow::Result<String> {
102+
let client = reqwest::Client::new();
103+
let res: OllamaChatResponse = client
104+
.post("http://localhost:11434/api/chat")
105+
.header("Content-Type", "application/json")
106+
.header("Accept", "application/json")
107+
.json(&json!({
108+
"model": self.configuration.model,
109+
"system": params.system,
110+
"template": params.template,
111+
"messages": messages,
112+
"options": params.options,
113+
"keep_alive": params.keep_alive,
114+
"stream": false
115+
}))
116+
.send()
117+
.await?
118+
.json()
119+
.await?;
120+
if let Some(error) = res.error {
121+
anyhow::bail!("{:?}", error.to_string())
122+
} else if let Some(message) = res.message {
123+
Ok(message.content)
124+
} else {
125+
anyhow::bail!(
126+
"Unknown error while making request to Ollama: {:?}",
127+
res.other
128+
)
129+
}
130+
}
131+
132+
async fn do_chat_completion(
133+
&self,
134+
prompt: &Prompt,
135+
params: OllamaRunParams,
136+
) -> anyhow::Result<String> {
137+
match prompt {
138+
Prompt::ContextAndCode(code_and_context) => match &params.messages {
139+
Some(completion_messages) => {
140+
let messages = format_chat_messages(completion_messages, code_and_context);
141+
self.get_chat(messages, params).await
142+
}
143+
None => {
144+
self.get_completion(
145+
&format_context_code(&code_and_context.context, &code_and_context.code),
146+
params,
147+
)
148+
.await
149+
}
150+
},
151+
Prompt::FIM(fim) => match &params.fim {
152+
Some(fim_params) => {
153+
self.get_completion(
154+
&format!(
155+
"{}{}{}{}{}",
156+
fim_params.start,
157+
fim.prompt,
158+
fim_params.middle,
159+
fim.suffix,
160+
fim_params.end
161+
),
162+
params,
163+
)
164+
.await
165+
}
166+
None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"),
167+
},
168+
}
169+
}
170+
}
171+
172+
#[async_trait::async_trait]
173+
impl TransformerBackend for Ollama {
174+
#[instrument(skip(self))]
175+
async fn do_generate(
176+
&self,
177+
prompt: &Prompt,
178+
179+
params: Value,
180+
) -> anyhow::Result<DoGenerationResponse> {
181+
let params: OllamaRunParams = serde_json::from_value(params)?;
182+
let generated_text = self.do_chat_completion(prompt, params).await?;
183+
Ok(DoGenerationResponse { generated_text })
184+
}
185+
186+
#[instrument(skip(self))]
187+
async fn do_generate_stream(
188+
&self,
189+
request: &GenerationStreamRequest,
190+
_params: Value,
191+
) -> anyhow::Result<DoGenerationStreamResponse> {
192+
anyhow::bail!("GenerationStream is not yet implemented")
193+
}
194+
}
195+
196+
#[cfg(test)]
197+
mod test {
198+
use super::*;
199+
use serde_json::{from_value, json};
200+
201+
#[tokio::test]
202+
async fn ollama_completion_do_generate() -> anyhow::Result<()> {
203+
let configuration: config::Ollama = from_value(json!({
204+
"model": "llama3",
205+
}))?;
206+
let ollama = Ollama::new(configuration);
207+
let prompt = Prompt::default_without_cursor();
208+
let run_params = json!({
209+
"options": {
210+
"num_predict": 4
211+
}
212+
});
213+
let response = ollama.do_generate(&prompt, run_params).await?;
214+
assert!(!response.generated_text.is_empty());
215+
Ok(())
216+
}
217+
218+
#[tokio::test]
219+
async fn ollama_chat_do_generate() -> anyhow::Result<()> {
220+
let configuration: config::Ollama = from_value(json!({
221+
"model": "llama3",
222+
}))?;
223+
let ollama = Ollama::new(configuration);
224+
let prompt = Prompt::default_with_cursor();
225+
let run_params = json!({
226+
"messages": [
227+
{
228+
"role": "system",
229+
"content": "Test"
230+
},
231+
{
232+
"role": "user",
233+
"content": "Test {CONTEXT} - {CODE}"
234+
}
235+
],
236+
"options": {
237+
"num_predict": 4
238+
}
239+
});
240+
let response = ollama.do_generate(&prompt, run_params).await?;
241+
assert!(!response.generated_text.is_empty());
242+
Ok(())
243+
}
244+
}

0 commit comments

Comments
 (0)