Skip to content

Commit 3aebee7

Browse files
fix: generation functions for torch.compile llm combination (#85)
* fix: take and arg or kwargs as input to the llm and stop when the llm has generated EOS token * fix: perform quantization on cpu always * fix: mypy errors * fix: handle review comments * fix: mypy error
1 parent f6dbc50 commit 3aebee7

File tree

3 files changed

+115
-43
lines changed

3 files changed

+115
-43
lines changed

src/pruna/algorithms/compilation/torch_compile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def causal_lm_logic(model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
343343
compile_mode=smash_config["mode"],
344344
compile_fullgraph=smash_config["fullgraph"],
345345
batch_size=smash_config["batch_size"],
346+
device=smash_config.device,
346347
)
347348
# If we are using max-autotune-no-cudagraphs, we need to handle the cudagraphs manually.
348349
if smash_config["mode"] == "max-autotune-no-cudagraphs":

src/pruna/algorithms/compilation/utils.py

Lines changed: 108 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from torch.nn.attention import SDPBackend, sdpa_kernel
2020
from transformers.cache_utils import StaticCache
2121

22+
from pruna.logging.logger import pruna_logger
23+
2224

2325
class TransformersGenerator:
2426
"""
@@ -43,6 +45,8 @@ class TransformersGenerator:
4345
Whether to compile the full computation graph or use partial graph compilation.
4446
batch_size : int, default=1
4547
The batch size to use for text generation.
48+
device : str, default='cuda'
49+
The device to use for text generation.
4650
"""
4751

4852
def __init__(
@@ -54,6 +58,7 @@ def __init__(
5458
compile_mode: str = "reduce-overhead",
5559
compile_fullgraph: bool = True,
5660
batch_size: int = 1,
61+
device: str = "cuda",
5762
):
5863
"""
5964
Initialize the TransformersGenerator.
@@ -87,21 +92,23 @@ def __init__(
8792
torch._dynamo.config.inline_inbuilt_nn_modules = False # torch 2.5.0 fix
8893

8994
self.model = model
90-
self.device = model.device
95+
self.device = device
9196
self.temperature = temperature
9297
self.top_k = top_k
9398
self.use_cache = True
9499
self.compile_mode = compile_mode
95100
self.compile_fullgraph = compile_fullgraph
96101
self.batch_size = batch_size
102+
self.cache_batch_size = batch_size
97103
self.cache_size = max_kv_cache_size
104+
self.eos_token_id = getattr(model.config, "eos_token_id", None)
105+
if self.eos_token_id is None:
106+
pruna_logger.warning("Warning: eos_token_id is None. This may affect generation stopping criteria.")
98107

99108
self.setup_cache()
100109

101110
self.decode_one_token = torch.compile( # type: ignore
102-
self.decode_one_token,
103-
mode=self.compile_mode,
104-
fullgraph=self.compile_fullgraph
111+
self.decode_one_token, mode=self.compile_mode, fullgraph=self.compile_fullgraph
105112
)
106113

107114
self.init()
@@ -110,6 +117,7 @@ def __init__(
110117
# Cuda Graph section
111118
self.static_input = torch.zeros((1, 1), device=self.device, dtype=torch.int32)
112119
self.static_output = torch.zeros((1, 1), device=self.device, dtype=torch.int32)
120+
self.original_gen_next_token = self.gen_next_token
113121
self.cuda_graph = None
114122
self.do_capture_graph = False
115123
############################
@@ -198,10 +206,7 @@ def logits_to_probs(self, logits: torch.Tensor, temperature: float = 1.0, top_k:
198206
return probs
199207

200208
def sample(
201-
self,
202-
logits: torch.Tensor,
203-
temperature: float = 1.0,
204-
top_k: int | None = None
209+
self, logits: torch.Tensor, temperature: float = 1.0, top_k: int | None = None
205210
) -> tuple[torch.Tensor, torch.Tensor]:
206211
"""
207212
Sample one token from the model.
@@ -286,18 +291,38 @@ def setup(self, inputs: torch.Tensor, max_new_tokens: int):
286291
None
287292
This method initializes internal state for generation but does not return a value.
288293
"""
294+
new_batch_size = inputs.shape[0]
295+
296+
# Check if batch size changed compared to the cache configuration
297+
if new_batch_size != self.cache_batch_size:
298+
pruna_logger.info(
299+
f"Batch size changed from {self.cache_batch_size} to {new_batch_size}. Re-initializing StaticCache."
300+
)
301+
self.batch_size = new_batch_size
302+
self.cache_batch_size = new_batch_size
303+
self.setup_cache()
304+
305+
# If CUDA graph was used, it's now invalid
306+
if hasattr(self, "cuda_graph") and self.cuda_graph is not None:
307+
pruna_logger.warning("CUDA graph is invalidated due to batch size change. Disabling CUDA graph usage.")
308+
self.cuda_graph = None
309+
self.gen_next_token = self.original_gen_next_token
310+
self.do_capture_graph = False
311+
312+
# Reset cache contents (does not change shape)
289313
self.reset_cache()
314+
290315
self.inputs = inputs
291316
self.batch_size, self.seq_length = self.inputs.shape
292317
self.cache_position = torch.arange(self.seq_length, device=self.device)
293-
# initialize the generated ids with zeros of the shape (batch_size, seq_length + max_new_tokens + 1)
318+
# initialize the generated ids with zeros
294319
self.generated_ids = torch.zeros(
295320
self.batch_size,
296321
self.seq_length + max_new_tokens + 1,
297322
dtype=torch.int,
298323
device=self.device,
299324
)
300-
# copy the input ids to the generated ids at the cache position.
325+
# copy the input ids to the generated ids
301326
self.generated_ids[:, self.cache_position] = self.inputs.to(torch.int)
302327

303328
def prefill(self) -> torch.Tensor:
@@ -340,11 +365,11 @@ def gen_next_token(self, current_token: torch.Tensor) -> torch.Tensor:
340365
The next token generated by the model.
341366
"""
342367
next_token = self.decode_one_token(
343-
current_token.clone(),
344-
cache_position=self.cache_position + 1,
345-
past_key_values=self.past_key_values,
346-
temperature=self.temperature,
347-
top_k=self.top_k,
368+
current_token.clone(),
369+
cache_position=self.cache_position + 1,
370+
past_key_values=self.past_key_values,
371+
temperature=self.temperature,
372+
top_k=self.top_k,
348373
)
349374
self.cache_position += 1
350375
self.generated_ids[:, self.cache_position] = next_token.int()
@@ -354,7 +379,7 @@ def enable_cuda_graph(
354379
self,
355380
iters: int = 2,
356381
prompt_tokenized: list[int] = [596, 8830, 315, 6913, 19476, 11, 1778, 439, 279, 12939],
357-
max_kv_cache_size: int = 1024
382+
max_kv_cache_size: int = 1024,
358383
) -> None:
359384
"""
360385
Enable the CUDA graph and capture the graph on random prompt.
@@ -375,17 +400,15 @@ def enable_cuda_graph(
375400
but does not return any value.
376401
"""
377402
_ = self.generate(
378-
torch.tensor(prompt_tokenized, device=self.model.device).unsqueeze(0),
379-
max_new_tokens=max_kv_cache_size
403+
torch.tensor(prompt_tokenized, device=self.model.device).unsqueeze(0), max_new_tokens=max_kv_cache_size
380404
)
381405
for _ in range(iters):
382406
# need to reset the graph before capturing it at each iteration
383407
# to avoid block/thread errors.
384408
self.do_capture_graph = True
385409
self.gen_next_token = self.gen_next_token_withgraph # type: ignore
386410
_ = self.generate(
387-
torch.tensor(prompt_tokenized, device=self.model.device).unsqueeze(0),
388-
max_new_tokens=max_kv_cache_size
411+
torch.tensor(prompt_tokenized, device=self.model.device).unsqueeze(0), max_new_tokens=max_kv_cache_size
389412
)
390413

391414
def gen_next_token_withgraph(self, current_token: torch.Tensor) -> torch.Tensor:
@@ -426,54 +449,100 @@ def gen_next_token_withgraph(self, current_token: torch.Tensor) -> torch.Tensor:
426449
return next_token
427450

428451
def next_token_iterator(
429-
self,
430-
current_token: torch.Tensor,
431-
max_new_tokens: int,
432-
cleanup: bool = True
452+
self, current_token: torch.Tensor, max_new_tokens: int, cleanup: bool = True
433453
) -> torch.Tensor:
434454
"""
435-
Generate the next token.
455+
Generate the next token, stopping at max_new_tokens or EOS for each sequence in the batch.
436456
437457
Parameters
438458
----------
439459
current_token : torch.Tensor
440-
The current token.
460+
The current token tensor of shape (batch_size, 1).
441461
max_new_tokens : int
442462
The maximum number of new tokens to generate.
443463
cleanup : bool
444-
Whether to cleanup the inputs, generated ids, and cache position.
464+
Whether to cleanup the inputs, generated ids, and cache position after generation.
445465
446466
Returns
447467
-------
448468
torch.Tensor
449-
The generated tokens.
469+
The generated tokens tensor of shape (batch_size, seq_length + generated_length),
470+
including the input prompt and potentially EOS tokens. Sequences that finish early
471+
will have EOS followed by padding (initial zeros).
450472
"""
473+
# Keep track of sequences that haven't finished yet (encountered EOS)
474+
# Assumes initial state is unfinished for all sequences in the batch
475+
unfinished_sequences = torch.ones(self.batch_size, dtype=torch.bool, device=self.device)
476+
477+
# Loop for a maximum of max_new_tokens - 1 steps (as prefill generates the first)
451478
for i in range(1, max_new_tokens):
452-
current_token = self.gen_next_token(current_token)
453-
output_tokens = self.generated_ids
479+
# Generate the next token for all sequences
480+
current_token = self.gen_next_token(current_token) # Updates self.generated_ids internally
481+
482+
# Check if the generated token is the EOS token for any currently unfinished sequence
483+
if self.eos_token_id is not None:
484+
# Check which sequences produced the EOS token THIS step
485+
# current_token shape is (batch_size, 1), squeeze to (batch_size,)
486+
# Only consider sequences that were previously unfinished
487+
finished_this_step = (current_token.squeeze(-1) == self.eos_token_id) & unfinished_sequences
488+
# Update the overall tracker for unfinished sequences
489+
unfinished_sequences &= ~finished_this_step
490+
491+
# Stop generation if all sequences in the batch have finished
492+
if not unfinished_sequences.any():
493+
break
494+
495+
# Determine the actual length generated (up to the current cache position)
496+
# .item() is safe as cache_position should be a 0-dim tensor
497+
final_seq_len = self.cache_position.item() + 1
498+
# Clone the relevant part of generated_ids before potential cleanup
499+
output_tokens = self.generated_ids[:, : int(final_seq_len)].clone()
454500

455501
if cleanup:
502+
# Delete internal state tensors, but not output_tokens which is returned
456503
del self.inputs, self.generated_ids, self.cache_position
457504
torch.cuda.empty_cache()
458505

459506
return output_tokens
460507

461508
@torch.inference_mode()
462-
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100) -> torch.Tensor:
509+
def generate(self, *args, **kwargs) -> torch.Tensor:
463510
"""
464-
Generate the tokens.
511+
Generate tokens using the model.
465512
466513
Parameters
467514
----------
468-
input_ids : torch.Tensor
469-
The input ids.
470-
max_new_tokens : int
471-
The maximum number of new tokens to generate.
515+
*args : tuple
516+
Variable length argument list (not used directly).
517+
**kwargs : dict
518+
Keyword arguments dictionary that must contain:
519+
- input_ids : torch.Tensor
520+
The input token ids that serve as the prompt.
521+
- max_new_tokens : int
522+
The maximum number of new tokens to generate.
472523
473524
Returns
474525
-------
475526
torch.Tensor
476-
The generated tokens.
527+
The generated tokens, including the input prompt and potentially an EOS token.
477528
"""
478-
self.setup(inputs=input_ids, max_new_tokens=max_new_tokens)
479-
return self.next_token_iterator(self.prefill(), max_new_tokens)
529+
# Extract parameters from kwargs with defaults from instance variables
530+
self.temperature = kwargs.pop("temperature", self.temperature)
531+
self.top_k = kwargs.pop("top_k", self.top_k)
532+
self.use_cache = kwargs.pop("use_cache", self.use_cache)
533+
534+
# Log any kwargs that are not explicitly handled
535+
unhandled_kwargs = {
536+
k: v
537+
for k, v in kwargs.items()
538+
if k not in ["input_ids", "max_new_tokens", "temperature", "top_k", "batch_size"]
539+
}
540+
if unhandled_kwargs:
541+
pruna_logger.warning(f"Unhandled kwargs in generate method: {unhandled_kwargs}")
542+
543+
# Update instance variables with any provided values
544+
self.setup(
545+
inputs=kwargs["input_ids"] if "input_ids" in kwargs else args[0],
546+
max_new_tokens=kwargs["max_new_tokens"] if "max_new_tokens" in kwargs else args[1],
547+
)
548+
return self.next_token_iterator(self.prefill(), kwargs["max_new_tokens"])

src/pruna/algorithms/quantization/hqq.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pruna.config.smash_config import SmashConfigPrefixWrapper
2525
from pruna.engine.model_checks import is_causal_lm
2626
from pruna.engine.save import SAVE_FUNCTIONS
27+
from pruna.engine.utils import move_to_device, safe_memory_cleanup
2728
from pruna.logging.filter import SuppressOutput
2829
from pruna.logging.logger import pruna_logger
2930

@@ -116,9 +117,10 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
116117

117118
quant_config_hqq = imported_modules["BaseQuantizeConfig"](nbits=weight_quantization_bits, group_size=group_size)
118119
quant_config_hf = imported_modules["HqqConfig"](nbits=weight_quantization_bits, group_size=group_size)
119-
120+
move_to_device(model, "cpu")
121+
safe_memory_cleanup()
120122
try: # Try to quantize the model using HQQ
121-
smashed_model = imported_modules["AutoHQQHFModel"].quantize_model(
123+
model = imported_modules["AutoHQQHFModel"].quantize_model(
122124
model,
123125
quant_config=quant_config_hqq,
124126
device=smash_config["device"],
@@ -131,7 +133,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
131133
temp_dir = tempfile.mkdtemp(dir=base_temp_dir)
132134
model.save_pretrained(temp_dir)
133135

134-
smashed_model = AutoModelForCausalLM.from_pretrained(
136+
model = AutoModelForCausalLM.from_pretrained(
135137
temp_dir,
136138
quantization_config=quant_config_hf,
137139
trust_remote_code=True,
@@ -149,7 +151,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
149151
except Exception as e:
150152
pruna_logger.error(f"Error: {e}")
151153
pass
152-
return smashed_model
154+
return model
153155

154156
def import_algorithm_packages(self) -> Dict[str, Any]:
155157
"""

0 commit comments

Comments
 (0)