11mod ggml;
22
3+ use core:: slice;
34use std:: {
45 collections:: HashMap ,
56 fmt:: Display ,
@@ -422,6 +423,17 @@ pub enum InferenceError {
422423 UserCallback ( Box < dyn std:: error:: Error > ) ,
423424}
424425
426+ /// Used in a call to `evaluate` to request information from the transformer.
427+ #[ derive( Default ) ]
428+ pub struct EvaluateOutputRequest {
429+ /// Returns all the logits for the provided batch of tokens.
430+ /// Output shape is n_batch * n_vocab
431+ pub all_logits : Option < Vec < f32 > > ,
432+ /// Returns the embeddings for the provided batch of tokens
433+ /// Output shape is n_batch * n_embd
434+ pub embeddings : Option < Vec < f32 > > ,
435+ }
436+
425437/// NOTE: The original code relies in promotion rules and automatic cast between
426438/// int to float. What we do instead is use this macro to convert every term of
427439/// the multiplication to f64, which should have enough precision bits to hold
@@ -1094,11 +1106,17 @@ impl Model {
10941106 }
10951107
10961108 /// Evaluates the transformer.
1109+ ///
1110+ /// The provided `output_request` struct lets you specify which additional
1111+ /// data you are interested in fetching from the transformer. Setting a
1112+ /// field to a `Some` value will clear and fill the provided vector with
1113+ /// data. The provided vector will be resized to the exact output size.
10971114 pub fn evaluate (
10981115 & self ,
10991116 session : & mut InferenceSession ,
11001117 params : & InferenceParameters ,
11011118 input_tokens : & [ TokenId ] ,
1119+ output_request : & mut EvaluateOutputRequest ,
11021120 ) {
11031121 let n = input_tokens. len ( ) ;
11041122 let n_past = session. n_past as i32 ;
@@ -1317,12 +1335,16 @@ impl Model {
13171335 input_layer = current;
13181336 }
13191337
1338+ // Used at the end to optionally extract the embeddings.
1339+ let embeddings_tensor;
1340+
13201341 // norm
13211342 {
13221343 input_layer = ctx0. op_norm ( & input_layer) ;
13231344
13241345 // inpL = norm*inpL
13251346 input_layer = ctx0. op_mul ( & ctx0. op_repeat ( & self . norm , & input_layer) , & input_layer) ;
1347+ embeddings_tensor = input_layer. share ( ) ;
13261348 }
13271349
13281350 // lm_head
@@ -1347,6 +1369,28 @@ impl Model {
13471369 )
13481370 } ;
13491371
1372+ // Extract logits
1373+ if let Some ( all_logits) = & mut output_request. all_logits {
1374+ all_logits. resize ( n_vocab as usize * n, 0.0 ) ;
1375+ // SAFETY: Tensor data can be read (properly aligned, initialized,
1376+ // data will not be mutated or otherwise aliased during the copy),
1377+ // and we're not reading past the end of the tensor data.
1378+ assert_eq ! ( input_layer. nelements( ) , n_vocab * n as i32 ) ;
1379+ unsafe {
1380+ input_layer. read_data ( 0 , bytemuck:: cast_slice_mut ( all_logits) ) ;
1381+ }
1382+ }
1383+
1384+ // Extract embeddings
1385+ if let Some ( embeddings) = & mut output_request. embeddings {
1386+ embeddings. resize ( n_embd as usize * n, 0.0 ) ;
1387+ // SAFETY: Same rationale as for the "Extract logits" section applies.
1388+ assert_eq ! ( embeddings_tensor. nelements( ) , n_embd * n as i32 ) ;
1389+ unsafe {
1390+ embeddings_tensor. read_data ( 0 , bytemuck:: cast_slice_mut ( embeddings) ) ;
1391+ }
1392+ }
1393+
13501394 // Adjust the required memory per token if we didn't know that already
13511395 if session. mem_per_token == 0 {
13521396 session. mem_per_token = ctx0. used_mem ( ) / n;
@@ -1418,7 +1462,7 @@ impl InferenceSession {
14181462 }
14191463
14201464 for batch in prompt_tokens. chunks ( 8 ) {
1421- model. evaluate ( self , params, batch) ;
1465+ model. evaluate ( self , params, batch, & mut EvaluateOutputRequest :: default ( ) ) ;
14221466 for & tk in batch {
14231467 // NOTE: No string ever tokenizes to the end of sentence. So we
14241468 // can just return the id here.
@@ -1452,7 +1496,12 @@ impl InferenceSession {
14521496 self . tokens . push ( next_token) ;
14531497
14541498 // Then, evaluate the network again to compute the new last_logits
1455- model. evaluate ( self , params, & [ next_token] ) ;
1499+ model. evaluate (
1500+ self ,
1501+ params,
1502+ & [ next_token] ,
1503+ & mut EvaluateOutputRequest :: default ( ) ,
1504+ ) ;
14561505
14571506 // Return the next token
14581507 Ok ( if next_token as TokenId == EOD_TOKEN_ID {
@@ -1533,7 +1582,6 @@ impl InferenceSession {
15331582 /// ggml context. While the provided `InferenceSnapshotRef` object is alive,
15341583 /// no other methods for this model object should be called.
15351584 pub unsafe fn get_snapshot ( & mut self ) -> InferenceSnapshotRef < ' _ > {
1536- use core:: slice;
15371585 let memory_k = unsafe {
15381586 slice:: from_raw_parts ( self . memory_k . data ( ) as * mut u8 , self . memory_k . nbytes ( ) )
15391587 } ;
0 commit comments