Skip to content
This repository was archived by the owner on Jan 31, 2024. It is now read-only.

Commit df75753

Browse files
committed
StreamingTokenDecoder, AntipromptProcessor support
1 parent 1bc8917 commit df75753

File tree

10 files changed

+143
-86
lines changed

10 files changed

+143
-86
lines changed

LLamaStack.Core/Extensions/Extensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public static ILLamaParams ToLLamaParams(this IModelConfig modelConfig)
3232
RopeFrequencyBase = modelConfig.RopeFrequencyBase,
3333
RopeFrequencyScale = modelConfig.RopeFrequencyScale,
3434
Seed = modelConfig.Seed,
35-
TensorSplits = modelConfig.TensorSplits,
35+
TensorSplits = modelConfig.TensorSplits == null ? new() : new (modelConfig.TensorSplits),
3636
UseFp16Memory = modelConfig.UseFp16Memory,
3737
UseMemoryLock = modelConfig.UseMemoryLock,
3838
UseMemorymap = modelConfig.UseMemorymap,
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
namespace LLamaStack.Core.Inference
2+
{
3+
//TODO remove when made public interface LLamaSharp
4+
public sealed class AntipromptProcessor
5+
{
6+
private int _longestAntiprompt;
7+
private readonly List<string> _antiprompts = new();
8+
9+
private string? _string;
10+
11+
public AntipromptProcessor(IEnumerable<string>? antiprompts = null)
12+
{
13+
if (antiprompts != null)
14+
SetAntiprompts(antiprompts);
15+
}
16+
17+
/// <summary>
18+
/// Add an antiprompt to the collection
19+
/// </summary>
20+
/// <param name="antiprompt"></param>
21+
public void AddAntiprompt(string antiprompt)
22+
{
23+
_antiprompts.Add(antiprompt);
24+
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
25+
}
26+
27+
/// <summary>
28+
/// Overwrite all current antiprompts with a new set
29+
/// </summary>
30+
/// <param name="antiprompts"></param>
31+
public void SetAntiprompts(IEnumerable<string> antiprompts)
32+
{
33+
_antiprompts.Clear();
34+
_antiprompts.AddRange(antiprompts);
35+
36+
_longestAntiprompt = 0;
37+
foreach (var antiprompt in _antiprompts)
38+
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
39+
}
40+
41+
/// <summary>
42+
/// Add some text and check if the buffer now ends with any antiprompt
43+
/// </summary>
44+
/// <param name="text"></param>
45+
/// <returns>true if the text buffer ends with any antiprompt</returns>
46+
public bool Add(string text)
47+
{
48+
_string += text;
49+
50+
// When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length).
51+
// This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode
52+
// even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances!
53+
var maxLength = Math.Max(32, _longestAntiprompt * 4);
54+
var trimLength = Math.Max(16, _longestAntiprompt * 2);
55+
if (_string.Length > maxLength)
56+
_string = _string.Substring(_string.Length - trimLength);
57+
58+
foreach (var antiprompt in _antiprompts)
59+
if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture))
60+
return true;
61+
62+
return false;
63+
}
64+
}
65+
}

LLamaStack.Core/Inference/InferenceHandlerBase.cs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
using LLama.Abstractions;
1+
using LLama;
2+
using LLama.Abstractions;
23
using LLama.Common;
34
using LLamaStack.Core.Common;
5+
using LLamaStack.Core.Extensions;
6+
using LLamaStack.Core.Models;
47
using LLamaStack.Core.Services;
58
using System.Runtime.CompilerServices;
9+
using System.Text;
610

