diff --git a/llama/runner/runner.go b/llama/runner/runner.go index ffbea9e9dc4..bf799d37cc4 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -451,14 +451,27 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) sequence := strings.Join(seq.pendingResponses, "") if ok, stop := findStop(sequence, seq.stop); ok { - slog.Debug("hit stop token", "stop", seq.stop) - - trimCacheLen := len(seq.pendingResponses) - 1 - seq.pendingResponses = truncateStop(seq.pendingResponses, stop) - trimCacheLen -= len(seq.pendingResponses) + slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) + + var tokenTruncated bool + origLen := len(seq.pendingResponses) + seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop) + newLen := len(seq.pendingResponses) + + // Update the cache based on the tokens that will be returned: + // - We have 1 token more than is currently in the cache because + // the last one generated wasn't submitted to Decode + // - Remove any stop sequences that we stripped out + // - If truncateStop removed a portion of a token, drop that + // - As defense-in-depth, if truncatedToken didn't find a stop token + // remove the extra one that we added to the cache len + tokenLen := len(seq.cache.Inputs) + 1 + tokenLen -= origLen - newLen + if tokenTruncated || origLen == newLen { + tokenLen-- + } + seq.cache.Inputs = seq.cache.Inputs[:tokenLen] - // remove any tokens from the cache that we don't actually return - seq.cache.Inputs = seq.cache.Inputs[:len(seq.cache.Inputs)-trimCacheLen] s.removeSequence(i, "stop") continue } diff --git a/llama/runner/stop.go b/llama/runner/stop.go index ece06c2103c..c05f5e3d5cf 100644 --- a/llama/runner/stop.go +++ b/llama/runner/stop.go @@ -28,13 +28,13 @@ func containsStopSuffix(sequence string, stops []string) bool { // truncateStop removes the provided stop string from pieces, // returning the partial pieces with stop removed, including truncating -// the last piece if required -func truncateStop(pieces []string, stop string) []string { +// the last piece if required (and signalling if this was the case) +func truncateStop(pieces []string, stop string) ([]string, bool) { joined := strings.Join(pieces, "") index := strings.Index(joined, stop) if index == -1 { - return pieces + return pieces, false } joined = joined[:index] @@ -46,6 +46,7 @@ func truncateStop(pieces []string, stop string) []string { } var result []string + tokenTruncated := false start := 0 for _, length := range lengths { if start >= len(joined) { @@ -55,12 +56,13 @@ func truncateStop(pieces []string, stop string) []string { end := start + length if end > len(joined) { end = len(joined) + tokenTruncated = true } result = append(result, joined[start:end]) start = end } - return result + return result, tokenTruncated } func incompleteUnicode(token string) bool { diff --git a/llama/runner/stop_test.go b/llama/runner/stop_test.go index 1455398768c..51b35fde358 100644 --- a/llama/runner/stop_test.go +++ b/llama/runner/stop_test.go @@ -7,42 +7,54 @@ import ( func TestTruncateStop(t *testing.T) { tests := []struct { - name string - pieces []string - stop string - expected []string + name string + pieces []string + stop string + expected []string + expectedTrunc bool }{ { - name: "Single word", - pieces: []string{"hello", "world"}, - stop: "world", - expected: []string{"hello"}, + name: "Single word", + pieces: []string{"hello", "world"}, + stop: "world", + expected: []string{"hello"}, + expectedTrunc: false, + }, + { + name: "Partial", + pieces: []string{"hello", "wor"}, + stop: "or", + expected: []string{"hello", "w"}, + expectedTrunc: true, }, { - name: "Partial", - pieces: []string{"hello", "wor"}, - stop: "or", - expected: []string{"hello", "w"}, + name: "Suffix", + pieces: []string{"Hello", " there", "!"}, + stop: "!", + expected: []string{"Hello", " there"}, + expectedTrunc: false, }, { - name: "Suffix", - pieces: []string{"Hello", " there", "!"}, - stop: "!", - expected: []string{"Hello", " there"}, + name: "Suffix partial", + pieces: []string{"Hello", " the", "re!"}, + stop: "there!", + expected: []string{"Hello", " "}, + expectedTrunc: true, }, { - name: "Middle", - pieces: []string{"hello", " wor"}, - stop: "llo w", - expected: []string{"he"}, + name: "Middle", + pieces: []string{"hello", " wor"}, + stop: "llo w", + expected: []string{"he"}, + expectedTrunc: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := truncateStop(tt.pieces, tt.stop) - if !reflect.DeepEqual(result, tt.expected) { - t.Errorf("truncateStop(%v, %s): have %v; want %v", tt.pieces, tt.stop, result, tt.expected) + result, resultTrunc := truncateStop(tt.pieces, tt.stop) + if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { + t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) } }) }