1- using LLama . Abstractions ;
1+ using LLama ;
2+ using LLama . Abstractions ;
23using LLama . Common ;
34using LLamaStack . Core . Common ;
5+ using LLamaStack . Core . Extensions ;
6+ using LLamaStack . Core . Models ;
47using LLamaStack . Core . Services ;
58using System . Runtime . CompilerServices ;
9+ using System . Text ;
610
711namespace LLamaStack . Core . Inference
812{
@@ -48,6 +52,11 @@ public abstract class InferenceHandlerBase<T> : IInferenceHandler
4852 /// </summary>
4953 protected ISampleService _sampleService ;
5054
55+ /// <summary>
56+ /// The token decoder
57+ /// </summary>
58+ protected StreamingTokenDecoder _tokenDecoder ;
59+
5160
5261 /// <summary>
5362 /// Initializes a new instance of the <see cref="InferenceHandlerBase{T}"/> class.
@@ -61,6 +70,7 @@ protected InferenceHandlerBase(LLamaStackModel<T> model, LLamaStackContext conte
6170 _pastTokensCount = 0 ;
6271 _consumedTokensCount = 0 ;
6372 _sampleService = new SampleService ( _context ) ;
73+ _tokenDecoder = new StreamingTokenDecoder ( _context . LLamaContext ) ;
6474 _lastTokens = new FixedSizeQueue < TokenData > ( _context . ContextSize ) . FillWith ( new ( 0 ) ) ;
6575 }
6676
@@ -136,6 +146,7 @@ public async virtual IAsyncEnumerable<TokenData> InferAsync(string text, IInfere
136146 {
137147 cancellationToken . ThrowIfCancellationRequested ( ) ;
138148 inferenceParams ??= new InferenceParams ( ) ;
149+ var antipromptProcessor = new AntipromptProcessor ( inferenceParams . AntiPrompts ) ;
139150
140151 InferStateArgs args = new InferStateArgs ( )
141152 {
@@ -158,8 +169,14 @@ public async virtual IAsyncEnumerable<TokenData> InferAsync(string text, IInfere
158169
159170 if ( args . ReturnValue )
160171 {
161- foreach ( var embed in _currentTokens )
162- yield return embed ;
172+ foreach ( var tokenData in ProcessTokens ( _currentTokens ) )
173+ {
174+ // Check if any of the antiprompts have been generated
175+ if ( ! tokenData . IsChild && antipromptProcessor . Add ( tokenData . Content ) )
176+ args . WaitForInput = true ;
177+
178+ yield return tokenData ;
179+ }
163180 }
164181
165182 var breakGeneration = await PostProcess ( inferenceParams , args ) ;
@@ -170,6 +187,18 @@ public async virtual IAsyncEnumerable<TokenData> InferAsync(string text, IInfere
170187 }
171188 }
172189
190+ protected List < TokenData > ProcessTokens ( List < TokenData > tokens )
191+ {
192+ _tokenDecoder . AddRange ( tokens . ToTokenIds ( ) ) ;
193+
194+ // First token is parent, contains full Content,
195+ // Others are Child, no Content, Data only
196+ tokens [ 0 ] . Content = _tokenDecoder . Read ( ) ;
197+ foreach ( var token in tokens . Skip ( 1 ) )
198+ token . IsChild = true ;
199+
200+ return tokens ;
201+ }
173202
174203 /// <summary>
175204 /// Gets the state.
0 commit comments