Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

Commit a067431

Browse files
setzer22philpax
andauthored
Embedding extraction (#72)
* Add code to extract embeddings * Add ad_hoc_test for embeddings * Use Tensor::read_data * Adjust safety comment * Remove ad-hoc test * Clippy & fmt * Fix comment that ended halfway --------- Co-authored-by: Philpax <me@philpax.me>
1 parent b103dcd commit a067431

1 file changed

Lines changed: 51 additions & 3 deletions

File tree

llama-rs/src/lib.rs

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mod ggml;
22

3+
use core::slice;
34
use std::{
45
collections::HashMap,
56
fmt::Display,
@@ -422,6 +423,17 @@ pub enum InferenceError {
422423
UserCallback(Box<dyn std::error::Error>),
423424
}
424425

426+
/// Used in a call to `evaluate` to request information from the transformer.
427+
#[derive(Default)]
428+
pub struct EvaluateOutputRequest {
429+
/// Returns all the logits for the provided batch of tokens.
430+
/// Output shape is n_batch * n_vocab
431+
pub all_logits: Option<Vec<f32>>,
432+
/// Returns the embeddings for the provided batch of tokens
433+
/// Output shape is n_batch * n_embd
434+
pub embeddings: Option<Vec<f32>>,
435+
}
436+
425437
/// NOTE: The original code relies in promotion rules and automatic cast between
426438
/// int to float. What we do instead is use this macro to convert every term of
427439
/// the multiplication to f64, which should have enough precision bits to hold
@@ -1094,11 +1106,17 @@ impl Model {
10941106
}
10951107

10961108
/// Evaluates the transformer.
1109+
///
1110+
/// The provided `output_request` struct lets you specify which additional
1111+
/// data you are interested in fetching from the transformer. Setting a
1112+
/// field to a `Some` value will clear and fill the provided vector with
1113+
/// data. The provided vector will be resized to the exact output size.
10971114
pub fn evaluate(
10981115
&self,
10991116
session: &mut InferenceSession,
11001117
params: &InferenceParameters,
11011118
input_tokens: &[TokenId],
1119+
output_request: &mut EvaluateOutputRequest,
11021120
) {
11031121
let n = input_tokens.len();
11041122
let n_past = session.n_past as i32;
@@ -1317,12 +1335,16 @@ impl Model {
13171335
input_layer = current;
13181336
}
13191337

1338+
// Used at the end to optionally extract the embeddings.
1339+
let embeddings_tensor;
1340+
13201341
// norm
13211342
{
13221343
input_layer = ctx0.op_norm(&input_layer);
13231344

13241345
// inpL = norm*inpL
13251346
input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer);
1347+
embeddings_tensor = input_layer.share();
13261348
}
13271349

13281350
// lm_head
@@ -1347,6 +1369,28 @@ impl Model {
13471369
)
13481370
};
13491371

1372+
// Extract logits
1373+
if let Some(all_logits) = &mut output_request.all_logits {
1374+
all_logits.resize(n_vocab as usize * n, 0.0);
1375+
// SAFETY: Tensor data can be read (properly aligned, initialized,
1376+
// data will not be mutated or otherwise aliased during the copy),
1377+
// and we're not reading past the end of the tensor data.
1378+
assert_eq!(input_layer.nelements(), n_vocab * n as i32);
1379+
unsafe {
1380+
input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits));
1381+
}
1382+
}
1383+
1384+
// Extract embeddings
1385+
if let Some(embeddings) = &mut output_request.embeddings {
1386+
embeddings.resize(n_embd as usize * n, 0.0);
1387+
// SAFETY: Same rationale as for the "Extract logits" section applies.
1388+
assert_eq!(embeddings_tensor.nelements(), n_embd * n as i32);
1389+
unsafe {
1390+
embeddings_tensor.read_data(0, bytemuck::cast_slice_mut(embeddings));
1391+
}
1392+
}
1393+
13501394
// Adjust the required memory per token if we didn't know that already
13511395
if session.mem_per_token == 0 {
13521396
session.mem_per_token = ctx0.used_mem() / n;
@@ -1418,7 +1462,7 @@ impl InferenceSession {
14181462
}
14191463

14201464
for batch in prompt_tokens.chunks(8) {
1421-
model.evaluate(self, params, batch);
1465+
model.evaluate(self, params, batch, &mut EvaluateOutputRequest::default());
14221466
for &tk in batch {
14231467
// NOTE: No string ever tokenizes to the end of sentence. So we
14241468
// can just return the id here.
@@ -1452,7 +1496,12 @@ impl InferenceSession {
14521496
self.tokens.push(next_token);
14531497

14541498
// Then, evaluate the network again to compute the new last_logits
1455-
model.evaluate(self, params, &[next_token]);
1499+
model.evaluate(
1500+
self,
1501+
params,
1502+
&[next_token],
1503+
&mut EvaluateOutputRequest::default(),
1504+
);
14561505

14571506
// Return the next token
14581507
Ok(if next_token as TokenId == EOD_TOKEN_ID {
@@ -1533,7 +1582,6 @@ impl InferenceSession {
15331582
/// ggml context. While the provided `InferenceSnapshotRef` object is alive,
15341583
/// no other methods for this model object should be called.
15351584
pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> {
1536-
use core::slice;
15371585
let memory_k = unsafe {
15381586
slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes())
15391587
};

0 commit comments

Comments
 (0)