711
namespace LLamaStack.Core.Inference
812
{
@@ -48,6 +52,11 @@ public abstract class InferenceHandlerBase<T> : IInferenceHandler
4852
/// </summary>
4953
protected ISampleService _sampleService;
5054

55+
/// <summary>
56+
/// The token decoder
57+
/// </summary>
58+
protected StreamingTokenDecoder _tokenDecoder;
59+
5160

5261
/// <summary>
5362
/// Initializes a new instance of the <see cref="InferenceHandlerBase{T}"/> class.
@@ -61,6 +70,7 @@ protected InferenceHandlerBase(LLamaStackModel<T> model, LLamaStackContext conte
6170
_pastTokensCount = 0;
6271
_consumedTokensCount = 0;
6372
_sampleService = new SampleService(_context);
73+
_tokenDecoder = new StreamingTokenDecoder(_context.LLamaContext);
6474
_lastTokens = new FixedSizeQueue<TokenData>(_context.ContextSize).FillWith(new(0));
6575
}
6676

@@ -136,6 +146,7 @@ public async virtual IAsyncEnumerable<TokenData> InferAsync(string text, IInfere
136146
{
137147
cancellationToken.ThrowIfCancellationRequested();
138148
inferenceParams ??= new InferenceParams();
149+
var antipromptProcessor = new AntipromptProcessor(inferenceParams.AntiPrompts);
139150

140151
InferStateArgs args = new InferStateArgs()
141152
{
@@ -158,8 +169,14 @@ public async virtual IAsyncEnumerable<TokenData> InferAsync(string text, IInfere
158169

159170
if (args.ReturnValue)
160171
{
161-
foreach (var embed in _currentTokens)
162-
yield return embed;
172+
foreach (var tokenData in ProcessTokens(_currentTokens))
173+
{
174+
// Check if any of the antiprompts have been generated
175+
if (!tokenData.IsChild && antipromptProcessor.Add(tokenData.Content))
176+
args.WaitForInput = true;
177+
178+
yield return tokenData;
179+
}
163180
}
164181

165182
var breakGeneration = await PostProcess(inferenceParams, args);
@@ -170,6 +187,18 @@ public async virtual IAsyncEnumerable<TokenData> InferAsync(string text, IInfere
170187
}
171188
}
172189

190+
protected List<TokenData> ProcessTokens(List<TokenData> tokens)
191+
{
192+
_tokenDecoder.AddRange(tokens.ToTokenIds());
193+
194+
// First token is parent, contains full Content,
195+
// Others are Child, no Content, Data only
196+
tokens[0].Content = _tokenDecoder.Read();
197+
foreach (var token in tokens.Skip(1))
198+
token.IsChild = true;
199+
200+
return tokens;
201+
}
173202

174203
/// <summary>
175204
/// Gets the state.

LLamaStack.Core/Inference/InstructInferenceHandler.cs

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -89,32 +89,13 @@ protected override Task<bool> PostProcess(IInferenceParams inferenceParams, Infe
8989
{
9090
if (_promptTokens.Count <= _consumedTokensCount)
9191
{
92-
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
93-
{
94-
var last_output_builder = new StringBuilder();
95-
foreach (var token in _lastTokens)
96-
{
97-
_context.TokenToString(token, last_output_builder);
98-
}
99-
100-
var last_output = last_output_builder.ToString();
101-
foreach (var antiprompt in args.Antiprompts)
102-
{
103-
if (last_output.EndsWith(antiprompt))
104-
{
105-
args.WaitForInput = true;
106-
return Task.FromResult(true);
107-
}
108-
}
109-
}
110-
11192
if (_pastTokensCount > 0 && args.WaitForInput)
11293
{
11394
return Task.FromResult(true);
11495
}
11596
}
11697

117-
if (_currentTokens.Count > 0 && _currentTokens.Last()?.Id == _context.TokenEOS)
98+
if (_currentTokens.Count > 0 && _currentTokens.Last()?.Id == _model.TokenEOS)
11899
{
119100
args.WaitForInput = true;
120101
}

LLamaStack.Core/Inference/InteractiveInferenceHandler.cs

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public sealed class InteractiveInferenceHandler<T> : InferenceHandlerBase<T>
1818
/// <param name="context">The context.</param>
1919
public InteractiveInferenceHandler(LLamaStackModel<T> model, LLamaStackContext context) : base(model, context)
2020
{
21-
_tokenNewline = new TokenData(_context.TokenNL);
21+
_tokenNewline = new TokenData(model.TokenNL);
2222
}
2323

2424

@@ -69,32 +69,13 @@ protected override Task<bool> PostProcess(IInferenceParams inferenceParams, Infe
6969
{
7070
if (_promptTokens.Count <= _consumedTokensCount)
7171
{
72-
if (!args.Antiprompts.IsNullOrEmpty())
73-
{
74-
var last_output_builder = new StringBuilder();
75-
foreach (var token in _lastTokens)
76-
{
77-
_context.TokenToString(token, last_output_builder);
78-
}
79-
80-
var last_output = last_output_builder.ToString();
81-
foreach (var antiprompt in args.Antiprompts)
82-
{
83-
if (last_output.EndsWith(antiprompt))
84-
{
85-
args.WaitForInput = true;
86-
break;
87-
}
88-
}
89-
}
90-
9172
if (_pastTokensCount > 0 && args.WaitForInput)
9273
{
9374
return Task.FromResult(true);
9475
}
9576
}
9677

97-
if (_currentTokens.Count > 0 && _currentTokens.Last()?.Id == _context.TokenEOS)
78+
if (_currentTokens.Count > 0 && _currentTokens.Last()?.Id == _model.TokenEOS)
9879
{
9980
return Task.FromResult(true);
10081
}
@@ -133,15 +114,9 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
133114

134115
_lastTokens.Enqueue(tokenData);
135116

136-
137-
if (tokenData.Id == _context.TokenEOS)
117+
if (tokenData.Id == _model.TokenEOS)
138118
{
139119
tokenData = _tokenNewline;
140-
if (!args.Antiprompts.IsNullOrEmpty())
141-
{
142-
var first_antiprompt = _context.TokenizeTextToList(args.Antiprompts[0], false);
143-
_promptTokens.AddRange(first_antiprompt);
144-
}
145120
}
146121

147122
_currentTokens.Add(tokenData);

LLamaStack.Core/Inference/TokenData.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ public sealed record TokenData(int Id)
55
public float Logit { get; set; }
66
public float Probability { get; set; }
77
public string Content { get; set; }
8+
public bool IsChild { get; set; }
89
}
910
}

LLamaStack.Core/LLamaStackContext.cs

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,6 @@ public LLamaStackContext(LLamaContext context)
3737
public int ContextSize => _context.ContextSize;
3838

3939

40-
/// <summary>
41-
/// Gets the native llama EOS tokenid.
42-
/// </summary>
43-
public int TokenEOS => NativeApi.llama_token_eos(_context.NativeHandle);
44-
45-
46-
/// <summary>
47-
/// Gets the native llama NL tokenid.
48-
/// </summary>
49-
public int TokenNL => NativeApi.llama_token_nl(_context.NativeHandle);
50-
51-
5240
/// <summary>
5341
/// Loads the state.
5442
/// </summary>
@@ -100,8 +88,7 @@ public TokenData GetTokenData(LLamaTokenDataArray tokenDataArray, int id)
10088
return new TokenData(tokenData.id)
10189
{
10290
Logit = tokenData.logit,
103-
Probability = tokenData.p,
104-
Content = _context.TokenToString(tokenData.id)
91+
Probability = tokenData.p
10592
};
10693
}
10794

@@ -152,7 +139,8 @@ public int Sample(LLamaTokenDataArray tokenDataArray, IInferenceParams inference
152139
inferenceParams.TopP,
153140
inferenceParams.TfsZ,
154141
inferenceParams.TypicalP,
155-
inferenceParams.Grammar
142+
inferenceParams.Grammar,
143+
inferenceParams.MinP
156144
);
157145
}
158146

@@ -166,7 +154,7 @@ public int Sample(LLamaTokenDataArray tokenDataArray, IInferenceParams inference
166154
private IEnumerable<TokenData> TokenizeText(string text, bool addBos)
167155
{
168156
return _context.Tokenize(text, addBos)
169-
.Select(x => new TokenData(x) { Content = _context.TokenToString(x) });
157+
.Select(x => new TokenData(x));
170158
}
171159

172160

@@ -205,27 +193,12 @@ public Task<int> EvalAsync(IEnumerable<TokenData> tokens, int pastTokensCount)
205193
return Task.Run(() => _context.Eval(tokens.ToTokenIds(), pastTokensCount));
206194
}
207195

208-
209-
/// <summary>
210-
/// Token to string.
211-
/// </summary>
212-
/// <param name="token">The token.</param>
213-
/// <param name="stringBuilder">The string builder.</param>
214-
public void TokenToString(TokenData token, StringBuilder stringBuilder)
215-
{
216-
_context.NativeHandle.TokenToString(token.Id, _context.Encoding, stringBuilder);
217-
}
218-
219196
/// <summary>
220197
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
221198
/// </summary>
222199
public void Dispose()
223200
{
224201
_context?.Dispose();
225202
}
226-
227-
228203
}
229-
230-
231204
}

LLamaStack.Core/LLamaStackModel.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using LLama;
22
using LLama.Abstractions;
3+
using LLama.Native;
34
using LLamaStack.Core.Config;
45
using LLamaStack.Core.Extensions;
56
using System.Collections.Concurrent;
@@ -45,6 +46,19 @@ public LLamaStackModel(ModelConfig modelParams)
4546
public int ContextCount => _contexts.Count;
4647

4748

49+
50+
/// <summary>
51+
/// Gets the native llama EOS tokenid.
52+
/// </summary>
53+
public int TokenEOS => NativeApi.llama_token_eos(_weights.NativeHandle);
54+
55+
56+
/// <summary>
57+
/// Gets the native llama NL tokenid.
58+
/// </summary>
59+
public int TokenNL => NativeApi.llama_token_nl(_weights.NativeHandle);
60+
61+
4862
/// <summary>
4963
/// Creates a new context session on this model
5064
/// </summary>

0 commit comments

Comments
 (0)