@@ -205,6 +205,8 @@ def deserialize(cls, obj: dict) -> ContextOverflowError:
205205
206206def get_mem_size_str (n_bytes : int ) -> str :
207207 """Convert number of bytes to human-readable string."""
208+ if n_bytes == 0 :
209+ return "0 bytes"
208210 for exp , suffix in ((4 , "TB" ), (3 , "GB" ), (2 , "MB" ), (3 , "KB" ), (0 , "bytes" )):
209211 nquery = int (1024 ** exp )
210212 if round (n_bytes / nquery ) >= 1 :
@@ -449,6 +451,26 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
449451 buffer_size_bytes = int (buffer_size_bytes * (1.0 - mamba_memory_ratio ))
450452 paused_buffer_size_bytes = int (paused_buffer_size_bytes * (1.0 - mamba_memory_ratio ))
451453
454+ block_count = buffer_size_bytes // self .block_size_bytes
455+ block_count = max (2 , block_count ) # need >= 1 active block + 1 dummy block
456+ paused_block_count = paused_buffer_size_bytes // self .block_size_bytes
457+ elif self .is_hybrid_model and inference_config .max_requests is not None :
458+ # Auto-derive mamba/KV split from max_requests. Allocate exactly enough
459+ # mamba memory for max_requests, and give the rest to KV cache blocks.
460+ total_memory = buffer_size_bytes + paused_buffer_size_bytes
461+ mamba_memory_needed = inference_config .max_requests * mamba_states_memory_per_request
462+ assert mamba_memory_needed < total_memory , (
463+ f"Not enough memory for { inference_config .max_requests } mamba requests. "
464+ f"Need { mamba_memory_needed / 1024 ** 3 :.2f} GB for mamba states, "
465+ f"but total buffer is { total_memory / 1024 ** 3 :.2f} GB."
466+ )
467+ mamba_max_requests = inference_config .max_requests
468+
469+ # Subtract mamba memory proportionally from active and paused buffers.
470+ mamba_memory_ratio = mamba_memory_needed / total_memory
471+ buffer_size_bytes = int (buffer_size_bytes * (1.0 - mamba_memory_ratio ))
472+ paused_buffer_size_bytes = int (paused_buffer_size_bytes * (1.0 - mamba_memory_ratio ))
473+
452474 block_count = buffer_size_bytes // self .block_size_bytes
453475 block_count = max (2 , block_count ) # need >= 1 active block + 1 dummy block
454476 paused_block_count = paused_buffer_size_bytes // self .block_size_bytes
@@ -594,13 +616,76 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
594616 self .initialize_all_tensors ()
595617
596618 # Print info.
597- logging .info (
598- "DynamicInferenceContext: allocated context with active buffer size %s (%d blocks)."
599- % (
600- get_mem_size_str (self .kv_block_allocator .active_count * self .block_size_bytes ),
601- self .kv_block_allocator .active_count ,
619+ active_blocks = self .kv_block_allocator .active_count
620+ total_blocks = self .kv_block_allocator .total_count
621+ paused_blocks = self .kv_block_allocator .paused_count
622+ active_kv_bytes = active_blocks * self .block_size_bytes
623+ total_kv_bytes = total_blocks * self .block_size_bytes
624+ paused_kv_bytes = paused_blocks * self .block_size_bytes
625+
626+ log_lines = [
627+ "DynamicInferenceContext: configuration summary" ,
628+ f" max_requests: { self .max_requests } " ,
629+ f" max_tokens: { self .max_tokens } " ,
630+ f" max_sequence_length: { self .max_sequence_length } " ,
631+ f" block_size_tokens: { self .block_size_tokens } " ,
632+ f" max_kv_blocks_per_req: { self .max_kv_block_count } " ,
633+ f" KV cache:" ,
634+ f" block_size_bytes: { get_mem_size_str (self .block_size_bytes )} " ,
635+ f" active_blocks: { active_blocks } ({ get_mem_size_str (active_kv_bytes )} )" ,
636+ f" paused_blocks: { paused_blocks } ({ get_mem_size_str (paused_kv_bytes )} )" ,
637+ f" total_blocks: { total_blocks } ({ get_mem_size_str (total_kv_bytes )} )" ,
638+ ]
639+
640+ if self .is_hybrid_model :
641+ mamba_conv_bytes = (
642+ math .prod (self .mamba_conv_states_shape )
643+ * self .mamba_conv_states_dtype .itemsize
644+ * self .num_mamba_layers
602645 )
603- )
646+ mamba_ssm_bytes = (
647+ math .prod (self .mamba_ssm_states_shape )
648+ * self .mamba_ssm_states_dtype .itemsize
649+ * self .num_mamba_layers
650+ )
651+ mamba_bytes_per_req = mamba_conv_bytes + mamba_ssm_bytes
652+ mamba_total_bytes = mamba_bytes_per_req * self .max_requests
653+ log_lines += [
654+ f" Mamba states:" ,
655+ f" num_mamba_layers: { self .num_mamba_layers } " ,
656+ f" conv_state_shape: { self .mamba_conv_states_shape } " ,
657+ f" ssm_state_shape: { self .mamba_ssm_states_shape } " ,
658+ f" per_request: { get_mem_size_str (mamba_bytes_per_req )} " ,
659+ f" total ({ self .max_requests } requests): { get_mem_size_str (mamba_total_bytes )} " ,
660+ ]
661+
662+ if self .num_speculative_tokens > 0 :
663+ spec_multiplier = self .num_speculative_tokens + 1
664+ spec_bytes_per_req = mamba_bytes_per_req * spec_multiplier
665+ spec_total_bytes = spec_bytes_per_req * self .max_requests
666+ log_lines += [
667+ f" Mamba speculative buffers (num_speculative_tokens={ self .num_speculative_tokens } ):" ,
668+ f" per_request: { get_mem_size_str (spec_bytes_per_req )} " ,
669+ f" total ({ self .max_requests } requests): { get_mem_size_str (spec_total_bytes )} " ,
670+ ]
671+
672+ prefix_caching_mamba_gb = inference_config .prefix_caching_mamba_gb
673+ if (
674+ inference_config .enable_prefix_caching
675+ and prefix_caching_mamba_gb is not None
676+ and prefix_caching_mamba_gb > 0
677+ ):
678+ prefix_cache_bytes = int (prefix_caching_mamba_gb * 1024 ** 3 )
679+ prefix_cache_slots = prefix_cache_bytes // mamba_bytes_per_req
680+ log_lines += [
681+ f" Mamba prefix cache:" ,
682+ f" budget: { get_mem_size_str (prefix_cache_bytes )} " ,
683+ f" slots: { prefix_cache_slots } " ,
684+ f" per_slot: { get_mem_size_str (mamba_bytes_per_req )} " ,
685+ ]
686+
687+ if inference_config ._verbose and torch .distributed .get_rank () == 0 :
688+ logging .info ("\n " .join (log_lines ))
604689
605690 def _allocate_memory_buffer (self ):
606691 """Allocate the KV cache memory buffer."""
0 commit comments