diff --git a/inference_cli.py b/inference_cli.py index 2d4fff18..e645377d 100644 --- a/inference_cli.py +++ b/inference_cli.py @@ -50,6 +50,9 @@ import time import platform import multiprocessing as mp +import tempfile +import threading +import gc from typing import Dict, Any, List, Optional, Tuple, Literal, Generator from datetime import datetime from pathlib import Path @@ -110,6 +113,11 @@ import numpy as np import subprocess import shutil +_psutil_missing_warned = False +try: + import psutil # type: ignore +except ImportError: # pragma: no cover + psutil = None # Project imports from src.utils.downloads import download_weight @@ -163,50 +171,199 @@ class FFMPEGVideoWriter: Internally converts to RGB for ffmpeg rawvideo input. """ - def __init__(self, path: str, width: int, height: int, fps: float, use_10bit: bool = False): - pix_fmt = 'yuv420p10le' if use_10bit else 'yuv420p' - codec = 'libx265' if use_10bit else 'libx264' + def __init__( + self, + path: str, + width: int, + height: int, + fps: float, + use_10bit: bool = False, + codec: Optional[str] = None, + pix_fmt: Optional[str] = None, + preset: str = "medium", + crf: Optional[int] = None, + bitrate: Optional[str] = None, + input_pix_fmt: str = "rgb24" + ): + pix_fmt_effective = pix_fmt or ('yuv420p10le' if use_10bit else 'yuv420p') + codec_effective = codec or ('libx265' if use_10bit else 'libx264') + crf_effective = crf if crf is not None else 12 + self._input_dtype = np.uint16 if input_pix_fmt == "rgb48le" else np.uint8 + + cmd = [ + 'ffmpeg', '-y', '-f', 'rawvideo', '-pix_fmt', input_pix_fmt, + '-s', f'{width}x{height}', '-r', str(fps), '-i', '-', + '-c:v', codec_effective, '-pix_fmt', pix_fmt_effective, '-preset', preset + ] + if bitrate: + cmd.extend(['-b:v', bitrate]) + else: + cmd.extend(['-crf', str(crf_effective)]) + cmd.append(path) self.proc = subprocess.Popen( - ['ffmpeg', '-y', '-f', 'rawvideo', '-pix_fmt', 'rgb24', - '-s', f'{width}x{height}', '-r', str(fps), '-i', '-', - '-c:v', codec, '-pix_fmt', pix_fmt, '-preset', 'medium', '-crf', '12', path], - stdin=subprocess.PIPE, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + cmd, + stdin=subprocess.PIPE, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE ) def write(self, frame_bgr: np.ndarray): - if not self.isOpened(): - raise RuntimeError("FFMPEGVideoWriter: ffmpeg process is not running") - frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) - try: - self.proc.stdin.write(frame_rgb.astype(np.uint8).tobytes()) - self.proc.stdin.flush() # Critical: prevent buffering issues - except BrokenPipeError: - raise RuntimeError( - "FFMPEGVideoWriter: ffmpeg process terminated unexpectedly. " - "Check video path, codec support, and disk space." - ) + self.proc.stdin.write(frame_rgb.astype(self._input_dtype).tobytes()) def isOpened(self) -> bool: return self.proc is not None and self.proc.poll() is None def release(self): if self.proc: - try: - self.proc.stdin.close() - except Exception: - pass # Ignore errors on close - + self.proc.stdin.close() self.proc.wait() - + stderr = self.proc.stderr.read() if self.proc.stderr else b'' if self.proc.returncode != 0: + debug.log(f"ffmpeg error: {stderr.decode()}", level="WARNING", category="file") + self.proc = None + + +# ============================================================================= +# Resilience Helpers +# ============================================================================= + +def _format_bytes(num_bytes: int) -> str: + """ + Convert bytes to human-readable string. + """ + units = ["B", "KB", "MB", "GB", "TB"] + size = float(num_bytes) + idx = 0 + while size >= 1024 and idx < len(units) - 1: + size /= 1024.0 + idx += 1 + return f"{size:.2f} {units[idx]}" + + +def _log_ram_usage(debug: Debug, label: str, force: bool = False) -> None: + """ + Log current process RSS (RAM) and system usage if psutil is available. + """ + global _psutil_missing_warned # noqa: PLW0603 + if psutil is None: + if not _psutil_missing_warned: + debug.log("[RAM] psutil not installed; install psutil for RAM telemetry", category="memory", force=True) + _psutil_missing_warned = True + return + try: + proc = psutil.Process(os.getpid()) + rss = proc.memory_info().rss + sys_mem = psutil.virtual_memory() + debug.log( + f"[RAM] {label}: rss={_format_bytes(rss)} | system_used={_format_bytes(sys_mem.used)}/{_format_bytes(sys_mem.total)} ({sys_mem.percent:.1f}%)", + category="memory", + force=force + ) + except Exception: + pass + + +def _start_memory_monitor(pids: List[int], debug: Debug, label: str = "workers", interval: float = 10.0): + """ + Spawn a background thread that periodically logs RSS for a list of PIDs. + """ + if psutil is None: + _log_ram_usage(debug, f"{label} monitor skipped (psutil missing)", force=True) + return None, None + + stop_event = threading.Event() + + def _monitor(): + while not stop_event.wait(interval): + try: + rss_total = 0 + per_proc = [] + for pid in pids: + try: + p = psutil.Process(pid) + rss = p.memory_info().rss + rss_total += rss + per_proc.append(f"{pid}:{_format_bytes(rss)}") + except psutil.NoSuchProcess: + continue + sys_mem = psutil.virtual_memory() debug.log( - f"ffmpeg exited with code {self.proc.returncode}. " - "Check output file for corruption.", - level="WARNING", force=True, category="file" + f"[RAM] {label} total={_format_bytes(rss_total)} | system_used={_format_bytes(sys_mem.used)}/{_format_bytes(sys_mem.total)} ({sys_mem.percent:.1f}%) | per_pid={', '.join(per_proc)}", + category="memory", + force=True ) - self.proc = None + except Exception: + continue + + t = threading.Thread(target=_monitor, daemon=True) + t.start() + return stop_event, t + +def _is_oom_error(exc: BaseException) -> bool: + """ + Detect CUDA/CPU OOM errors from common exception types/messages. + """ + msg = str(exc).lower() + return isinstance(exc, (torch.cuda.OutOfMemoryError, MemoryError)) or "out of memory" in msg + + +def _retry_with_cleanup(fn, debug: Debug, description: str = "operation", backoff_base: float = 1.0): + """ + Retry a callable indefinitely on OOM, with aggressive cleanup and backoff. + """ + attempt = 0 + while True: + try: + return fn() + except Exception as exc: # noqa: BLE001 + if not _is_oom_error(exc): + raise + attempt += 1 + wait = min(30.0, backoff_base * attempt) + debug.log( + f"{description} failed with OOM (attempt {attempt}), cleaning up and retrying in {wait:.1f}s", + level="WARNING", + category="memory", + force=True + ) + clear_memory(debug=debug, deep=True, force=True, timer_name="oom_retry") + time.sleep(wait) + + +def _save_chunk_with_retry(chunk_np: np.ndarray, chunk_path: str, debug: Debug, retries: int = 3) -> None: + """ + Save numpy chunk to disk with retries and atomic rename to avoid partial files. + """ + dir_path = os.path.dirname(chunk_path) + os.makedirs(dir_path, exist_ok=True) + tmp_path = chunk_path + ".tmp" + bytes_needed = chunk_np.nbytes + + for attempt in range(1, retries + 1): + try: + with open(tmp_path, "wb") as f: + np.save(f, chunk_np, allow_pickle=False) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, chunk_path) + return + except OSError as exc: + free_bytes = shutil.disk_usage(dir_path).free if os.path.isdir(dir_path) else 0 + debug.log( + f"Disk write failed for chunk ({bytes_needed/1e6:.1f} MB) to {chunk_path} " + f"(attempt {attempt}/{retries}). Free space: {free_bytes/1e6:.1f} MB. Error: {exc}", + level="WARNING", + category="file", + force=True + ) + time.sleep(min(10, attempt * 2)) + finally: + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError: + pass + raise OSError(f"Failed to save chunk after {retries} attempts: {chunk_path}") # ============================================================================= @@ -495,8 +652,18 @@ def process_single_file(input_path: str, args: argparse.Namespace, device_list: return 0 # Streaming mode: process in chunks - chunk_size = args.chunk_size if args.chunk_size > 0 else frames_to_process - streaming = args.chunk_size > 0 + if args.chunk_size <= 0 and frames_to_process > 300: + # Auto-enable streaming to avoid unbounded RAM use + chunk_size = min(300, frames_to_process) + streaming = True + debug.log( + f"--chunk_size not set; auto-enabling streaming with chunk_size={chunk_size} to reduce RAM use", + category="memory", + force=True + ) + else: + chunk_size = args.chunk_size if args.chunk_size > 0 else frames_to_process + streaming = args.chunk_size > 0 total_chunks = (frames_to_process + chunk_size - 1) // chunk_size # ceiling division if streaming: @@ -518,18 +685,38 @@ def process_single_file(input_path: str, args: argparse.Namespace, device_list: 'start_frame': args.skip_first_frames, 'frames_to_process': frames_to_process, } - result = _gpu_processing(None, device_list, args, video_info=video_info) - - # Save result - if is_png: - save_frames_to_image(result, output_path, base_name) - else: - video_writer = save_frames_to_video(result, output_path, fps, - video_backend=args.video_backend, use_10bit=args.use_10bit) - if video_writer is not None: - video_writer.release() - - frames_written = result.shape[0] + spill_dir = args.spill_dir or tempfile.mkdtemp(prefix="seedvr2_spill_") + debug.log(f"Multi-GPU streaming: spilling chunks to {spill_dir}", category="memory") + created_temp_spill = args.spill_dir is None + try: + result = _gpu_processing( + None, + device_list, + args, + video_info=video_info, + spill_dir=spill_dir, + output_path=output_path, + fps=fps, + base_name=base_name + ) + if isinstance(result, dict) and "frames_written" in result: + frames_written = result["frames_written"] + else: + # Fallback to in-memory path if spill disabled or failed + if is_png: + save_frames_to_image(result, output_path, base_name) + else: + video_writer = save_frames_to_video(result, output_path, fps, + video_backend=args.video_backend, use_10bit=args.use_10bit or args.output_bitdepth > 8, + codec=_resolve_codec(args), pix_fmt=_resolve_pix_fmt(args), + bitrate=args.video_bitrate, crf=args.video_crf, input_bitdepth=args.output_bitdepth, + preset=args.video_preset) + if video_writer is not None: + video_writer.release() + frames_written = result.shape[0] + finally: + if created_temp_spill: + shutil.rmtree(spill_dir, ignore_errors=True) # Single GPU: stream in main process else: @@ -553,8 +740,13 @@ def process_single_file(input_path: str, args: argparse.Namespace, device_list: if is_png: save_frames_to_image(result, output_path, base_name, start_index=frames_written) else: - video_writer = save_frames_to_video(result, output_path, fps, writer=video_writer, - video_backend=args.video_backend, use_10bit=args.use_10bit) + video_writer = save_frames_to_video( + result, output_path, fps, writer=video_writer, + video_backend=args.video_backend, use_10bit=args.use_10bit or args.output_bitdepth > 8, + codec=_resolve_codec(args), pix_fmt=_resolve_pix_fmt(args), + bitrate=args.video_bitrate, crf=args.video_crf, input_bitdepth=args.output_bitdepth, + preset=args.video_preset + ) frames_written += result.shape[0] del result @@ -687,20 +879,29 @@ def _stream_video_chunks( if chunk_idx > 1: debug.log("", category="none", force=True) debug.log("━" * 60, category="none", force=True) - debug.log("", category="none", force=True) - debug.log(f"{log_prefix}Chunk {chunk_idx}/{total_chunks}: {new_frames.shape[0]} new + {context_count} context frames", + debug.log("", category="none", force=True) + debug.log(f"{log_prefix}Chunk {chunk_idx}/{max(1, total_chunks)}: {new_frames.shape[0]} new + {context_count} context frames", category="generation", force=True) debug.log("", category="none", force=True) + # RAM before processing + _log_ram_usage(debug, f"{log_prefix}Chunk {chunk_idx} pre-process", force=True) + # Process chunk - result = _process_frames_core( - frames_tensor=frames.to(torch.float16), - args=chunk_args, - device_id=device_id, + result = _retry_with_cleanup( + lambda: _process_frames_core( + frames_tensor=frames.to(torch.float16), + args=chunk_args, + device_id=device_id, + debug=debug, + runner_cache=runner_cache + ), debug=debug, - runner_cache=runner_cache + description=f"{log_prefix}chunk processing" ) + _log_ram_usage(debug, f"{log_prefix}Chunk {chunk_idx} post-process (before context trim)", force=True) + # Remove context frames from output if context_count > 0: result = result[context_count:] @@ -710,12 +911,16 @@ def _stream_video_chunks( # Cleanup before yield del frames + del new_frames + gc.collect() + _log_ram_usage(debug, f"{log_prefix}Chunk {chunk_idx} before yield (inputs freed)", force=True) yield result # Memory cleanup between chunks if streaming: clear_memory(debug=debug, deep=True, force=True, timer_name=cleanup_timer_name) + _log_ram_usage(debug, f"{log_prefix}Chunk {chunk_idx} after cleanup", force=True) def _save_image_bgr(frame_np: np.ndarray, file_path: str) -> None: @@ -739,7 +944,13 @@ def save_frames_to_video( fps: float = 30.0, writer: Optional[cv2.VideoWriter] = None, video_backend: str = "opencv", - use_10bit: bool = False + use_10bit: bool = False, + input_bitdepth: int = 8, + codec: Optional[str] = None, + pix_fmt: Optional[str] = None, + bitrate: Optional[str] = None, + crf: Optional[int] = None, + preset: str = "medium" ) -> Optional[cv2.VideoWriter]: """ Save frames tensor to MP4 video file. @@ -760,14 +971,32 @@ def save_frames_to_video( Raises: ValueError: If video writer cannot be initialized """ - frames_np = (frames_tensor.cpu().numpy() * 255.0).astype(np.uint8) + bitdepth = max(1, input_bitdepth or 8) + if video_backend != "ffmpeg": + if input_bitdepth > 8: + debug.log("output_bitdepth > 8 requires --video_backend ffmpeg. Falling back to 8-bit for OpenCV writer.", + level="WARNING", category="file", force=True) + bitdepth = 8 # OpenCV backend expects 8-bit BGR + max_val = (1 << bitdepth) - 1 if bitdepth > 1 else 255 + dtype = np.uint8 if bitdepth <= 8 else np.uint16 + frames_np = (frames_tensor.cpu().numpy() * float(max_val)).round().astype(dtype) T, H, W, C = frames_np.shape + input_pix_fmt = "rgb48le" if dtype == np.uint16 else "rgb24" if writer is None: debug.log(f"Saving {T} frames to video: {output_path} (backend={video_backend})", category="file") os.makedirs(Path(output_path).parent, exist_ok=True) if video_backend == "ffmpeg": - writer = FFMPEGVideoWriter(output_path, W, H, fps, use_10bit) + writer = FFMPEGVideoWriter( + output_path, W, H, fps, + use_10bit=use_10bit, + codec=codec, + pix_fmt=pix_fmt, + bitrate=bitrate, + crf=crf, + preset=preset, + input_pix_fmt=input_pix_fmt + ) else: fourcc = cv2.VideoWriter_fourcc(*'mp4v') writer = cv2.VideoWriter(output_path, fourcc, fps, (W, H)) @@ -824,6 +1053,145 @@ def save_frames_to_image( return total +def _convert_tensor_for_storage(frames_tensor: torch.Tensor, bitdepth: int) -> torch.Tensor: + """ + Convert float tensor [0,1] to integer storage tensor for disk spill. + """ + frames_clamped = torch.clamp(frames_tensor, 0.0, 1.0) + if bitdepth <= 8: + return (frames_clamped * 255.0).round().to(torch.uint8) + max_val = float((1 << bitdepth) - 1) + return (frames_clamped * max_val).round().to(torch.uint16) + + +def _load_spilled_chunk(chunk_path: str, bitdepth: int) -> np.ndarray: + """ + Load a spilled chunk from disk and normalize to float32 [0,1]. + """ + chunk = np.load(chunk_path, mmap_mode="r") + if np.issubdtype(chunk.dtype, np.integer): + max_val = 255.0 if chunk.dtype == np.uint8 else float((1 << bitdepth) - 1) + return (chunk.astype(np.float32) / max_val) + return chunk.astype(np.float32) + + +def _resolve_pix_fmt(args: argparse.Namespace) -> Optional[str]: + if args.video_pix_fmt: + return args.video_pix_fmt + if args.output_bitdepth == 8: + return "yuv420p" + if args.output_bitdepth == 10: + return "yuv420p10le" + return "yuv420p16le" + + +def _resolve_codec(args: argparse.Namespace) -> Optional[str]: + if args.video_codec: + return args.video_codec + return "libx264" if args.output_bitdepth <= 8 else "libx265" + + +def _stitch_spilled_chunks( + worker_chunks: Dict[int, List[str]], + args: argparse.Namespace, + fps: float, + output_path: str, + base_name: str, + writer: Optional[cv2.VideoWriter] = None, + pix_fmt: Optional[str] = None, + codec: Optional[str] = None +) -> Tuple[int, Optional[cv2.VideoWriter]]: + """ + Stitch spilled chunks from multiple workers with overlap blending and stream to disk. + """ + frames_written = 0 + pending_tail: Optional[torch.Tensor] = None + overlap = args.temporal_overlap + last_worker_idx = max(worker_chunks.keys()) + pix_fmt = pix_fmt or _resolve_pix_fmt(args) + codec = codec or _resolve_codec(args) + frames_to_skip = args.prepend_frames + + def write_out(chunk_tensor: torch.Tensor) -> int: + nonlocal writer, frames_to_skip + if chunk_tensor.numel() == 0: + return 0 + if frames_to_skip > 0: + if chunk_tensor.shape[0] <= frames_to_skip: + frames_to_skip -= chunk_tensor.shape[0] + return 0 + chunk_tensor = chunk_tensor[frames_to_skip:] + frames_to_skip = 0 + if args.output_format == "png": + return save_frames_to_image(chunk_tensor, output_path, base_name, start_index=frames_written) + _log_ram_usage(debug, "Stitch write_out pre-video", force=True) + writer_local = save_frames_to_video( + chunk_tensor, + output_path, + fps, + writer=writer, + video_backend=args.video_backend, + use_10bit=args.use_10bit or args.output_bitdepth > 8, + codec=codec, + pix_fmt=pix_fmt, + bitrate=args.video_bitrate, + crf=args.video_crf, + preset=args.video_preset, + input_bitdepth=args.output_bitdepth + ) + _log_ram_usage(debug, "Stitch write_out post-video", force=True) + if writer is None: + writer = writer_local + return chunk_tensor.shape[0] + + for worker_idx in sorted(worker_chunks.keys()): + chunk_paths = sorted(worker_chunks[worker_idx]) + for chunk_idx, chunk_path in enumerate(chunk_paths): + chunk_np = _load_spilled_chunk(chunk_path, args.output_bitdepth) + chunk_tensor = torch.from_numpy(chunk_np) + is_last_worker = worker_idx == last_worker_idx + is_last_chunk = chunk_idx == len(chunk_paths) - 1 + + if pending_tail is not None: + if overlap > 0 and chunk_tensor.shape[0] > 0: + blend_len = min(overlap, pending_tail.shape[0], chunk_tensor.shape[0]) + blended = blend_overlapping_frames( + pending_tail[-blend_len:], + chunk_tensor[:blend_len], + blend_len + ) + frames_written += write_out(blended) + chunk_tensor = chunk_tensor[blend_len:] + else: + frames_written += write_out(pending_tail) + pending_tail = None + + if not is_last_worker and overlap > 0 and is_last_chunk: + if chunk_tensor.shape[0] <= overlap: + pending_tail = chunk_tensor + chunk_body = chunk_tensor[:0] + else: + pending_tail = chunk_tensor[-overlap:] + chunk_body = chunk_tensor[:-overlap] + else: + chunk_body = chunk_tensor + + if chunk_body.numel() > 0: + frames_written += write_out(chunk_body) + + try: + os.remove(chunk_path) + except OSError: + pass + del chunk_tensor, chunk_np + gc.collect() + + if pending_tail is not None: + frames_written += write_out(pending_tail) + + return frames_written, writer + + # ============================================================================= # Core Processing Logic # ============================================================================= @@ -851,6 +1219,8 @@ def _process_frames_core( Returns: Upscaled frames tensor [T', H', W', C], Float32, range [0,1] """ + _log_ram_usage(debug, f"Process start device {device_id}", force=True) + debug.log(f"[SHAPE] input frames {tuple(frames_tensor.shape)}, dtype={frames_tensor.dtype}", category="memory", force=True) # Determine platform and convert device IDs to full names platform_type = get_gpu_backend() inference_device = _device_id_to_name(device_id, platform_type) @@ -947,6 +1317,7 @@ def _process_frames_core( # Preload text embeddings before Phase 1 to avoid sync stall in Phase 2 ctx['text_embeds'] = load_text_embeddings(script_directory, ctx['dit_device'], ctx['compute_dtype'], debug) debug.log("Loaded text embeddings for DiT", category="dit") + _log_ram_usage(debug, "After text embeddings", force=True) # Compute generation info and log start (handles prepending internally) frames_tensor, gen_info = compute_generation_info( @@ -962,6 +1333,7 @@ def _process_frames_core( debug=debug ) log_generation_start(gen_info, debug) + _log_ram_usage(debug, "After compute_generation_info", force=True) # Phase 1: Encode ctx = encode_all_batches( @@ -977,6 +1349,7 @@ def _process_frames_core( input_noise_scale=args.input_noise_scale, color_correction=args.color_correction ) + _log_ram_usage(debug, "After encode", force=True) # Phase 2: Upscale ctx = upscale_all_batches( @@ -985,12 +1358,14 @@ def _process_frames_core( latent_noise_scale=args.latent_noise_scale, cache_model=cache_dit ) + _log_ram_usage(debug, "After upscale", force=True) # Phase 3: Decode ctx = decode_all_batches( runner, ctx=ctx, debug=debug, progress_callback=None, cache_model=cache_vae ) + _log_ram_usage(debug, "After decode", force=True) # Phase 4: Post-process ctx = postprocess_all_batches( @@ -1000,6 +1375,7 @@ def _process_frames_core( temporal_overlap=args.temporal_overlap, batch_size=args.batch_size ) + _log_ram_usage(debug, "After postprocess", force=True) result_tensor = ctx['final_video'] @@ -1009,6 +1385,35 @@ def _process_frames_core( if result_tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2): result_tensor = result_tensor.to(torch.float32) + _log_ram_usage(debug, "After CPU move/finalize", force=True) + # Aggressively drop intermediate tensors from ctx to avoid accumulation across chunks + keep_keys = {"cache_context", "dit_device", "vae_device", "dit_offload_device", + "vae_offload_device", "tensor_offload_device", "compute_dtype", "text_embeds"} + for key in list(ctx.keys()): + if key not in keep_keys: + try: + del ctx[key] + except Exception: + pass + + # If we are not caching runners, drop everything else as well + if runner_cache is None: + try: + ctx.clear() + except Exception: + pass + else: + # Even with caching, drop cache_context/text_embeds to avoid CPU bloat between chunks + for key in ("cache_context", "text_embeds"): + try: + if key in ctx: + del ctx[key] + except Exception: + pass + + clear_memory(debug=debug, deep=True, force=True, timer_name="ctx_cleanup") + gc.collect() + _log_ram_usage(debug, "After ctx cleanup", force=True) return result_tensor @@ -1018,7 +1423,7 @@ def _worker_process( frames_np: Optional[np.ndarray], shared_args: Dict[str, Any], return_queue: mp.Queue, - done_barrier: mp.Barrier, + done_barrier: Optional[mp.Barrier], video_info: Optional[Dict[str, Any]] = None ) -> None: """ @@ -1062,46 +1467,76 @@ def _worker_process( if proc_idx != 0: worker_args.prepend_frames = 0 - # Enable model caching within worker only if requested - runner_cache = {} if (args.cache_dit or args.cache_vae) else None + # Disable runner caching per chunk to avoid CPU accumulation; reload models per chunk + runner_cache = None total_chunks = (segment_frames + chunk_size - 1) // chunk_size - results = [] - for result in _stream_video_chunks( - cap=cap, - frames_to_process=segment_frames, - chunk_size=chunk_size, - overlap=args.temporal_overlap, - args=worker_args, - device_id="0", - debug=worker_debug, - runner_cache=runner_cache, - log_progress=total_chunks > 1, - total_chunks=total_chunks, - log_prefix=f"[GPU {proc_idx}] " - ): - results.append(result.cpu()) + chunk_paths: List[str] = [] + spill_root = shared_args.get("spill_dir") + if spill_root is None: + raise RuntimeError("spill_dir not provided for streaming worker") + worker_spill_dir = os.path.join(spill_root, f"worker_{proc_idx}") + os.makedirs(worker_spill_dir, exist_ok=True) + + for chunk_idx, result in enumerate(_stream_video_chunks( + cap=cap, + frames_to_process=segment_frames, + chunk_size=chunk_size, + overlap=args.temporal_overlap, + args=worker_args, + device_id="0", + debug=worker_debug, + runner_cache=runner_cache, + log_progress=total_chunks > 1, + total_chunks=total_chunks, + log_prefix=f"[GPU {proc_idx}] " + ), start=1): + storage_tensor = _convert_tensor_for_storage(result.cpu(), args.output_bitdepth) + chunk_path = os.path.join(worker_spill_dir, f"chunk_{chunk_idx:05d}.npy") + try: + _log_ram_usage(worker_debug, f"Worker {proc_idx} chunk {chunk_idx} pre-save ({storage_tensor.shape}, {storage_tensor.dtype})", force=True) + _save_chunk_with_retry(storage_tensor.numpy(), chunk_path, worker_debug) + chunk_paths.append(chunk_path) + _log_ram_usage(worker_debug, f"Worker {proc_idx} chunk {chunk_idx} post-save", force=True) + except Exception as exc: # noqa: BLE001 + if return_queue is not None: + return_queue.put((proc_idx, {"error": str(exc)})) + return + finally: + del result, storage_tensor + clear_memory(debug=worker_debug, deep=True, force=True, timer_name="worker_chunk_cleanup") + _log_ram_usage(worker_debug, f"Worker {proc_idx} chunk {chunk_idx} after cleanup", force=True) + gc.collect() cap.release() - result_tensor = torch.cat(results, dim=0) if results else torch.empty(0, dtype=torch.float32) + if return_queue is not None: + return_queue.put((proc_idx, chunk_paths)) + return # Pre-loaded frames mode (original behavior) else: frames_tensor = torch.from_numpy(frames_np).to(torch.float16) - result_tensor = _process_frames_core( - frames_tensor=frames_tensor, - args=args, - device_id="0", + result_tensor = _retry_with_cleanup( + lambda: _process_frames_core( + frames_tensor=frames_tensor, + args=args, + device_id="0", + debug=worker_debug, + runner_cache=None + ), debug=worker_debug, - runner_cache=None + description="worker processing" ) - + _log_ram_usage(worker_debug, f"Worker {proc_idx} post-process (pre-return)", force=True) + clear_memory(debug=worker_debug, deep=True, force=True, timer_name="worker_post_process_cleanup") # Share tensor memory for efficient cross-process transfer (avoids pickling large arrays) - return_queue.put((proc_idx, result_tensor.share_memory_())) + if return_queue is not None: + return_queue.put((proc_idx, result_tensor.share_memory_())) # Wait for parent to copy shared tensors before exiting # (shared memory requires creating process to stay alive during access) - done_barrier.wait() + if done_barrier is not None: + done_barrier.wait() def _single_gpu_direct_processing( @@ -1115,12 +1550,16 @@ def _single_gpu_direct_processing( Uses main process and shared runner cache for efficient multi-file processing. """ - return _process_frames_core( - frames_tensor=frames_tensor, - args=args, - device_id=device_id, + return _retry_with_cleanup( + lambda: _process_frames_core( + frames_tensor=frames_tensor, + args=args, + device_id=device_id, + debug=debug, + runner_cache=runner_cache + ), debug=debug, - runner_cache=runner_cache + description="single GPU processing" ) @@ -1128,8 +1567,12 @@ def _gpu_processing( frames_tensor: Optional[torch.Tensor], device_list: List[str], args: argparse.Namespace, - video_info: Optional[Dict[str, Any]] = None -) -> torch.Tensor: + video_info: Optional[Dict[str, Any]] = None, + spill_dir: Optional[str] = None, + output_path: Optional[str] = None, + fps: float = 30.0, + base_name: str = "" +) -> Any: """ Orchestrate multi-GPU parallel video upscaling with temporal overlap blending. @@ -1147,50 +1590,119 @@ def _gpu_processing( for streaming mode where workers read video directly Returns: - Upscaled frames tensor [T', H', W', C], Float32, range [0,1] + Upscaled frames tensor [T', H', W', C], Float32, range [0,1] or + dict with {'frames_written': int} when streaming to disk """ num_devices = len(device_list) overlap = args.temporal_overlap return_queue = mp.Queue(maxsize=0) - done_barrier = mp.Barrier(num_devices + 1) + done_barrier: Optional[mp.Barrier] = None if video_info is not None and spill_dir else mp.Barrier(num_devices + 1) workers = [] shared_args = vars(args).copy() + if spill_dir: + shared_args["spill_dir"] = spill_dir # Video streaming mode: distribute frame ranges to workers - if video_info is not None: + recycle_every = max(1, getattr(args, "recycle_workers_every", 1)) + if video_info is not None and spill_dir: total_frames = video_info['frames_to_process'] start_frame = video_info['start_frame'] video_path = video_info['video_path'] - base_per_gpu = total_frames // num_devices remainder = total_frames % num_devices - - current_start = start_frame - for idx, device_id in enumerate(device_list): + + cycle_span = (args.chunk_size if args.chunk_size > 0 else max(1, base_per_gpu)) + cycle_span *= recycle_every + + # Build per-device segments + device_states = [] + seg_start = start_frame + for idx in range(num_devices): gpu_frames = base_per_gpu + (1 if idx < remainder else 0) - gpu_end = current_start + gpu_frames - - # Add overlap frames for blending (except last GPU) - if idx < num_devices - 1 and overlap > 0: - gpu_end = min(gpu_end + overlap, start_frame + total_frames) - - worker_video_info = { - 'video_path': video_path, - 'start_frame': current_start, - 'end_frame': gpu_end, - } - - os.environ["CUDA_VISIBLE_DEVICES"] = device_id - p = mp.Process( - target=_worker_process, - args=(idx, device_id, None, shared_args, return_queue, done_barrier), - kwargs={'video_info': worker_video_info} + base_end = seg_start + gpu_frames + final_end = base_end + (overlap if idx < num_devices - 1 else 0) + device_states.append({ + "idx": idx, + "device": device_list[idx], + "cursor": seg_start, + "final_end": final_end, + "start_base": seg_start, + "total_frames": final_end - seg_start + }) + seg_start = base_end + + frames_written_total = 0 + cycle_index = 0 + writer = None + pix_fmt = _resolve_pix_fmt(args) + codec = _resolve_codec(args) + + # Process cycles until all device segments are consumed + while any(state["cursor"] < state["final_end"] for state in device_states): + cycle_index += 1 + workers = [] + worker_chunks: Dict[int, List[str]] = {} + + # Spawn workers for devices that still have remaining frames + for state in device_states: + if state["cursor"] >= state["final_end"]: + continue + start_cur = state["cursor"] + end_cur = min(state["final_end"], start_cur + cycle_span) + state["cursor"] = end_cur + + worker_video_info = { + 'video_path': video_path, + 'start_frame': start_cur, + 'end_frame': end_cur, + } + + os.environ["CUDA_VISIBLE_DEVICES"] = state["device"] + p = mp.Process( + target=_worker_process, + args=(state["idx"], state["device"], None, shared_args, return_queue, done_barrier), + kwargs={'video_info': worker_video_info} + ) + p.start() + workers.append(p) + + monitor_stop = None + monitor_thread = None + if workers: + monitor_stop, monitor_thread = _start_memory_monitor( + [p.pid for p in workers if p.pid], + debug, + label=f"workers_cycle_{cycle_index}", + interval=10.0 + ) + + collected = 0 + while collected < len(workers): + proc_idx, payload = return_queue.get() + if isinstance(payload, dict) and "error" in payload: + raise RuntimeError(f"Worker {proc_idx} failed to spill chunk: {payload['error']}") + worker_chunks[proc_idx] = payload + collected += 1 + + for p in workers: + p.join() + + if monitor_stop: + monitor_stop.set() + if monitor_thread: + monitor_thread.join(timeout=2.0) + + _log_ram_usage(debug, f"Parent pre-stitch cycle {cycle_index}", force=True) + frames_written, writer = _stitch_spilled_chunks( + worker_chunks, args, fps, output_path, base_name, writer=writer, pix_fmt=pix_fmt, codec=codec ) - p.start() - workers.append(p) - - current_start += gpu_frames + frames_written_total += frames_written + _log_ram_usage(debug, f"Parent post-stitch cycle {cycle_index}", force=True) + + if writer is not None: + writer.release() + return {"frames_written": frames_written_total} # Pre-loaded frames mode (original behavior for images or non-streaming) else: @@ -1221,9 +1733,18 @@ def _gpu_processing( ) p.start() workers.append(p) + monitor_stop, monitor_thread = _start_memory_monitor( + [p.pid for p in workers if p.pid], + debug, + label="workers", + interval=10.0 + ) + + if video_info is not None and spill_dir: + # This path is now handled in the loop above + return {"frames_written": 0} - # Collect results before joining to prevent deadlock - # Tensors arrive via shared memory - copy to numpy while workers still alive + # Collect results before joining to prevent deadlock (pre-loaded frames path) results_np = [None] * num_devices collected = 0 while collected < num_devices: @@ -1232,7 +1753,12 @@ def _gpu_processing( collected += 1 # Release workers now that shared tensors are copied - done_barrier.wait() + if done_barrier is not None: + done_barrier.wait() + if monitor_stop: + monitor_stop.set() + if monitor_thread: + monitor_thread.join(timeout=2.0) # Now safe to join for p in workers: @@ -1352,10 +1878,27 @@ def parse_arguments() -> argparse.Namespace: io_group.add_argument("--output_format", type=str, default=None, choices=["mp4", "png", None], help="Output format: 'mp4' (video) or 'png' (image sequence). Default: auto-detect from input type") io_group.add_argument("--video_backend", type=str, default="opencv", choices=["opencv", "ffmpeg"], - help="Video encoder backend: 'opencv' (default) or 'ffmpeg' (requires ffmpeg in PATH)") + help="Video encoder backend: 'opencv' (default) or 'ffmpeg' (requires ffmpeg in PATH)") io_group.add_argument("--10bit", dest="use_10bit", action="store_true", help="Save 10-bit video with x265 codec (reduces banding). Without this flag, " "ffmpeg uses x264 for maximum compatibility. Requires --video_backend ffmpeg") + io_group.add_argument("--output_bitdepth", type=int, default=8, choices=[8, 10, 12, 16], + help="Bit depth for output frames. Influences spill-to-disk dtype and default ffmpeg pixel format. Default: 8") + io_group.add_argument("--spill_dir", type=str, default=None, + help="Directory for spilling streamed chunks in multi-GPU mode. Default: system temp dir") + io_group.add_argument("--video_codec", type=str, default=None, + help="Override video codec for ffmpeg backend (default: libx264 for <=8-bit, libx265 for >8-bit)") + io_group.add_argument("--video_pix_fmt", type=str, default=None, + help="Override ffmpeg pixel format (default derives from output_bitdepth: yuv420p/yuv420p10le/yuv420p16le)") + io_group.add_argument("--video_crf", type=int, default=None, + help="CRF value for ffmpeg (default: 12). Ignored if --video_bitrate is set") + io_group.add_argument("--video_bitrate", type=str, default=None, + help="Bitrate target for ffmpeg, e.g., '20M'. Overrides CRF when provided") + io_group.add_argument("--video_preset", type=str, default="medium", + help="ffmpeg preset (default: medium)") + io_group.add_argument("--recycle_workers_every", type=int, default=1, + help="Chunks processed per worker before recycling (respawn). " + "Default: 1 (recycle every chunk). Increase to reuse workers across chunks at the cost of higher peak RAM.") io_group.add_argument("--model_dir", type=str, default=None, help=f"Model directory (default: ./models/{SEEDVR2_FOLDER_NAME})") @@ -1709,4 +2252,4 @@ def main() -> None: debug.print_footer() if __name__ == "__main__": - main() \ No newline at end of file + main()