diff --git a/BRUCE_LOGPROBS_COMPARISON.md b/BRUCE_LOGPROBS_COMPARISON.md new file mode 100644 index 00000000000..fd75d71e918 --- /dev/null +++ b/BRUCE_LOGPROBS_COMPARISON.md @@ -0,0 +1,143 @@ +# Comparison: Bruce MacDonald's vs Our Log Probabilities Implementation + +## Overview + +After examining Bruce MacDonald's `brucemacd/logprobs` branch, here are the key differences and insights from his approach compared to our implementation. + +## Key Differences + +### 1. API Structure + +**Bruce's Approach:** +- Added `LogProbs int` field to `GenerateRequest` and `ChatRequest` (specifies number of log probs to return) +- Uses `TokenProbs` struct with fields: + ```go + type TokenProbs struct { + TokenID int `json:"id"` + LogProb float32 `json:"logprob"` + Token string `json:"token"` + } + ``` +- Returns `LogProbs []TokenProbs` in `ChatResponse` and `GenerateResponse` +- Does NOT modify the OpenAI compatibility layer + +**Our Approach:** +- Added `LogProbs bool` and `TopLogProbs *int` to `ChatCompletionRequest` in OpenAI layer +- Created more complex structures to match OpenAI's format: + ```go + type LogProbs struct { + Content []LogProbContent `json:"content"` + } + type LogProbContent struct { + Token string `json:"token"` + LogProb float32 `json:"logprob"` + Bytes []byte `json:"bytes,omitempty"` + TopLogProbs []LogProbToken `json:"top_logprobs,omitempty"` + } + ``` +- Modified both Ollama native API and OpenAI compatibility layer + +### 2. Implementation Scope + +**Bruce's Approach:** +- Focused on Ollama's native API only +- Simpler, more direct implementation +- No OpenAI compatibility layer modifications +- Returns token ID alongside the token text + +**Our Approach:** +- Comprehensive implementation covering both APIs +- Added OpenAI-compatible request/response structures +- More complex conversion logic between formats +- Focus on OpenAI schema compatibility + +### 3. LLM Server Integration + +**Bruce's Approach:** +- Modified `llm/server.go` to include `LogProbs int` in `CompletionRequest` +- Returns `LogProbs []TokenProbs` in `CompletionResponse` +- Simpler token probability structure in the completion response + +**Our Approach:** +- Added `n_probs` parameter to llama.cpp requests +- More complex parsing of `completion_probabilities` from llama.cpp +- Conversion logic to transform llama.cpp format to OpenAI format + +### 4. Token Information + +**Bruce's Approach:** +- Includes `TokenID` in the response (useful for debugging and analysis) +- Simpler structure makes it easier to process + +**Our Approach:** +- Focuses on OpenAI compatibility +- Includes byte representation of tokens +- Supports top-k log probabilities for each token + +## Lessons Learned from Bruce's Implementation + +### 1. **Simplicity First** +Bruce's implementation is notably simpler and more focused. He chose to: +- Keep the native Ollama API separate from OpenAI compatibility +- Use a minimal structure that provides essential information +- Avoid complex conversions and nested structures + +### 2. **Token IDs are Valuable** +Bruce includes token IDs in the response, which our implementation doesn't. This is useful for: +- Debugging tokenization issues +- Understanding model behavior +- Correlating with vocabulary files + +### 3. **Incremental Approach** +Bruce's implementation doesn't try to solve everything at once: +- No OpenAI compatibility layer changes +- Focus on core functionality first +- Leaves room for future enhancements + +### 4. **Native API Design** +Bruce's approach suggests that Ollama's native API should have its own design philosophy rather than trying to mirror OpenAI exactly. + +## Recommendations + +Based on Bruce's approach, we might consider: + +1. **Simplifying our native API implementation** - Use Bruce's simpler `TokenProbs` structure for Ollama's native API +2. **Including Token IDs** - Add token IDs to our response for better debugging capabilities +3. **Separating concerns** - Keep OpenAI compatibility as a separate layer rather than mixing it with native API +4. **Phased approach** - Consider implementing log probabilities in phases: + - Phase 1: Native API support (like Bruce's) + - Phase 2: OpenAI compatibility layer + - Phase 3: Advanced features (top-k, bytes, etc.) + +## Technical Implementation Details + +### Bruce's Server Route Handler +```go +// Simplified log probability handling +for _, p := range r.LogProbs { + res.LogProbs = append(res.LogProbs, api.TokenProbs{ + TokenID: p.TokenID, + LogProb: p.LogProb, + Token: p.Token, + }) +} +``` + +### Our Implementation +```go +// Complex conversion with top-k support +topK := int(3) +logits := make([]float32, len(cr.Logits)) +copy(logits, cr.Logits) +res.TopLogprobs = getTopKLogProbs(c.Request.Context(), r, logits, topK) +``` + +## Conclusion + +Bruce MacDonald's implementation demonstrates a more idiomatic approach for Ollama: +- Simpler and more maintainable +- Focuses on core functionality +- Doesn't conflate native API with OpenAI compatibility +- Provides useful debugging information (token IDs) + +While our implementation provides fuller OpenAI compatibility, Bruce's approach suggests that starting simple and building incrementally might be a better strategy for the official Ollama codebase. \ No newline at end of file diff --git a/api/types.go b/api/types.go index f4c5b1058e1..c508891d6c2 100644 --- a/api/types.go +++ b/api/types.go @@ -77,6 +77,8 @@ type GenerateRequest struct { // request, for multimodal models. Images []ImageData `json:"images,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + // Options lists model-specific options. For example, temperature can be // set through this field, if the model supports it. Options map[string]interface{} `json:"options"` @@ -103,6 +105,8 @@ type ChatRequest struct { // Tools is an optional list of tools the model has access to. Tools `json:"tools,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + // Options lists model-specific options. Options map[string]interface{} `json:"options"` } @@ -182,13 +186,20 @@ func (t *ToolFunction) String() string { return string(bts) } +type TokenProbs struct { + TokenID int `json:"id"` + LogProb float32 `json:"logprob"` + Token string `json:"token"` +} + // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Message Message `json:"message"` - DoneReason string `json:"done_reason,omitempty"` + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Message Message `json:"message"` + DoneReason string `json:"done_reason,omitempty"` + LogProbs []TokenProbs `json:"logprobs,omitempty"` Done bool `json:"done"` @@ -452,6 +463,8 @@ type GenerateResponse struct { // can be sent in the next request to keep a conversational memory. Context []int `json:"context,omitempty"` + LogProbs []TokenProbs `json:"logprobs,omitempty"` + Metrics } diff --git a/llama/llama.go b/llama/llama.go index a20f23578a2..90852c3d8c0 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -50,7 +50,7 @@ import ( _ "github.com/ollama/ollama/llama/llama.cpp/common" _ "github.com/ollama/ollama/llama/llama.cpp/examples/llava" _ "github.com/ollama/ollama/llama/llama.cpp/src" - "github.com/ollama/ollama/ml/backend/ggml/ggml/src" + ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" ) func BackendInit() { @@ -220,6 +220,19 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 { return embeddings } +// GetLogits returns the logits from the last decode operation. +// The returned slice has length equal to the vocabulary size. +func (c *Context) GetLogits() []float32 { + logits := unsafe.Pointer(C.llama_get_logits(c.c)) + if logits == nil { + return nil + } + + // Get the number of vocabulary tokens to determine array size + vocabSize := c.Model().NumVocab() + return unsafe.Slice((*float32)(logits), vocabSize) +} + type ModelParams struct { NumGpuLayers int MainGpu int diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 60ae88dacb2..815ce619186 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -8,12 +8,14 @@ import ( "fmt" "log" "log/slog" + "math" "net" "net/http" "os" "path/filepath" "regexp" "runtime" + "sort" "strconv" "strings" "sync" @@ -48,8 +50,9 @@ type Sequence struct { // inputs that have been added to a batch but not yet submitted to Decode pendingInputs []input + // TODO: update this comment // tokens that have been generated but not returned yet (e.g. for stop sequences) - pendingResponses []string + pendingResponses []CompletionResponse // input cache being used by this sequence cache *InputCacheSlot @@ -59,7 +62,7 @@ type Sequence struct { crossAttention bool // channel to send responses over - responses chan string + responses chan CompletionResponse // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -83,6 +86,11 @@ type Sequence struct { doneReason string + logits []float32 + + // number of logprobs to return with the completion response + logprobs int + // Metrics startProcessingTime time.Time startGenerationTime time.Time @@ -96,6 +104,7 @@ type NewSequenceParams struct { numKeep int samplingParams *llama.SamplingParams embedding bool + logprobs int } func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { @@ -148,14 +157,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen numPromptInputs: len(inputs), startProcessingTime: startTime, numPredict: params.numPredict, - pendingResponses: make([]string, 0), - responses: make(chan string, 100), + pendingResponses: make([]CompletionResponse, 0), + responses: make(chan CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), samplingCtx: sc, embeddingOnly: params.embedding, stop: params.stop, numKeep: params.numKeep, + logprobs: params.logprobs, }, nil } @@ -274,29 +284,37 @@ func (s *Server) allNil() bool { } func flushPending(seq *Sequence) bool { - joined := strings.Join(seq.pendingResponses, "") - seq.pendingResponses = []string{} - - // Check if there are any partial UTF-8 characters remaining. - // We already check and queue as we are generating but some may - // still make it here: - // - Sequence is ending, e.g. generation limit has been hit - // - Invalid characters in the middle of a string - // This is a stricter check to ensure we never output invalid Unicode. - for !utf8.ValidString(joined) { - joined = joined[:len(joined)-1] - } - - if len(joined) == 0 { + if len(seq.pendingResponses) == 0 { return true } + resps := []CompletionResponse{} + for _, resp := range seq.pendingResponses { + resps = append(resps, resp) + } + seq.pendingResponses = []CompletionResponse{} + + // TODO: figure out this result logic + result := false + for _, resp := range resps { + // Check if there are any partial UTF-8 characters remaining. + // We already check and queue as we are generating but some may + // still make it here: + // - Sequence is ending, e.g. generation limit has been hit + // - Invalid characters in the middle of a string + // This is a stricter check to ensure we never output invalid Unicode. + for !utf8.ValidString(resp.Content) { + resp.Content = resp.Content[:len(resp.Content)-1] + } - select { - case seq.responses <- joined: - return true - case <-seq.quit: - return false + select { + case seq.responses <- resp: + result = true + case <-seq.quit: + result = false + } } + + return result } func (s *Server) removeSequence(seqIndex int, reason string) { @@ -350,6 +368,63 @@ func (s *Server) run(ctx context.Context) { } } +// TokenProbs represents probability information for a token +type TokenProbs struct { + TokenID int `json:"id"` + Logit float32 `json:"logit"` + Prob float32 `json:"prob"` + LogProb float32 `json:"logprob"` + Token string `json:"token"` +} + +// probs returns sorted token probabilities for a specific token index +func probs(logits []float32, vocabSize int) []TokenProbs { + probs := make([]TokenProbs, vocabSize) + + // Initialize token data with logits + for i := 0; i < vocabSize; i++ { + probs[i] = TokenProbs{ + TokenID: i, + Logit: logits[i], + } + } + + // Sort tokens by logits in descending order + sort.Slice(probs, func(i, j int) bool { + return probs[i].Logit > probs[j].Logit + }) + + // Apply softmax + maxLogit := probs[0].Logit + var sum float32 = 0.0 + + for i := range probs { + p := float32(math.Exp(float64(probs[i].Logit - maxLogit))) + probs[i].Prob = p + sum += p + } + + // Normalize probabilities and calculate log probs + for i := range probs { + prob := probs[i].Prob / sum + probs[i].Prob = prob + probs[i].LogProb = float32(math.Log(float64(prob))) + } + + return probs +} + +// probs returns sorted token probabilities for a specific token index +func (s *Server) probs(seq *Sequence) []TokenProbs { + // Get logits for the specific token index + logits := s.lc.GetLogits() + seq.logits = make([]float32, len(logits)) + copy(seq.logits, logits) + + vocabSize := s.model.NumVocab() + return probs(logits, vocabSize) +} + // TODO (jmorganca): processBatch should be simplified, removing: // * sampling // * stop token checking @@ -483,6 +558,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.numPredicted++ + resp := CompletionResponse{Content: piece} + + if seq.logprobs > 0 { + // TODO: return selected token in logprobs always + resp.LogProbs = s.probs(seq) + // TODO: fix this logprobs limit + resp.LogProbs = resp.LogProbs[:min(len(resp.LogProbs), seq.logprobs)] + for i := range resp.LogProbs { + // decode the token id to a piece + resp.LogProbs[i].Token = s.model.TokenToPiece(resp.LogProbs[i].TokenID) + } + } + // if it's an end of sequence token, break if s.model.TokenIsEog(token) { // TODO (jmorganca): we should send this back @@ -495,16 +583,21 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.inputs = []input{{token: token}} - seq.pendingResponses = append(seq.pendingResponses, piece) - sequence := strings.Join(seq.pendingResponses, "") + // TODO: add probs here + seq.pendingResponses = append(seq.pendingResponses, resp) + var sequence string + for _, r := range seq.pendingResponses { + sequence += r.Content + } if ok, stop := findStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) + // TODO: fix this stop sequence caching var tokenTruncated bool - origLen := len(seq.pendingResponses) - seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop) - newLen := len(seq.pendingResponses) + origLen := len(sequence) + sequence, tokenTruncated = truncateStop(sequence, stop) + newLen := len(sequence) // Update the cache based on the tokens that will be returned: // - We have 1 token more than is currently in the cache because @@ -575,6 +668,7 @@ type CompletionRequest struct { Images []ImageData `json:"image_data"` Grammar string `json:"grammar"` CachePrompt bool `json:"cache_prompt"` + Logprobs int `json:"logprobs,omitempty"` Options } @@ -590,8 +684,10 @@ type CompletionResponse struct { Content string `json:"content"` Stop bool `json:"stop"` - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + Prompt string `json:"prompt,omitempty"` + LogProbs []TokenProbs `json:"logprobs,omitempty"` + StoppedLimit bool `json:"stopped_limit,omitempty"` PredictedN int `json:"predicted_n,omitempty"` PredictedMS float64 `json:"predicted_ms,omitempty"` @@ -609,10 +705,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - // Set the headers to indicate streaming - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Transfer-Encoding", "chunked") - flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming not supported", http.StatusInternalServerError) @@ -641,6 +733,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { numKeep: req.NumKeep, samplingParams: &samplingParams, embedding: false, + logprobs: req.Logprobs, }) if err != nil { http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) @@ -688,11 +781,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { case <-r.Context().Done(): close(seq.quit) return - case content, ok := <-seq.responses: + case resp, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Content: content, - }); err != nil { + fmt.Println("response", resp) + if err := json.NewEncoder(w).Encode(&resp); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) return diff --git a/llama/runner/runner_test.go b/llama/runner/runner_test.go new file mode 100644 index 00000000000..bb4a6da9e4a --- /dev/null +++ b/llama/runner/runner_test.go @@ -0,0 +1,58 @@ +package runner + +import ( + "math" + "testing" +) + +func TestProbs(t *testing.T) { + // Input test data + logits := []float32{1.0, 2.0, 0.5, -1.0} + vocabSize := 4 + want := []TokenProbs{ + {TokenID: 1, Logit: 2.0}, // Highest logit + {TokenID: 0, Logit: 1.0}, // Second highest + {TokenID: 2, Logit: 0.5}, // Third + {TokenID: 3, Logit: -1.0}, // Lowest + } + + got := probs(logits, vocabSize) + + // Test 1: Check sorting order + for i := 0; i < len(got)-1; i++ { + if got[i].Logit < got[i+1].Logit { + t.Errorf("probs not properly sorted: logit at pos %d (%f) < logit at pos %d (%f)", + i, got[i].Logit, i+1, got[i+1].Logit) + } + } + + // Test 2: Check probability normalization + var sum float32 + for _, p := range got { + sum += p.Prob + } + if math.Abs(float64(sum-1.0)) > 1e-6 { + t.Errorf("probabilities do not sum to 1: got %v", sum) + } + + // Test 3: Check token IDs match expected order + for i, want := range want { + if got[i].TokenID != want.TokenID { + t.Errorf("wrong token ID at position %d: got %d, want %d", + i, got[i].TokenID, want.TokenID) + } + if got[i].Logit != want.Logit { + t.Errorf("wrong logit at position %d: got %f, want %f", + i, got[i].Logit, want.Logit) + } + } + + // Test 4: Check log probs are correctly calculated + for i, p := range got { + expectedLogProb := float32(math.Log(float64(p.Prob))) + if math.Abs(float64(p.LogProb-expectedLogProb)) > 1e-6 { + t.Errorf("wrong log prob at position %d: got %f, want %f", + i, p.LogProb, expectedLogProb) + } + } +} diff --git a/llama/runner/stop.go b/llama/runner/stop.go index 8dcb08d331d..ff5de43c636 100644 --- a/llama/runner/stop.go +++ b/llama/runner/stop.go @@ -26,43 +26,15 @@ func containsStopSuffix(sequence string, stops []string) bool { return false } -// truncateStop removes the provided stop string from pieces, -// returning the partial pieces with stop removed, including truncating -// the last piece if required (and signalling if this was the case) -func truncateStop(pieces []string, stop string) ([]string, bool) { - joined := strings.Join(pieces, "") - - index := strings.Index(joined, stop) +// truncateStop removes the provided stop string from sequence, +// returning both the truncated sequence and a bool indicating if truncation occurred +func truncateStop(sequence string, stop string) (string, bool) { + index := strings.Index(sequence, stop) if index == -1 { - return pieces, false - } - - joined = joined[:index] - - // Split truncated string back into pieces of original lengths - lengths := make([]int, len(pieces)) - for i, piece := range pieces { - lengths[i] = len(piece) - } - - var result []string - tokenTruncated := false - start := 0 - for _, length := range lengths { - if start >= len(joined) { - break - } - - end := start + length - if end > len(joined) { - end = len(joined) - tokenTruncated = true - } - result = append(result, joined[start:end]) - start = end + return sequence, false } - return result, tokenTruncated + return sequence[:index], true } func incompleteUnicode(token string) bool { diff --git a/llama/runner/stop_test.go b/llama/runner/stop_test.go index 31dc161f379..52637ff5e29 100644 --- a/llama/runner/stop_test.go +++ b/llama/runner/stop_test.go @@ -1,60 +1,60 @@ package runner import ( - "reflect" "testing" ) func TestTruncateStop(t *testing.T) { tests := []struct { name string - pieces []string + sequence string stop string - expected []string + expected string expectedTrunc bool }{ { name: "Single word", - pieces: []string{"hello", "world"}, + sequence: "helloworld", stop: "world", - expected: []string{"hello"}, - expectedTrunc: false, + expected: "hello", + expectedTrunc: true, }, { name: "Partial", - pieces: []string{"hello", "wor"}, + sequence: "hellowor", stop: "or", - expected: []string{"hello", "w"}, + expected: "hellow", expectedTrunc: true, }, { name: "Suffix", - pieces: []string{"Hello", " there", "!"}, + sequence: "Hello there!", stop: "!", - expected: []string{"Hello", " there"}, - expectedTrunc: false, - }, - { - name: "Suffix partial", - pieces: []string{"Hello", " the", "re!"}, - stop: "there!", - expected: []string{"Hello", " "}, + expected: "Hello there", expectedTrunc: true, }, { name: "Middle", - pieces: []string{"hello", " wor"}, + sequence: "hello wor", stop: "llo w", - expected: []string{"he"}, + expected: "he", expectedTrunc: true, }, + { + name: "No stop found", + sequence: "hello world", + stop: "xyz", + expected: "hello world", + expectedTrunc: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, resultTrunc := truncateStop(tt.pieces, tt.stop) - if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { - t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) + result, truncated := truncateStop(tt.sequence, tt.stop) + if result != tt.expected || truncated != tt.expectedTrunc { + t.Errorf("truncateStop(%q, %q): have %q (%v); want %q (%v)", + tt.sequence, tt.stop, result, truncated, tt.expected, tt.expectedTrunc) } }) } diff --git a/llm/server.go b/llm/server.go index 881209b3951..0f409c7cfef 100644 --- a/llm/server.go +++ b/llm/server.go @@ -644,12 +644,22 @@ type ImageData struct { AspectRatioID int `json:"aspect_ratio_id"` } +// TokenProbs represents probability information for a token +type TokenProbs struct { + TokenID int `json:"id"` + Logit float32 `json:"logit"` + Prob float32 `json:"prob"` + LogProb float32 `json:"logprob"` + Token string `json:"token"` +} + type completion struct { - Content string `json:"content"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Stop bool `json:"stop"` - StoppedLimit bool `json:"stopped_limit"` + Content string `json:"content"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Stop bool `json:"stop"` + StoppedLimit bool `json:"stopped_limit"` + LogProbs []TokenProbs `json:"logprobs"` Timings struct { PredictedN int `json:"predicted_n"` @@ -660,14 +670,16 @@ type completion struct { } type CompletionRequest struct { - Prompt string - Format json.RawMessage - Images []ImageData - Options *api.Options + Prompt string + Format json.RawMessage + Images []ImageData + LogProbs int + Options *api.Options } type CompletionResponse struct { Content string + LogProbs []TokenProbs DoneReason string Done bool PromptEvalCount int @@ -698,9 +710,12 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu "seed": req.Options.Seed, "stop": req.Options.Stop, "image_data": req.Images, + "logprobs": req.LogProbs, "cache_prompt": true, } + fmt.Println("completion request:", request) + if len(req.Format) > 0 { switch string(req.Format) { case `null`, `""`: @@ -796,7 +811,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu continue } - // slog.Debug("got line", "line", string(line)) evt, ok := bytes.CutPrefix(line, []byte("data: ")) if !ok { evt = line @@ -822,7 +836,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if c.Content != "" { fn(CompletionResponse{ - Content: c.Content, + Content: c.Content, + LogProbs: c.LogProbs, }) } @@ -839,6 +854,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), EvalCount: c.Timings.PredictedN, EvalDuration: parseDurationMs(c.Timings.PredictedMS), + LogProbs: c.LogProbs, }) return nil } diff --git a/server/routes.go b/server/routes.go index 5a4bb485c8f..dfedc1c3b36 100644 --- a/server/routes.go +++ b/server/routes.go @@ -293,11 +293,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + LogProbs: req.LogProbs, + Options: opts, }, func(cr llm.CompletionResponse) { + fmt.Printf("banana: %#v\n", cr) res := api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), @@ -311,6 +313,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { EvalDuration: cr.EvalDuration, }, } + for _, p := range cr.LogProbs { + res.LogProbs = append(res.LogProbs, api.TokenProbs{ + TokenID: p.TokenID, + LogProb: p.LogProb, + Token: p.Token, + }) + } if _, err := sb.WriteString(cr.Content); err != nil { ch <- gin.H{"error": err.Error()} @@ -1466,10 +1475,11 @@ func (s *Server) ChatHandler(c *gin.Context) { var sb strings.Builder var toolCallIndex int = 0 if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + LogProbs: req.LogProbs, + Options: opts, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, @@ -1484,6 +1494,13 @@ func (s *Server) ChatHandler(c *gin.Context) { EvalDuration: r.EvalDuration, }, } + for _, p := range r.LogProbs { + res.LogProbs = append(res.LogProbs, api.TokenProbs{ + TokenID: p.TokenID, + LogProb: p.LogProb, + Token: p.Token, + }) + } if r.Done { res.TotalDuration = time.Since(checkpointStart)