diff --git a/src/engine/parameters.h b/src/engine/parameters.h index cee92c9e..f90835b8 100644 --- a/src/engine/parameters.h +++ b/src/engine/parameters.h @@ -37,6 +37,9 @@ struct ModelOutput { // logits for selected indices torch::Tensor logits; + + // hidden states for selected indices + torch::Tensor hidden_states; }; } // namespace llm \ No newline at end of file diff --git a/src/engine/worker.cpp b/src/engine/worker.cpp index 0961eb36..d33e59fa 100644 --- a/src/engine/worker.cpp +++ b/src/engine/worker.cpp @@ -161,6 +161,16 @@ ModelOutput Worker::execute_model(const ModelInput& inputs) { output.logprobs = sampling_params.logprobs; output.max_top_logprobs = sampling_params.max_top_logprobs; } + + // if (inputs.pooling_params.selected_token_idxes.defined()) { + // // call model to get embedding for each sequence + // do pooling based on query sequence length + // torch::Tensor selected_hidden_states = model_->pooling( + // hidden_states, inputs.pooling_params.selected_token_idxes); + + // // set hidden states to output + // output.hidden_states = selected_hidden_states; + // } return output; } diff --git a/src/request/output.h b/src/request/output.h index 29ac4e6f..52c25f2c 100644 --- a/src/request/output.h +++ b/src/request/output.h @@ -67,6 +67,9 @@ struct SequenceOutput { // log probabilities of the generated tokens. std::optional> logprobs; + + // the embeddings of the sequence, only used for embedding generation. + std::optional> embeddings; }; struct RequestOutput { diff --git a/src/request/sequence.cpp b/src/request/sequence.cpp index bdc18077..9e7da064 100644 --- a/src/request/sequence.cpp +++ b/src/request/sequence.cpp @@ -210,6 +210,14 @@ std::optional Sequence::build_delta_output_until( SequenceOutput Sequence::build_output(const Tokenizer& tokenizer) { AUTO_COUNTER(non_stream_decode_latency_seconds); + // return embeddings if available + if (this->embeddings_.has_value()) { + SequenceOutput output; + output.index = index_; + output.embeddings = std::move(this->embeddings_); + return output; + } + const auto ids = token_ids(); const size_t size = ids.size(); diff --git a/src/request/sequence.h b/src/request/sequence.h index d3d9abb8..92bb5959 100644 --- a/src/request/sequence.h +++ b/src/request/sequence.h @@ -145,6 +145,11 @@ class Sequence final { void append_token(const Token& token); void append_token(int64_t token_id) { append_token(Token(token_id)); } + // set embeddings for the sequence + void set_embeddings(std::vector&& embeddings) { + embeddings_ = std::move(embeddings); + } + // validate draft tokens with accepted tokens for speculative decoding // N.B. take int64_t as input to be compatible with torch::Tensor // returns the number of accepted tokens, including the resampled token @@ -259,6 +264,9 @@ class Sequence final { std::vector> top_tokens_; std::vector> top_logprobs_; + // sequence embeddings + std::optional> embeddings_; + // number of tokens in the sequence size_t num_tokens_ = 0;