Skip to content

Commit 0d5f576

Browse files
fix: make kv cache dynamic based on input (#87)
* fix: this makes the kv cache dynamic based on the input * fix: add the input size as a parameter as well * fix: recompile cuda graph in case of max-autotune-no-cudagraphs model * fix: remove support for max-autotune-no-cudagraphs mode
1 parent 3aebee7 commit 0d5f576

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

src/pruna/algorithms/compilation/torch_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def causal_lm_logic(model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
347347
)
348348
# If we are using max-autotune-no-cudagraphs, we need to handle the cudagraphs manually.
349349
if smash_config["mode"] == "max-autotune-no-cudagraphs":
350-
gen.enable_cuda_graph(max_kv_cache_size=smash_config["seqlen_manual_cuda_graph"])
350+
pruna_logger.error("max-autotune-no-cudagraphs is not supported for causal language models.")
351351
model.generate = gen.generate
352352
return model
353353

src/pruna/algorithms/compilation/utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -294,20 +294,24 @@ def setup(self, inputs: torch.Tensor, max_new_tokens: int):
294294
new_batch_size = inputs.shape[0]
295295

296296
# Check if batch size changed compared to the cache configuration
297-
if new_batch_size != self.cache_batch_size:
297+
# Round up max_new_tokens to the nearest 1000 for better memory allocation
298+
rounded_cache_size = ((inputs.shape[1] + max_new_tokens + 999) // 1000) * 1000
299+
if new_batch_size != self.cache_batch_size or self.cache_size != rounded_cache_size:
298300
pruna_logger.info(
299-
f"Batch size changed from {self.cache_batch_size} to {new_batch_size}. Re-initializing StaticCache."
301+
f"Cache size changed from {self.cache_batch_size}x{self.cache_size} to "
302+
f"{new_batch_size}x{rounded_cache_size}. Re-initializing StaticCache."
300303
)
301304
self.batch_size = new_batch_size
302305
self.cache_batch_size = new_batch_size
306+
self.cache_size = rounded_cache_size
303307
self.setup_cache()
304308

305-
# If CUDA graph was used, it's now invalid
309+
# If CUDA graph was used, recompile the graph
306310
if hasattr(self, "cuda_graph") and self.cuda_graph is not None:
307-
pruna_logger.warning("CUDA graph is invalidated due to batch size change. Disabling CUDA graph usage.")
308-
self.cuda_graph = None
309-
self.gen_next_token = self.original_gen_next_token
310-
self.do_capture_graph = False
311+
pruna_logger.warning(
312+
"CUDA graph is invalidated due to batch size or cache size change. Recompiling the graph."
313+
)
314+
self.enable_cuda_graph(max_kv_cache_size=self.cache_size)
311315

312316
# Reset cache contents (does not change shape)
313317
self.reset_cache()

0 commit comments

Comments
 (0)