Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

Commit 34ba664

Browse files
committed
Copy v_transposed like llama.cpp
See ggml-org/llama.cpp#439 Closes #67
1 parent 08b875c commit 34ba664

File tree

2 files changed

+45
-19
lines changed

2 files changed

+45
-19
lines changed

llama-cli/src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ fn main() {
112112
}
113113
}),
114114
play_back_previous_tokens: false,
115+
..Default::default()
115116
};
116117
let inference_session_params = {
117118
let mem_typ = if args.float16 {

llama-rs/src/lib.rs

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ impl Default for InferenceSessionParameters {
174174
}
175175
}
176176

177+
#[derive(Clone, Debug, PartialEq)]
177178
/// The parameters that drive text generation.
178179
pub struct InferenceParameters {
179180
pub n_threads: i32,
@@ -184,6 +185,7 @@ pub struct InferenceParameters {
184185
pub temp: f32,
185186
pub bias_tokens: TokenBias,
186187
pub play_back_previous_tokens: bool,
188+
pub increased_determinism: bool,
187189
}
188190

189191
impl Default for InferenceParameters {
@@ -197,6 +199,7 @@ impl Default for InferenceParameters {
197199
temp: 0.80,
198200
bias_tokens: TokenBias::default(),
199201
play_back_previous_tokens: false,
202+
increased_determinism: true,
200203
}
201204
}
202205
}
@@ -1094,11 +1097,13 @@ impl Model {
10941097
pub fn evaluate(
10951098
&self,
10961099
session: &mut InferenceSession,
1097-
n_threads: i32,
1100+
params: &InferenceParameters,
10981101
input_tokens: &[TokenId],
10991102
) {
11001103
let n = input_tokens.len();
11011104
let n_past = session.n_past as i32;
1105+
let n_threads = params.n_threads;
1106+
let increased_determinism = params.increased_determinism;
11021107

11031108
let Hyperparameters {
11041109
n_vocab,
@@ -1127,6 +1132,27 @@ impl Model {
11271132

11281133
let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd);
11291134

1135+
// Defined here to avoid repetition and creating a binding inside nested loops.
1136+
// See the call site below for more context.
1137+
let vtrans_fun = |il: usize| -> ggml::Tensor {
1138+
ctx0.op_permute(
1139+
&ctx0.op_reshape_3d(
1140+
&ctx0.op_view_1d(
1141+
&session.memory_v,
1142+
(n_past + n as i32) * n_embd,
1143+
il * n_ctx as usize * session.memory_v.element_size() * n_embd as usize,
1144+
),
1145+
n_embd / n_head,
1146+
n_head,
1147+
n_past + n as i32,
1148+
),
1149+
1,
1150+
2,
1151+
0,
1152+
3,
1153+
)
1154+
};
1155+
11301156
for il in 0..n_layer as usize {
11311157
let input_self_attention = input_layer.share();
11321158
let mut current: ggml::Tensor;
@@ -1226,22 +1252,21 @@ impl Model {
12261252
let k_q_soft_max = ctx0.op_soft_max(&k_q_masked);
12271253

12281254
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
1229-
let v_transposed = ctx0.op_permute(
1230-
&ctx0.op_reshape_3d(
1231-
&ctx0.op_view_1d(
1232-
&session.memory_v,
1233-
(n_past + n as i32) * n_embd,
1234-
il * n_ctx as usize * session.memory_v.element_size() * n_embd as usize,
1235-
),
1236-
n_embd / n_head,
1237-
n_head,
1238-
n_past + n as i32,
1239-
),
1240-
1,
1241-
2,
1242-
0,
1243-
3,
1244-
);
1255+
let v_transposed = {
1256+
if !increased_determinism {
1257+
vtrans_fun(il)
1258+
} else {
1259+
ctx0.op_cpy(
1260+
&vtrans_fun(il),
1261+
&ctx0.new_tensor_3d(
1262+
ggml::TYPE_F32,
1263+
n_past + n as i32,
1264+
n_embd / n_head,
1265+
n_head,
1266+
),
1267+
)
1268+
}
1269+
};
12451270

12461271
// KQV = transpose(V) * KQ_soft_max
12471272
let k_q_v = ctx0.op_mul_mat(&v_transposed, &k_q_soft_max);
@@ -1393,7 +1418,7 @@ impl InferenceSession {
13931418
}
13941419

13951420
for batch in prompt_tokens.chunks(8) {
1396-
model.evaluate(self, params.n_threads, batch);
1421+
model.evaluate(self, params, batch);
13971422
for &tk in batch {
13981423
// NOTE: No string ever tokenizes to the end of sentence. So we
13991424
// can just return the id here.
@@ -1427,7 +1452,7 @@ impl InferenceSession {
14271452
self.tokens.push(next_token);
14281453

14291454
// Then, evaluate the network again to compute the new last_logits
1430-
model.evaluate(self, params.n_threads, &[next_token]);
1455+
model.evaluate(self, params, &[next_token]);
14311456

14321457
// Return the next token
14331458
Ok(if next_token as TokenId == EOD_TOKEN_ID {

0 commit comments

Comments
 (0)