1919from torch .nn .attention import SDPBackend , sdpa_kernel
2020from transformers .cache_utils import StaticCache
2121
22+ from pruna .logging .logger import pruna_logger
23+
2224
2325class TransformersGenerator :
2426 """
@@ -43,6 +45,8 @@ class TransformersGenerator:
4345 Whether to compile the full computation graph or use partial graph compilation.
4446 batch_size : int, default=1
4547 The batch size to use for text generation.
48+ device : str, default='cuda'
49+ The device to use for text generation.
4650 """
4751
4852 def __init__ (
@@ -54,6 +58,7 @@ def __init__(
5458 compile_mode : str = "reduce-overhead" ,
5559 compile_fullgraph : bool = True ,
5660 batch_size : int = 1 ,
61+ device : str = "cuda" ,
5762 ):
5863 """
5964 Initialize the TransformersGenerator.
@@ -87,21 +92,23 @@ def __init__(
8792 torch ._dynamo .config .inline_inbuilt_nn_modules = False # torch 2.5.0 fix
8893
8994 self .model = model
90- self .device = model . device
95+ self .device = device
9196 self .temperature = temperature
9297 self .top_k = top_k
9398 self .use_cache = True
9499 self .compile_mode = compile_mode
95100 self .compile_fullgraph = compile_fullgraph
96101 self .batch_size = batch_size
102+ self .cache_batch_size = batch_size
97103 self .cache_size = max_kv_cache_size
104+ self .eos_token_id = getattr (model .config , "eos_token_id" , None )
105+ if self .eos_token_id is None :
106+ pruna_logger .warning ("Warning: eos_token_id is None. This may affect generation stopping criteria." )
98107
99108 self .setup_cache ()
100109
101110 self .decode_one_token = torch .compile ( # type: ignore
102- self .decode_one_token ,
103- mode = self .compile_mode ,
104- fullgraph = self .compile_fullgraph
111+ self .decode_one_token , mode = self .compile_mode , fullgraph = self .compile_fullgraph
105112 )
106113
107114 self .init ()
@@ -110,6 +117,7 @@ def __init__(
110117 # Cuda Graph section
111118 self .static_input = torch .zeros ((1 , 1 ), device = self .device , dtype = torch .int32 )
112119 self .static_output = torch .zeros ((1 , 1 ), device = self .device , dtype = torch .int32 )
120+ self .original_gen_next_token = self .gen_next_token
113121 self .cuda_graph = None
114122 self .do_capture_graph = False
115123 ############################
@@ -198,10 +206,7 @@ def logits_to_probs(self, logits: torch.Tensor, temperature: float = 1.0, top_k:
198206 return probs
199207
200208 def sample (
201- self ,
202- logits : torch .Tensor ,
203- temperature : float = 1.0 ,
204- top_k : int | None = None
209+ self , logits : torch .Tensor , temperature : float = 1.0 , top_k : int | None = None
205210 ) -> tuple [torch .Tensor , torch .Tensor ]:
206211 """
207212 Sample one token from the model.
@@ -286,18 +291,38 @@ def setup(self, inputs: torch.Tensor, max_new_tokens: int):
286291 None
287292 This method initializes internal state for generation but does not return a value.
288293 """
294+ new_batch_size = inputs .shape [0 ]
295+
296+ # Check if batch size changed compared to the cache configuration
297+ if new_batch_size != self .cache_batch_size :
298+ pruna_logger .info (
299+ f"Batch size changed from { self .cache_batch_size } to { new_batch_size } . Re-initializing StaticCache."
300+ )
301+ self .batch_size = new_batch_size
302+ self .cache_batch_size = new_batch_size
303+ self .setup_cache ()
304+
305+ # If CUDA graph was used, it's now invalid
306+ if hasattr (self , "cuda_graph" ) and self .cuda_graph is not None :
307+ pruna_logger .warning ("CUDA graph is invalidated due to batch size change. Disabling CUDA graph usage." )
308+ self .cuda_graph = None
309+ self .gen_next_token = self .original_gen_next_token
310+ self .do_capture_graph = False
311+
312+ # Reset cache contents (does not change shape)
289313 self .reset_cache ()
314+
290315 self .inputs = inputs
291316 self .batch_size , self .seq_length = self .inputs .shape
292317 self .cache_position = torch .arange (self .seq_length , device = self .device )
293- # initialize the generated ids with zeros of the shape (batch_size, seq_length + max_new_tokens + 1)
318+ # initialize the generated ids with zeros
294319 self .generated_ids = torch .zeros (
295320 self .batch_size ,
296321 self .seq_length + max_new_tokens + 1 ,
297322 dtype = torch .int ,
298323 device = self .device ,
299324 )
300- # copy the input ids to the generated ids at the cache position.
325+ # copy the input ids to the generated ids
301326 self .generated_ids [:, self .cache_position ] = self .inputs .to (torch .int )
302327
303328 def prefill (self ) -> torch .Tensor :
@@ -340,11 +365,11 @@ def gen_next_token(self, current_token: torch.Tensor) -> torch.Tensor:
340365 The next token generated by the model.
341366 """
342367 next_token = self .decode_one_token (
343- current_token .clone (),
344- cache_position = self .cache_position + 1 ,
345- past_key_values = self .past_key_values ,
346- temperature = self .temperature ,
347- top_k = self .top_k ,
368+ current_token .clone (),
369+ cache_position = self .cache_position + 1 ,
370+ past_key_values = self .past_key_values ,
371+ temperature = self .temperature ,
372+ top_k = self .top_k ,
348373 )
349374 self .cache_position += 1
350375 self .generated_ids [:, self .cache_position ] = next_token .int ()
@@ -354,7 +379,7 @@ def enable_cuda_graph(
354379 self ,
355380 iters : int = 2 ,
356381 prompt_tokenized : list [int ] = [596 , 8830 , 315 , 6913 , 19476 , 11 , 1778 , 439 , 279 , 12939 ],
357- max_kv_cache_size : int = 1024
382+ max_kv_cache_size : int = 1024 ,
358383 ) -> None :
359384 """
360385 Enable the CUDA graph and capture the graph on random prompt.
@@ -375,17 +400,15 @@ def enable_cuda_graph(
375400 but does not return any value.
376401 """
377402 _ = self .generate (
378- torch .tensor (prompt_tokenized , device = self .model .device ).unsqueeze (0 ),
379- max_new_tokens = max_kv_cache_size
403+ torch .tensor (prompt_tokenized , device = self .model .device ).unsqueeze (0 ), max_new_tokens = max_kv_cache_size
380404 )
381405 for _ in range (iters ):
382406 # need to reset the graph before capturing it at each iteration
383407 # to avoid block/thread errors.
384408 self .do_capture_graph = True
385409 self .gen_next_token = self .gen_next_token_withgraph # type: ignore
386410 _ = self .generate (
387- torch .tensor (prompt_tokenized , device = self .model .device ).unsqueeze (0 ),
388- max_new_tokens = max_kv_cache_size
411+ torch .tensor (prompt_tokenized , device = self .model .device ).unsqueeze (0 ), max_new_tokens = max_kv_cache_size
389412 )
390413
391414 def gen_next_token_withgraph (self , current_token : torch .Tensor ) -> torch .Tensor :
@@ -426,54 +449,100 @@ def gen_next_token_withgraph(self, current_token: torch.Tensor) -> torch.Tensor:
426449 return next_token
427450
428451 def next_token_iterator (
429- self ,
430- current_token : torch .Tensor ,
431- max_new_tokens : int ,
432- cleanup : bool = True
452+ self , current_token : torch .Tensor , max_new_tokens : int , cleanup : bool = True
433453 ) -> torch .Tensor :
434454 """
435- Generate the next token.
455+ Generate the next token, stopping at max_new_tokens or EOS for each sequence in the batch .
436456
437457 Parameters
438458 ----------
439459 current_token : torch.Tensor
440- The current token.
460+ The current token tensor of shape (batch_size, 1) .
441461 max_new_tokens : int
442462 The maximum number of new tokens to generate.
443463 cleanup : bool
444- Whether to cleanup the inputs, generated ids, and cache position.
464+ Whether to cleanup the inputs, generated ids, and cache position after generation .
445465
446466 Returns
447467 -------
448468 torch.Tensor
449- The generated tokens.
469+ The generated tokens tensor of shape (batch_size, seq_length + generated_length),
470+ including the input prompt and potentially EOS tokens. Sequences that finish early
471+ will have EOS followed by padding (initial zeros).
450472 """
473+ # Keep track of sequences that haven't finished yet (encountered EOS)
474+ # Assumes initial state is unfinished for all sequences in the batch
475+ unfinished_sequences = torch .ones (self .batch_size , dtype = torch .bool , device = self .device )
476+
477+ # Loop for a maximum of max_new_tokens - 1 steps (as prefill generates the first)
451478 for i in range (1 , max_new_tokens ):
452- current_token = self .gen_next_token (current_token )
453- output_tokens = self .generated_ids
479+ # Generate the next token for all sequences
480+ current_token = self .gen_next_token (current_token ) # Updates self.generated_ids internally
481+
482+ # Check if the generated token is the EOS token for any currently unfinished sequence
483+ if self .eos_token_id is not None :
484+ # Check which sequences produced the EOS token THIS step
485+ # current_token shape is (batch_size, 1), squeeze to (batch_size,)
486+ # Only consider sequences that were previously unfinished
487+ finished_this_step = (current_token .squeeze (- 1 ) == self .eos_token_id ) & unfinished_sequences
488+ # Update the overall tracker for unfinished sequences
489+ unfinished_sequences &= ~ finished_this_step
490+
491+ # Stop generation if all sequences in the batch have finished
492+ if not unfinished_sequences .any ():
493+ break
494+
495+ # Determine the actual length generated (up to the current cache position)
496+ # .item() is safe as cache_position should be a 0-dim tensor
497+ final_seq_len = self .cache_position .item () + 1
498+ # Clone the relevant part of generated_ids before potential cleanup
499+ output_tokens = self .generated_ids [:, : int (final_seq_len )].clone ()
454500
455501 if cleanup :
502+ # Delete internal state tensors, but not output_tokens which is returned
456503 del self .inputs , self .generated_ids , self .cache_position
457504 torch .cuda .empty_cache ()
458505
459506 return output_tokens
460507
461508 @torch .inference_mode ()
462- def generate (self , input_ids : torch . Tensor , max_new_tokens : int = 100 ) -> torch .Tensor :
509+ def generate (self , * args , ** kwargs ) -> torch .Tensor :
463510 """
464- Generate the tokens .
511+ Generate tokens using the model .
465512
466513 Parameters
467514 ----------
468- input_ids : torch.Tensor
469- The input ids.
470- max_new_tokens : int
471- The maximum number of new tokens to generate.
515+ *args : tuple
516+ Variable length argument list (not used directly).
517+ **kwargs : dict
518+ Keyword arguments dictionary that must contain:
519+ - input_ids : torch.Tensor
520+ The input token ids that serve as the prompt.
521+ - max_new_tokens : int
522+ The maximum number of new tokens to generate.
472523
473524 Returns
474525 -------
475526 torch.Tensor
476- The generated tokens.
527+ The generated tokens, including the input prompt and potentially an EOS token .
477528 """
478- self .setup (inputs = input_ids , max_new_tokens = max_new_tokens )
479- return self .next_token_iterator (self .prefill (), max_new_tokens )
529+ # Extract parameters from kwargs with defaults from instance variables
530+ self .temperature = kwargs .pop ("temperature" , self .temperature )
531+ self .top_k = kwargs .pop ("top_k" , self .top_k )
532+ self .use_cache = kwargs .pop ("use_cache" , self .use_cache )
533+
534+ # Log any kwargs that are not explicitly handled
535+ unhandled_kwargs = {
536+ k : v
537+ for k , v in kwargs .items ()
538+ if k not in ["input_ids" , "max_new_tokens" , "temperature" , "top_k" , "batch_size" ]
539+ }
540+ if unhandled_kwargs :
541+ pruna_logger .warning (f"Unhandled kwargs in generate method: { unhandled_kwargs } " )
542+
543+ # Update instance variables with any provided values
544+ self .setup (
545+ inputs = kwargs ["input_ids" ] if "input_ids" in kwargs else args [0 ],
546+ max_new_tokens = kwargs ["max_new_tokens" ] if "max_new_tokens" in kwargs else args [1 ],
547+ )
548+ return self .next_token_iterator (self .prefill (), kwargs ["max_new_tokens" ])
0 commit comments