Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

Commit 4a207f0

Browse files
authored
Merge pull request #36 from setzer22/feat/seed_rng
Adds a way to specify the seed for RNG
2 parents 75e9bbb + df16b3e commit 4a207f0

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

llama-cli/src/cli_args.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ pub struct Args {
6767
/// --cache-prompt
6868
#[arg(long, default_value = None)]
6969
pub restore_prompt: Option<String>,
70+
71+
/// Specifies the seed to use during sampling. Note that, depending on
72+
/// hardware, the same seed may lead to different results on two separate
73+
/// machines.
74+
#[arg(long, default_value = None)]
75+
pub seed: Option<u64>,
7076
}
7177

7278
/// CLI args are stored in a lazy static variable so they're accessible from

llama-cli/src/main.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{convert::Infallible, io::Write};
22

33
use cli_args::CLI_ARGS;
44
use llama_rs::{InferenceParameters, InferenceSnapshot};
5-
use rand::thread_rng;
5+
use rand::SeedableRng;
66

77
mod cli_args;
88

@@ -94,7 +94,11 @@ fn main() {
9494

9595
log::info!("Model fully loaded!");
9696

97-
let mut rng = thread_rng();
97+
let mut rng = if let Some(seed) = CLI_ARGS.seed {
98+
rand::rngs::StdRng::seed_from_u64(seed)
99+
} else {
100+
rand::rngs::StdRng::from_entropy()
101+
};
98102

99103
let mut session = if let Some(restore_path) = &args.restore_prompt {
100104
let snapshot = InferenceSnapshot::load_from_disk(restore_path);

0 commit comments

Comments
 (0)