Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

Commit ecbbb81

Browse files
authored
Merge pull request #34 from philpax/misc-tweaks
Miscellaneous tweaks
2 parents c0e7708 + 9b3911b commit ecbbb81

File tree

5 files changed

+152
-115
lines changed

5 files changed

+152
-115
lines changed

README.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
55
![A llama riding a crab, AI-generated](./doc/resources/logo2.png)
66

7-
> *Image by [@darthdeus](https://github.com/darthdeus/), using Stable Diffusion*
7+
> _Image by [@darthdeus](https://github.com/darthdeus/), using Stable Diffusion_
88
99
[![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/F1F8DNO5D)
1010

1111
[![Latest version](https://img.shields.io/crates/v/llama-rs.svg)](https://crates.io/crates/llama_rs)
1212
![MIT](https://img.shields.io/badge/license-MIT-blue.svg)
13+
[![Discord](https://img.shields.io/discord/1085885067601137734)](https://discord.gg/YB9WaXYAWU)
1314

1415
![Gif showcasing language generation using llama-rs](./doc/resources/llama_gif.gif)
1516

@@ -43,7 +44,7 @@ Some additional things to try:
4344

4445
For example, you try the following prompt:
4546

46-
``` shell
47+
```shell
4748
cargo run --release -- -m /data/Llama/LLaMA/7B/ggml-model-q4_0.bin -p "Tell me how cool the Rust programming language is
4849
```
4950
@@ -52,14 +53,14 @@ cargo run --release -- -m /data/Llama/LLaMA/7B/ggml-model-q4_0.bin -p "Tell me h
5253
- **Q: Why did you do this?**
5354
- **A:** It was not my choice. Ferris appeared to me in my dreams and asked me
5455
to rewrite this in the name of the Holy crab.
55-
56+
5657
- **Q: Seriously now**
5758
- **A:** Come on! I don't want to get into a flame war. You know how it goes,
58-
*something something* memory *something something* cargo is nice, don't make
59+
_something something_ memory _something something_ cargo is nice, don't make
5960
me say it, everybody knows this already.
6061
6162
- **Q: I insist.**
62-
- **A:** *Sheesh! Okaaay*. After seeing the huge potential for **llama.cpp**,
63+
- **A:** _Sheesh! Okaaay_. After seeing the huge potential for **llama.cpp**,
6364
the first thing I did was to see how hard would it be to turn it into a
6465
library to embed in my projects. I started digging into the code, and realized
6566
the heavy lifting is done by `ggml` (a C library, easy to bind to Rust) and
@@ -69,9 +70,9 @@ cargo run --release -- -m /data/Llama/LLaMA/7B/ggml-model-q4_0.bin -p "Tell me h
6970
I'm more comfortable.
7071
7172
- **Q: Is this the real reason?**
72-
- **A:** Haha. Of course *not*. I just like collecting imaginary internet
73+
- **A:** Haha. Of course _not_. I just like collecting imaginary internet
7374
points, in the form of little stars, that people seem to give to me whenever I
74-
embark on pointless quests for *rewriting X thing, but in Rust*.
75+
embark on pointless quests for _rewriting X thing, but in Rust_.
7576
7677
## Known issues / To-dos
7778
@@ -86,5 +87,5 @@ Contributions welcome! Here's a few pressing issues:
8687
- [x] The code needs to be "library"-fied. It is nice as a showcase binary, but
8788
the real potential for this tool is to allow embedding in other services.
8889
- [x] The code only sets the right CFLAGS on Linux. The `build.rs` script in
89-
`ggml_raw` needs to be fixed, so inference *will be very slow on every
90-
other OS*.
90+
`ggml_raw` needs to be fixed, so inference _will be very slow on every
91+
other OS_.

llama-cli/src/cli_args.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ pub struct Args {
2121
pub num_threads: usize,
2222

2323
/// Sets how many tokens to predict
24-
#[arg(long, default_value_t = 128)]
25-
pub num_predict: usize,
24+
#[arg(long, short = 'n')]
25+
pub num_predict: Option<usize>,
2626

2727
/// Sets the size of the context (in tokens). Allows feeding longer prompts.
2828
/// Note that this affects memory. TODO: Unsure how large the limit is.

llama-cli/src/main.rs

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::io::Write;
1+
use std::{convert::Infallible, io::Write};
22

33
use cli_args::CLI_ARGS;
44
use llama_rs::{InferenceParameters, InferenceSnapshot};
@@ -16,7 +16,6 @@ fn main() {
1616

1717
let inference_params = InferenceParameters {
1818
n_threads: args.num_threads as i32,
19-
n_predict: args.num_predict,
2019
n_batch: args.batch_size,
2120
top_k: args.top_k,
2221
top_p: args.top_p,
@@ -43,7 +42,7 @@ fn main() {
4342
llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |progress| {
4443
use llama_rs::LoadProgress;
4544
match progress {
46-
LoadProgress::HyperParamsLoaded(hparams) => {
45+
LoadProgress::HyperparametersLoaded(hparams) => {
4746
log::debug!("Loaded HyperParams {hparams:#?}")
4847
}
4948
LoadProgress::BadToken { index } => {
@@ -114,19 +113,24 @@ fn main() {
114113
};
115114

116115
if let Some(cache_path) = &args.cache_prompt {
117-
let res = session.feed_prompt(&model, &vocab, &inference_params, &prompt, |t| {
118-
print!("{t}");
119-
std::io::stdout().flush().unwrap();
120-
});
116+
let res =
117+
session.feed_prompt::<Infallible>(&model, &vocab, &inference_params, &prompt, |t| {
118+
print!("{t}");
119+
std::io::stdout().flush().unwrap();
120+
121+
Ok(())
122+
});
123+
121124
println!();
125+
122126
match res {
123127
Ok(_) => (),
124-
Err(llama_rs::Error::ContextFull) => {
128+
Err(llama_rs::InferenceError::ContextFull) => {
125129
log::warn!(
126130
"Context is not large enough to fit the prompt. Saving intermediate state."
127131
);
128132
}
129-
err => unreachable!("{err:?}"),
133+
Err(llama_rs::InferenceError::UserCallback(_)) => unreachable!("cannot fail"),
130134
}
131135

132136
// Write the memory to the cache file
@@ -144,25 +148,28 @@ fn main() {
144148
}
145149
}
146150
} else {
147-
let res = session.inference_with_prompt(
151+
let res = session.inference_with_prompt::<Infallible>(
148152
&model,
149153
&vocab,
150154
&inference_params,
151155
&prompt,
156+
args.num_predict,
152157
&mut rng,
153158
|t| {
154159
print!("{t}");
155160
std::io::stdout().flush().unwrap();
161+
162+
Ok(())
156163
},
157164
);
158165
println!();
159166

160167
match res {
161168
Ok(_) => (),
162-
Err(llama_rs::Error::ContextFull) => {
169+
Err(llama_rs::InferenceError::ContextFull) => {
163170
log::warn!("Context window full, stopping inference.")
164171
}
165-
err => unreachable!("{err:?}"),
172+
Err(llama_rs::InferenceError::UserCallback(_)) => unreachable!("cannot fail"),
166173
}
167174
}
168175
}

0 commit comments

Comments
 (0)