@@ -160,6 +160,156 @@ def token_probability_fn(inputs):
160160    return  prompt 
161161
162162
163+ def  beam_search (
164+     token_probability_fn ,
165+     prompt ,
166+     max_length ,
167+     num_beams ,
168+     from_logits = False ,
169+     end_token_id = None ,
170+     pad_token_id = 0 ,
171+ ):
172+     """ 
173+     Text generation utility based on beam search algorithm. 
174+ 
175+     At each time-step, beam search keeps the beams (sequences) of the top 
176+     `num_beams` highest accumulated probabilities, and uses each one of the 
177+     beams to predict candidate next tokens. 
178+ 
179+     Args: 
180+         token_probability_fn: a callable, which takes in input_sequence 
181+             and output the probability distribution of the next token. If 
182+             `from_logits` set to True, it should output the logits of the next 
183+             token. The input shape would be `[batch_size, length]` and the 
184+             output should be `[batch_size, vocab_size]`, where batch_size is 
185+             variable. 
186+         prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to 
187+             append generated tokens. The initial beam for beam search. 
188+         max_length: int. The max length of generated text. 
189+         num_beams: int. The number of beams that should be kept at each 
190+             time-step. `num_beams` should be strictly positive. 
191+         from_logits: bool. Indicates whether `token_probability_fn` outputs 
192+             logits or probabilities. 
193+         end_token_id: int, defaults to None. The token marking the end of the 
194+             sequence, once encountered the generation is finished for the exact 
195+             sequence. If None, every sequence is generated up to `max_length`. 
196+             If set, all tokens after encountering `end_token_id` will be 
197+             replaced with `pad_token_id`. 
198+         pad_token_id: int, defaults to 0. The pad token after `end_token_id` 
199+             is received. 
200+ 
201+     Returns: 
202+         A 1D int Tensor, or 2D int Tensor representing the generated 
203+         sequences. 
204+ 
205+     Examples: 
206+     ```python 
207+     BATCH_SIZE = 8 
208+     VOCAB_SIZE = 10 
209+     FEATURE_SIZE = 16 
210+     START_ID = 1 
211+     END_ID = 2 
212+ 
213+     # Create a dummy model to predict the next token. 
214+     model = tf.keras.Sequential( 
215+         [ 
216+             tf.keras.Input(shape=[None]), 
217+             tf.keras.layers.Embedding( 
218+                 input_dim=VOCAB_SIZE, 
219+                 output_dim=FEATURE_SIZE, 
220+             ), 
221+             tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"), 
222+         ] 
223+     ) 
224+ 
225+     # Define a function that outputs the next token's probability given the 
226+     # input sequence. 
227+     def token_probability_fn(inputs): 
228+         return model(inputs)[:, -1, :] 
229+ 
230+     prompt = tf.fill((BATCH_SIZE, 1), START_ID) 
231+ 
232+     # Print the generated sequence (token ids). 
233+     keras_nlp.utils.beam_search( 
234+         token_probability_fn, 
235+         prompt, 
236+         max_length=10, 
237+         num_beams=5, 
238+         end_token_id=END_ID, 
239+     ) 
240+     ``` 
241+ 
242+     """ 
243+     if  not  tf .executing_eagerly ():
244+         raise  RuntimeError (
245+             "`keras_nlp.utils.beam_search` currently requires an eager " 
246+             "execution context. Please call `beam_search` outside " 
247+             "tf.function or run `tf.config.run_functions_eagerly(True)` to run " 
248+             "tf.function in eager mode." 
249+         )
250+     if  num_beams  <=  0 :
251+         raise  ValueError (
252+             f"`num_beams` should be strictly positive. Received: `num_beams={ num_beams }  
253+         )
254+ 
255+     prompt  =  validate_prompt (prompt )
256+ 
257+     input_is_1d  =  prompt .shape .rank  ==  1 
258+     if  input_is_1d :
259+         prompt  =  prompt [tf .newaxis , :]
260+     validate_token_probability_fn (token_probability_fn , prompt )
261+ 
262+     batch_size , length  =  prompt .shape 
263+     if  length  <  max_length :
264+         # Initialize beam. 
265+         beams  =  tf .expand_dims (prompt , 1 )
266+         beams_prob  =  tf .zeros ([batch_size , 1 ])
267+         i  =  length 
268+         while  i  <  max_length :
269+             beam_size  =  beams .shape [1 ]
270+             beam_preds  =  []
271+             for  j  in  range (beam_size ):
272+                 preds  =  token_probability_fn (beams [:, j , :])
273+                 if  from_logits :
274+                     preds  =  tf .keras .activations .softmax (preds , axis = - 1 )
275+                 beam_preds .append (preds )
276+             stacked_preds  =  tf .stack (beam_preds , axis = 1 )
277+             vocab_size  =  stacked_preds .shape [2 ]
278+             logits  =  tf .reshape (
279+                 stacked_preds , [batch_size , beam_size  *  vocab_size ]
280+             )
281+             probs  =  tf .math .log (logits ) +  tf .repeat (
282+                 beams_prob , repeats = vocab_size , axis = 1 
283+             )
284+             num_beams  =  min (beam_size  *  vocab_size , num_beams )
285+             candidate_prob , candidate_indexes  =  tf .math .top_k (
286+                 probs , k = num_beams , sorted = False 
287+             )
288+             candidate_beam_indexes  =  candidate_indexes  //  vocab_size 
289+             next_token  =  candidate_indexes  %  vocab_size 
290+ 
291+             beams  =  tf .gather (
292+                 beams , candidate_beam_indexes , axis = 1 , batch_dims = 1 
293+             )
294+             beams  =  tf .concat ([beams , next_token [..., tf .newaxis ]], axis = - 1 )
295+             beams_prob  =  candidate_prob 
296+             i  +=  1 
297+         # Get the beam with the maximum probability. 
298+         max_indexes  =  tf .math .argmax (beams_prob , axis = - 1 )
299+         max_beams  =  tf .gather (
300+             beams , max_indexes [:, tf .newaxis ], axis = 1 , batch_dims = 1 
301+         )
302+         prompt  =  tf .squeeze (max_beams )
303+ 
304+     if  end_token_id  is  not None :
305+         prompt  =  mask_tokens_after_end_token (
306+             prompt , max_length , end_token_id , pad_token_id 
307+         )
308+     if  input_is_1d :
309+         return  tf .squeeze (prompt )
310+     return  prompt 
311+ 
312+ 
163313def  random_search (
164314    token_probability_fn ,
165315    prompt ,
@@ -361,7 +511,7 @@ def token_probability_fn(inputs):
361511            "tf.function in eager mode." 
362512        )
363513    if  k  <=  0 :
364-         raise  ValueError (f"`k` should strictly positive. Received: `k={ k }  )
514+         raise  ValueError (f"`k` should be  strictly positive. Received: `k={ k }  )
365515
366516    prompt  =  validate_prompt (prompt )
367517    input_is_1d  =  prompt .shape .rank  ==  1 
@@ -378,7 +528,7 @@ def token_probability_fn(inputs):
378528        # If k is greater than the vocabulary size, use the entire vocabulary. 
379529        k  =  min (k , pred .shape [1 ])
380530        # Filter out top-k tokens. 
381-         top_k_pred , top_k_indices  =  tf .math .top_k (pred , k = k )
531+         top_k_pred , top_k_indices  =  tf .math .top_k (pred , k = k ,  sorted = False )
382532        # Sample the next token from the probability distribution. 
383533        next_token  =  tf .random .categorical (
384534            tf .math .log (top_k_pred ), 1 , seed = seed 
0 commit comments