Skip to content

Commit

Permalink
runner.go: Handle truncation of tokens for stop sequences
Browse files Browse the repository at this point in the history
When a single token contains both text to be return and a stop
sequence, this causes an out of bounds error when we update the
cache to match our text. This is because we currently assume that
the removing the stop sequence will consume at least one token.

This also inverts the logic to deal with positive numbers, rather
than a value to be subtracted, which is easier to reason about.

Fixes ollama#7153
  • Loading branch information
jessegross committed Oct 10, 2024
1 parent 03408f3 commit 0077e22
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 34 deletions.
27 changes: 20 additions & 7 deletions llama/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,14 +451,27 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
sequence := strings.Join(seq.pendingResponses, "")

if ok, stop := findStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "stop", seq.stop)

trimCacheLen := len(seq.pendingResponses) - 1
seq.pendingResponses = truncateStop(seq.pendingResponses, stop)
trimCacheLen -= len(seq.pendingResponses)
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

var tokenTruncated bool
origLen := len(seq.pendingResponses)
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
newLen := len(seq.pendingResponses)

// Update the cache based on the tokens that will be returned:
// - We have 1 token more than is currently in the cache because
// the last one generated wasn't submitted to Decode
// - Remove any stop sequences that we stripped out
// - If truncateStop removed a portion of a token, drop that
// - As defense-in-depth, if truncatedToken didn't find a stop token
// remove the extra one that we added to the cache len
tokenLen := len(seq.cache.Inputs) + 1
tokenLen -= origLen - newLen
if tokenTruncated || origLen == newLen {
tokenLen--
}
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]

// remove any tokens from the cache that we don't actually return
seq.cache.Inputs = seq.cache.Inputs[:len(seq.cache.Inputs)-trimCacheLen]
s.removeSequence(i, "stop")
continue
}
Expand Down
10 changes: 6 additions & 4 deletions llama/runner/stop.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ func containsStopSuffix(sequence string, stops []string) bool {

// truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating
// the last piece if required
func truncateStop(pieces []string, stop string) []string {
// the last piece if required (and signalling if this was the case)
func truncateStop(pieces []string, stop string) ([]string, bool) {
joined := strings.Join(pieces, "")

index := strings.Index(joined, stop)
if index == -1 {
return pieces
return pieces, false
}

joined = joined[:index]
Expand All @@ -46,6 +46,7 @@ func truncateStop(pieces []string, stop string) []string {
}

var result []string
tokenTruncated := false
start := 0
for _, length := range lengths {
if start >= len(joined) {
Expand All @@ -55,12 +56,13 @@ func truncateStop(pieces []string, stop string) []string {
end := start + length
if end > len(joined) {
end = len(joined)
tokenTruncated = true
}
result = append(result, joined[start:end])
start = end
}

return result
return result, tokenTruncated
}

func incompleteUnicode(token string) bool {
Expand Down
58 changes: 35 additions & 23 deletions llama/runner/stop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,54 @@ import (

func TestTruncateStop(t *testing.T) {
tests := []struct {
name string
pieces []string
stop string
expected []string
name string
pieces []string
stop string
expected []string
expectedTrunc bool
}{
{
name: "Single word",
pieces: []string{"hello", "world"},
stop: "world",
expected: []string{"hello"},
name: "Single word",
pieces: []string{"hello", "world"},
stop: "world",
expected: []string{"hello"},
expectedTrunc: false,
},
{
name: "Partial",
pieces: []string{"hello", "wor"},
stop: "or",
expected: []string{"hello", "w"},
expectedTrunc: true,
},
{
name: "Partial",
pieces: []string{"hello", "wor"},
stop: "or",
expected: []string{"hello", "w"},
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: []string{"Hello", " there"},
expectedTrunc: false,
},
{
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: []string{"Hello", " there"},
name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"},
stop: "there!",
expected: []string{"Hello", " "},
expectedTrunc: true,
},
{
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
expectedTrunc: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("truncateStop(%v, %s): have %v; want %v", tt.pieces, tt.stop, result, tt.expected)
result, resultTrunc := truncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
}
})
}
Expand Down

0 comments on commit 0077e22

Please sign in to comment.