diff --git a/build.rs b/build.rs index 51d2b7ac..5f3f3b0e 100644 --- a/build.rs +++ b/build.rs @@ -138,36 +138,69 @@ fn compile_cuda_kernels() { }; let out_dir = std::env::var("OUT_DIR").unwrap(); - let ptx_path = format!("{out_dir}/decode_kernels.ptx"); - if is_output_fresh(&[cu_src], &[&ptx_path]) { + // Compile sm_80 PTX (works on Ampere SM86/3090, Ada SM89, Hopper, Blackwell via JIT) + let ptx_sm80 = format!("{out_dir}/decode_kernels_sm80.ptx"); + + if is_output_fresh(&[cu_src], &[&ptx_sm80]) { println!("cargo:rustc-cfg=has_decode_kernels"); - println!("cargo:warning=Reusing cached GPU decode kernels at {ptx_path}"); - return; + println!("cargo:warning=Reusing cached GPU decode kernels at {ptx_sm80}"); + } else { + let status = std::process::Command::new(&nvcc) + .args([ + "-ptx", + "-arch=sm_80", + "-O3", + "--use_fast_math", + "-o", &ptx_sm80, + cu_src, + ]) + .status(); + + match status { + Ok(s) if s.success() => { + println!("cargo:rustc-cfg=has_decode_kernels"); + println!("cargo:warning=Compiled GPU decode kernels to sm_80 PTX ({ptx_sm80})"); + } + Ok(s) => { + println!("cargo:warning=nvcc failed with status {s} — GPU decode kernels disabled"); + return; + } + Err(e) => { + println!("cargo:warning=nvcc execution error: {e} — GPU decode kernels disabled"); + return; + } + } } - // Compile .cu to .ptx targeting sm_80 (works on Ampere, Ada, Hopper) - let status = std::process::Command::new(&nvcc) - .args([ - "-ptx", - "-arch=sm_80", - "-O3", - "--use_fast_math", - "-o", &ptx_path, - cu_src, - ]) - .status(); + // Compile sm_120 PTX (native Blackwell — better warp scheduling on RTX 50x0) + let ptx_sm120 = format!("{out_dir}/decode_kernels_sm120.ptx"); + if is_output_fresh(&[cu_src], &[&ptx_sm120]) { + println!("cargo:rustc-cfg=has_decode_kernels_sm120"); + println!("cargo:warning=Reusing cached sm_120 GPU decode kernels at {ptx_sm120}"); + } else { + let status_120 = std::process::Command::new(&nvcc) + .args([ + "-ptx", + "-arch=sm_120", + "-O3", + "--use_fast_math", + "-o", &ptx_sm120, + cu_src, + ]) + .status(); - match status { - Ok(s) if s.success() => { - println!("cargo:rustc-cfg=has_decode_kernels"); - println!("cargo:warning=Compiled GPU decode kernels to PTX ({ptx_path})"); - } - Ok(s) => { - println!("cargo:warning=nvcc failed with status {s} — GPU decode kernels disabled"); - } - Err(e) => { - println!("cargo:warning=nvcc execution error: {e} — GPU decode kernels disabled"); + match status_120 { + Ok(s) if s.success() => { + println!("cargo:rustc-cfg=has_decode_kernels_sm120"); + println!("cargo:warning=Compiled GPU decode kernels to sm_120 PTX ({ptx_sm120})"); + } + Ok(s) => { + println!("cargo:warning=nvcc sm_120 failed with status {s} — sm_80 PTX will be used via JIT"); + } + Err(e) => { + println!("cargo:warning=nvcc sm_120 execution error: {e} — sm_80 PTX will be used via JIT"); + } } } } diff --git a/python/krasis/model.py b/python/krasis/model.py index b909027d..06c0541b 100644 --- a/python/krasis/model.py +++ b/python/krasis/model.py @@ -1783,16 +1783,22 @@ def _init_gpu_prefill(self): len(self.gpu_prefill_managers), self.gpu_prefill_threshold, ) - def _start_ram_watchdog(self, floor_pct: float = 5.0): + def _start_ram_watchdog(self, floor_pct: float = 0.5): """Start daemon thread that monitors system RAM and exits if too low. - Checks /proc/meminfo every second. If MemAvailable drops below - floor_pct% of MemTotal, logs an error and calls os._exit() to + Checks /proc/meminfo every second. If (MemAvailable + SwapFree) drops + below floor_pct% of MemTotal, logs an error and calls os._exit() to prevent a full system OOM that kills desktop processes. + Swap is counted as available headroom because NVMe swap is fast enough + to absorb transient spikes (e.g. during WriteCombined expert migration). + Args: - floor_pct: Minimum % free RAM before forced exit (default 5%) + floor_pct: Minimum % free RAM+swap before forced exit (default 0.5%). + Override via KRASIS_RAM_FLOOR_PERCENT env var. """ + floor_pct = float(os.environ.get("KRASIS_RAM_FLOOR_PERCENT", str(floor_pct))) + def _watchdog(): while True: time.sleep(1.0) @@ -1800,13 +1806,13 @@ def _watchdog(): if not meminfo: continue total_kb = meminfo.get("MemTotal", 0) - avail_kb = meminfo.get("MemAvailable", 0) + avail_kb = meminfo.get("MemAvailable", 0) + meminfo.get("SwapFree", 0) if total_kb == 0: continue pct_free = 100.0 * avail_kb / total_kb if pct_free < floor_pct: logger.error( - "RAM WATCHDOG: %.1f%% free (%.1f GB available / %.1f GB total) " + "RAM WATCHDOG: %.1f%% free (%.1f GB available+swap / %.1f GB total) " "— below %.1f%% floor. Exiting to prevent system OOM!", pct_free, avail_kb / 1024 / 1024, total_kb / 1024 / 1024, floor_pct, @@ -1815,7 +1821,7 @@ def _watchdog(): t = threading.Thread(target=_watchdog, daemon=True, name="ram-watchdog") t.start() - logger.info("RAM watchdog started: will exit if < %.1f%% free", floor_pct) + logger.info("RAM watchdog started: will exit if < %.1f%% free (RAM+swap)", floor_pct) # ── Multi-GPU calibration: replicate weights + measure inference cost ── @@ -4518,6 +4524,13 @@ def _register_attn_weight(w: torch.Tensor, layer_idx: int = -1, if self.krasis_engine is not None: store.setup_from_engine(self.krasis_engine) + # WriteCombined DMA staging: migrate expert weights from heap to WC memory. + # Must happen after setup_from_engine (pointers exist) but before + # allocate_prefill_engine (which reads the updated pointers). + if getattr(self, 'wc_alloc', False): + wc_msg = store.setup_wc_expert_memory(self.krasis_engine) + logger.info("WC expert memory: %s", wc_msg) + # Register Nemotron MoE config (relu2, ungated, latent projections) — must come # after setup_from_engine which populates moe_layers[abs_layer_idx]. if self.cfg.model_type == "nemotron_h": diff --git a/python/krasis/server.py b/python/krasis/server.py index f20e806f..3bb0d17c 100644 --- a/python/krasis/server.py +++ b/python/krasis/server.py @@ -710,6 +710,11 @@ def _force_exit_handler(sig, frame): help="Enable Session messenger bridge (default: off)") parser.add_argument("--test-endpoints", action="store_true", default=False, help="Enable test-only endpoints (/v1/internal/prefill_logits)") + parser.add_argument("--wc-alloc", action="store_true", default=False, + help="Use WriteCombined host memory for expert DMA staging. " + "Bypasses CPU cache for ~64%% higher PCIe bandwidth " + "(~46 GB/s vs ~28 GB/s on Gen5). Requires NVMe swap on " + "RAM-constrained systems.") # Apply config file defaults, then parse CLI (CLI wins over config file) if config_defaults: parser.set_defaults(**config_defaults) @@ -912,6 +917,11 @@ def fileno(self): # ── Set decode mode (GPU only) ── _model.decode_mode = "gpu" + # ── WriteCombined DMA staging ── + if getattr(args, 'wc_alloc', False): + _model.wc_alloc = True + _detail("WriteCombined expert DMA staging enabled (--wc-alloc)") + # ── GPU decode store setup (before warmup so decode warmup can use it) ── _status("Setting up GPU decode store") gpu_store = _model.setup_gpu_decode_store() diff --git a/src/gpu_decode.rs b/src/gpu_decode.rs index ef83cfe7..1a4d1f32 100644 --- a/src/gpu_decode.rs +++ b/src/gpu_decode.rs @@ -34,8 +34,12 @@ fn prefill_debug_enabled() -> bool { } // PTX compiled from src/cuda/decode_kernels.cu at build time. +// sm_80 works on Ampere/Ada/Hopper (JIT-compiled on newer arches). +// sm_120 is native Blackwell with better warp scheduling on RTX 50x0. #[cfg(has_decode_kernels)] -const DECODE_KERNELS_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/decode_kernels.ptx")); +const DECODE_KERNELS_PTX_SM80: &str = include_str!(concat!(env!("OUT_DIR"), "/decode_kernels_sm80.ptx")); +#[cfg(has_decode_kernels_sm120)] +const DECODE_KERNELS_PTX_SM120: &str = include_str!(concat!(env!("OUT_DIR"), "/decode_kernels_sm120.ptx")); // CUDA graph types from cudarc's sys bindings (dynamically loaded via cuda_sys::lib()) type CUgraph = cuda_sys::CUgraph; @@ -1916,6 +1920,40 @@ struct SingleSlotSwapEntry { simple_scales_host: Vec, } +// ── WriteCombined DMA staging ────────────────────────────────────────── + +/// RAII wrapper for WriteCombined host memory allocated via cuMemHostAlloc. +/// WC memory bypasses CPU cache, giving the GPU DMA engine uncontested access +/// to the memory bus — ~46 GB/s vs ~28 GB/s for regular pinned memory on PCIe Gen5. +struct WcBuffer { + ptr: *mut u8, + size: usize, +} + +impl Drop for WcBuffer { + fn drop(&mut self) { + if !self.ptr.is_null() { + unsafe { cuda_sys::lib().cuMemFreeHost(self.ptr as *mut std::ffi::c_void); } + } + } +} + +unsafe impl Send for WcBuffer {} +unsafe impl Sync for WcBuffer {} + +/// Per-layer WC buffer pointers matching LayerExpertBacking layout. +/// Used to rebuild prefill tensor views after WC migration. +struct WcLayerPtrs { + w13p_ptr: usize, + w13p_len: usize, + w13s_ptr: usize, + w13s_len: usize, + w2p_ptr: usize, + w2p_len: usize, + w2s_ptr: usize, + w2s_len: usize, +} + // ── PyO3 wrapper ─────────────────────────────────────────────────────── #[pyclass] @@ -2012,6 +2050,11 @@ pub struct GpuDecodeStore { /// Stored when the prefill engine is created so it's available for eviction checks /// even after the engine is taken by the RustServer. prefill_scratch_info: Option<(usize, usize)>, + /// WriteCombined host memory allocations for expert DMA staging. + /// Freed on drop via cuMemFreeHost. + wc_expert_buffers: Vec, + /// Per-layer WC buffer section pointers (per-component layout for prefill views). + wc_layer_ptrs: Vec, } #[pymethods] @@ -2096,12 +2139,40 @@ impl GpuDecodeStore { let mut gqa_smem_limit: u32 = 48 * 1024; // default - // Load CUDA decode kernels from embedded PTX + // Load CUDA decode kernels from embedded PTX, selecting arch at runtime #[cfg(has_decode_kernels)] { use cudarc::nvrtc::Ptx; + + // Detect GPU compute capability for PTX selection + let sm_major = unsafe { + let mut dev: i32 = 0; + cuda_sys::lib().cuCtxGetDevice(&mut dev); + let mut major = 0i32; + cuda_sys::lib().cuDeviceGetAttribute( + &mut major, + cuda_sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, + dev, + ); + major + }; + + #[cfg(has_decode_kernels_sm120)] + let ptx_src = if sm_major >= 12 { + log::info!("GpuDecodeStore: using native sm_120 PTX for Blackwell (SM {}.x)", sm_major); + DECODE_KERNELS_PTX_SM120 + } else { + log::info!("GpuDecodeStore: using sm_80 PTX for SM {}.x (JIT to native)", sm_major); + DECODE_KERNELS_PTX_SM80 + }; + #[cfg(not(has_decode_kernels_sm120))] + let ptx_src = { + log::info!("GpuDecodeStore: using sm_80 PTX for SM {}.x (sm_120 not compiled)", sm_major); + DECODE_KERNELS_PTX_SM80 + }; + device.load_ptx( - Ptx::from_src(DECODE_KERNELS_PTX), + Ptx::from_src(ptx_src), MODULE_NAME, KERNEL_NAMES, ).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err( @@ -2269,6 +2340,8 @@ impl GpuDecodeStore { debug_layer_captures: Vec::new(), prefill_engine_slot: None, prefill_scratch_info: None, + wc_expert_buffers: Vec::new(), + wc_layer_ptrs: Vec::new(), }) } @@ -2292,6 +2365,179 @@ impl GpuDecodeStore { } } + // ── WriteCombined DMA staging ────────────────────────────────────── + + /// Migrate expert weights from heap memory to WriteCombined host memory. + /// + /// WC memory bypasses CPU caches, giving the GPU DMA engine uncontested + /// access to the memory bus (~46 GB/s vs ~28 GB/s on PCIe Gen5). + /// After migration, the original heap backing is freed incrementally + /// to keep peak RAM manageable. + #[pyo3(signature = (engine))] + fn setup_wc_expert_memory(&mut self, engine: &crate::moe::KrasisEngine) -> PyResult { + let t0 = std::time::Instant::now(); + let graph = self.graph.as_mut() + .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err( + "setup_wc_expert_memory: call setup_from_engine first"))?; + + let store = engine.get_weight_store() + .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err( + "setup_wc_expert_memory: engine has no weight store"))?; + + let mut total_wc_bytes: usize = 0; + let mut wc_buffers: Vec = Vec::new(); + let num_moe_layers = graph.moe_layers.len(); + let mut wc_layer_ptrs: Vec = (0..num_moe_layers).map(|_| WcLayerPtrs { + w13p_ptr: 0, w13p_len: 0, w13s_ptr: 0, w13s_len: 0, + w2p_ptr: 0, w2p_len: 0, w2s_ptr: 0, w2s_len: 0, + }).collect(); + + for (layer_idx, moe_layer_opt) in graph.moe_layers.iter_mut().enumerate() { + let moe_layer = match moe_layer_opt { + Some(ref mut ml) => ml, + None => continue, + }; + + if layer_idx >= store.layer_backings_gpu.len() { + continue; + } + let backing = &store.layer_backings_gpu[layer_idx]; + let w13p_len = backing.w13_packed.len(); + let w13s_len = backing.w13_scales.len(); + let w2p_len = backing.w2_packed.len(); + let w2s_len = backing.w2_scales.len(); + let layer_bytes = w13p_len + w13s_len + w2p_len + w2s_len; + + if layer_bytes == 0 { + continue; + } + + // Allocate WC buffer: WRITECOMBINED (0x01) | PORTABLE (0x02) + let flags: u32 = 0x03; + let mut wc_ptr: *mut std::ffi::c_void = std::ptr::null_mut(); + let err = unsafe { + cuda_sys::lib().cuMemHostAlloc(&mut wc_ptr, layer_bytes, flags) + }; + if err != cuda_sys::CUresult::CUDA_SUCCESS { + return Err(pyo3::exceptions::PyRuntimeError::new_err( + format!("cuMemHostAlloc WC failed for layer {} ({} MB): {:?}", + layer_idx, layer_bytes / (1024 * 1024), err))); + } + let wc_base = wc_ptr as *mut u8; + + // Copy per-component from LayerExpertBacking into WC + let mut off = 0usize; + let w13p_ptr = unsafe { wc_base.add(off) }; + unsafe { std::ptr::copy_nonoverlapping(backing.w13_packed.as_ptr(), w13p_ptr, w13p_len); } + off += w13p_len; + + let w13s_ptr = unsafe { wc_base.add(off) }; + unsafe { std::ptr::copy_nonoverlapping(backing.w13_scales.as_ptr(), w13s_ptr, w13s_len); } + off += w13s_len; + + let w2p_ptr = unsafe { wc_base.add(off) }; + unsafe { std::ptr::copy_nonoverlapping(backing.w2_packed.as_ptr(), w2p_ptr, w2p_len); } + off += w2p_len; + + let w2s_ptr = unsafe { wc_base.add(off) }; + unsafe { std::ptr::copy_nonoverlapping(backing.w2_scales.as_ptr(), w2s_ptr, w2s_len); } + + // Store per-component pointers for prefill bulk DMA views + wc_layer_ptrs[layer_idx] = WcLayerPtrs { + w13p_ptr: w13p_ptr as usize, w13p_len, + w13s_ptr: w13s_ptr as usize, w13s_len, + w2p_ptr: w2p_ptr as usize, w2p_len, + w2s_ptr: w2s_ptr as usize, w2s_len, + }; + + // Update per-expert decode pointers to WC memory + for (ei, expert) in moe_layer.experts.iter_mut().enumerate() { + let e_w13p = backing.per_expert_w13p; + let e_w13s = backing.per_expert_w13s; + let e_w2p = backing.per_expert_w2p; + let e_w2s = backing.per_expert_w2s; + + expert.w13_packed_ptr = w13p_ptr as usize + ei * e_w13p; + expert.w13_scales_ptr = w13s_ptr as usize + ei * e_w13s; + expert.w2_packed_ptr = w2p_ptr as usize + ei * e_w2p; + expert.w2_scales_ptr = w2s_ptr as usize + ei * e_w2s; + expert.w13_packed_bytes = e_w13p; + expert.w13_scales_bytes = e_w13s; + expert.w2_packed_bytes = e_w2p; + expert.w2_scales_bytes = e_w2s; + // Per-component layout: force 4-call DMA path + expert.contiguous_ptr = 0; + expert.contiguous_bytes = 0; + } + + // Update shared expert if present + let n_experts = moe_layer.experts.len(); + if let Some(ref mut se) = moe_layer.shared { + if n_experts > 0 && w13p_len > n_experts * backing.per_expert_w13p { + se.w13_packed_ptr = w13p_ptr as usize + n_experts * backing.per_expert_w13p; + se.w13_scales_ptr = w13s_ptr as usize + n_experts * backing.per_expert_w13s; + se.w2_packed_ptr = w2p_ptr as usize + n_experts * backing.per_expert_w2p; + se.w2_scales_ptr = w2s_ptr as usize + n_experts * backing.per_expert_w2s; + se.contiguous_ptr = 0; + se.contiguous_bytes = 0; + } + } + + wc_buffers.push(WcBuffer { ptr: wc_base, size: layer_bytes }); + total_wc_bytes += layer_bytes; + + // Free original heap backing (keeps peak RAM manageable) + { + let bufs: [&[u8]; 4] = [ + &backing.w13_packed, &backing.w13_scales, + &backing.w2_packed, &backing.w2_scales, + ]; + for buf in &bufs { + if !buf.is_empty() { + unsafe { + let _ = cuda_sys::lib().cuMemHostUnregister( + buf.as_ptr() as *mut std::ffi::c_void, + ); + } + } + } + } + let freed = store.free_layer_backing_gpu(layer_idx); + if freed > 0 && (layer_idx + 1) % 10 == 0 { + log::info!("WC migration: {}/{} layers, freed {:.1} MB heap", + layer_idx + 1, num_moe_layers, freed as f64 / 1e6); + } + } + + self.wc_expert_buffers = wc_buffers; + self.wc_layer_ptrs = wc_layer_ptrs; + + let elapsed = t0.elapsed().as_secs_f64(); + let msg = format!( + "WC expert memory: {:.1} GB across {} layers in {:.1}s (heap freed incrementally)", + total_wc_bytes as f64 / (1024.0 * 1024.0 * 1024.0), + self.wc_expert_buffers.len(), elapsed, + ); + log::info!("{}", msg); + Ok(msg) + } + + /// Return per-component WC buffer pointers for a MoE layer (for prefill view rebuilding). + fn get_wc_layer_buffer_ptrs(&self, moe_layer_idx: usize) -> (usize, usize, usize, usize, usize, usize, usize, usize) { + if moe_layer_idx < self.wc_layer_ptrs.len() { + let p = &self.wc_layer_ptrs[moe_layer_idx]; + (p.w13p_ptr, p.w13p_len, p.w13s_ptr, p.w13s_len, + p.w2p_ptr, p.w2p_len, p.w2s_ptr, p.w2s_len) + } else { + (0, 0, 0, 0, 0, 0, 0, 0) + } + } + + /// Check if WC expert memory has been set up. + fn has_wc_expert_memory(&self) -> bool { + !self.wc_expert_buffers.is_empty() + } + /// Return the max VRAM (MB) that prefill scratch will need for a given prompt size. /// Used by the VRAM budget calculator to reserve space for dynamic scratch growth. #[pyo3(signature = (max_tokens))] diff --git a/src/gpu_prefill.rs b/src/gpu_prefill.rs index 4d5020df..179343b9 100644 --- a/src/gpu_prefill.rs +++ b/src/gpu_prefill.rs @@ -9088,7 +9088,7 @@ mod kernel_tests { ).expect("Failed to load prefill kernels PTX"); #[cfg(has_decode_kernels)] dev.load_ptx( - cudarc::nvrtc::Ptx::from_src(include_str!(concat!(env!("OUT_DIR"), "/decode_kernels.ptx"))), + cudarc::nvrtc::Ptx::from_src(include_str!(concat!(env!("OUT_DIR"), "/decode_kernels_sm80.ptx"))), "decode_kernels", &[ "kv_cache_write_polar4", diff --git a/src/weights/mod.rs b/src/weights/mod.rs index 7161c029..81f74fb9 100644 --- a/src/weights/mod.rs +++ b/src/weights/mod.rs @@ -1503,6 +1503,38 @@ impl WeightStore { gpu_num_bits: 4, } } + + /// Free the heap-backed Vecs in a LayerExpertBacking, returning total bytes freed. + /// + /// Used by WriteCombined DMA migration: after expert data is copied into WC memory + /// and decode pointers are redirected, the original heap backing is dead weight. + /// Freeing it incrementally (one layer at a time) keeps peak RAM manageable. + pub fn free_layer_backing_gpu(&self, layer_idx: usize) -> usize { + if layer_idx >= self.layer_backings_gpu.len() { + return 0; + } + // SAFETY: The backing data is dead — pointers have been redirected to WC memory. + // No concurrent access occurs during WC setup. We use ptr::read + ptr::write to + // swap Vecs without creating &mut through a const-to-mut cast. + unsafe { + let backing_ptr = &self.layer_backings_gpu[layer_idx] as *const LayerExpertBacking + as *mut LayerExpertBacking; + + macro_rules! swap_free { + ($field:ident) => {{ + let field_ptr = std::ptr::addr_of_mut!((*backing_ptr).$field); + let old = std::ptr::read(field_ptr); + let len = old.len(); + std::ptr::write(field_ptr, Vec::new()); + drop(old); + len + }}; + } + + swap_free!(w13_packed) + swap_free!(w13_scales) + + swap_free!(w2_packed) + swap_free!(w2_scales) + } + } } impl WeightStore {