diff --git a/python/pyproject.toml b/python/pyproject.toml index c51e21f50e..9549f80791 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -38,6 +38,7 @@ srt_hpu = ["sglang[runtime_common]"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] +torch_memory_saver = ["torch_memory_saver"] test = [ "jsonlines", "matplotlib", diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 13eb233bd1..6ff83f61ee 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -469,6 +469,26 @@ class GetWeightsByNameReqOutput: parameter: list +@dataclass +class ReleaseGPUOccupationReqInput: + pass + + +@dataclass +class ReleaseGPUOccupationReqOutput: + pass + + +@dataclass +class ResumeGPUOccupationReqInput: + pass + + +@dataclass +class ResumeGPUOccupationReqOutput: + pass + + @dataclass class AbortReq: # The request id diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4bf41aaf39..57f545765a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -46,6 +46,10 @@ OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + ReleaseGPUOccupationReqInput, + ReleaseGPUOccupationReqOutput, + ResumeGPUOccupationReqInput, + ResumeGPUOccupationReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, @@ -86,6 +90,7 @@ set_random_seed, suppress_other_loggers, ) +from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -332,6 +337,10 @@ def __init__( t.start() self.parent_process = psutil.Process().parent() + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.memory_saver + ) + # Init profiler if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": self.profiler = None @@ -492,6 +501,12 @@ def process_input_requests(self, recv_reqs: List): elif isinstance(recv_req, GetWeightsByNameReqInput): parameter = self.get_weights_by_name(recv_req) self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) + elif isinstance(recv_req, ReleaseGPUOccupationReqInput): + self.release_gpu_occupation() + self.send_to_tokenizer.send_pyobj(ReleaseGPUOccupationReqOutput()) + elif isinstance(recv_req, ResumeGPUOccupationReqInput): + self.resume_gpu_occupation() + self.send_to_tokenizer.send_pyobj(ResumeGPUOccupationReqOutput()) elif isinstance(recv_req, ProfileReq): if recv_req == ProfileReq.START_PROFILE: self.start_profile() @@ -1497,6 +1512,20 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) return parameter + def release_gpu_occupation(self): + self.stashed_model_static_state = ( + self.tp_worker.worker.model_runner.model.export_static_state() + ) + self.memory_saver_adapter.pause() + self.flush_cache() + + def resume_gpu_occupation(self): + self.memory_saver_adapter.resume() + self.tp_worker.worker.model_runner.model.import_static_state( + self.stashed_model_static_state + ) + del self.stashed_model_static_state + def start_profile(self) -> None: if self.profiler is None: raise RuntimeError("Profiler is not enabled.") diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3196e60cb6..1b3adfcee1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -53,6 +53,10 @@ OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + ReleaseGPUOccupationReqInput, + ReleaseGPUOccupationReqOutput, + ResumeGPUOccupationReqInput, + ResumeGPUOccupationReqOutput, SessionParams, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -188,6 +192,12 @@ def __init__( self.get_weights_by_name_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.release_gpu_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.resume_gpu_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) # Metrics if self.enable_metrics: @@ -548,6 +558,22 @@ async def get_weights_by_name( else: return all_parameters + async def release_gpu_occupation( + self, + obj: ReleaseGPUOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.release_gpu_occupation_communicator(obj) + + async def resume_gpu_occupation( + self, + obj: ResumeGPUOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.resume_gpu_occupation_communicator(obj) + async def open_session( self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None ): @@ -627,6 +653,8 @@ async def handle_loop(self): UpdateWeightsFromDistributedReqOutput, GetWeightsByNameReqOutput, InitWeightsUpdateGroupReqOutput, + ReleaseGPUOccupationReqOutput, + ResumeGPUOccupationReqOutput, ] = await self.recv_from_detokenizer.recv_pyobj() if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)): @@ -750,6 +778,10 @@ async def handle_loop(self): self.update_weights_from_tensor_communicator.handle_recv(recv_obj) elif isinstance(recv_obj, GetWeightsByNameReqOutput): self.get_weights_by_name_communicator.handle_recv(recv_obj) + elif isinstance(recv_obj, ReleaseGPUOccupationReqOutput): + self.release_gpu_occupation_communicator.handle_recv(recv_obj) + elif isinstance(recv_obj, ResumeGPUOccupationReqOutput): + self.resume_gpu_occupation_communicator.handle_recv(recv_obj) else: raise ValueError(f"Invalid object: {recv_obj=}") diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 06b9f62684..dc62196fd6 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -13,6 +13,8 @@ limitations under the License. """ +from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter + """ Memory pool. @@ -35,13 +37,21 @@ class ReqToTokenPool: """A memory pool that maps a request to its token locations.""" - def __init__(self, size: int, max_context_len: int, device: str, use_records: bool): + def __init__( + self, + size: int, + max_context_len: int, + device: str, + use_records: bool, + memory_saver_adapter: TorchMemorySaverAdapter, + ): self.size = size self.max_context_len = max_context_len self.device = device - self.req_to_token = torch.zeros( - (size, max_context_len), dtype=torch.int32, device=device - ) + with memory_saver_adapter.region(): + self.req_to_token = torch.zeros( + (size, max_context_len), dtype=torch.int32, device=device + ) self.free_slots = list(range(size)) self.write_records = [] self.use_records = use_records @@ -182,32 +192,35 @@ def __init__( head_dim: int, layer_num: int, device: str, + memory_saver_adapter: TorchMemorySaverAdapter, ): super().__init__(size, dtype, device) self.head_num = head_num self.head_dim = head_dim self.layer_num = layer_num + self.memory_saver_adapter = memory_saver_adapter self._create_buffers() def _create_buffers(self): - # [size, head_num, head_dim] for each layer - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.k_buffer = [ - torch.empty( - (self.size + 1, self.head_num, self.head_dim), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] - self.v_buffer = [ - torch.empty( - (self.size + 1, self.head_num, self.head_dim), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] + with self.memory_saver_adapter.region(): + # [size, head_num, head_dim] for each layer + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.k_buffer = [ + torch.empty( + (self.size + 1, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_buffer = [ + torch.empty( + (self.size + 1, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] def _clear_buffers(self): del self.k_buffer @@ -262,19 +275,22 @@ def __init__( qk_rope_head_dim: int, layer_num: int, device: str, + memory_saver_adapter: TorchMemorySaverAdapter, ): super().__init__(size, dtype, device) self.kv_lora_rank = kv_lora_rank - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.kv_buffer = [ - torch.empty( - (size + 1, 1, kv_lora_rank + qk_rope_head_dim), - dtype=self.store_dtype, - device=device, - ) - for _ in range(layer_num) - ] + + with memory_saver_adapter.region(): + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.kv_buffer = [ + torch.empty( + (size + 1, 1, kv_lora_rank + qk_rope_head_dim), + dtype=self.store_dtype, + device=device, + ) + for _ in range(layer_num) + ] def get_key_buffer(self, layer_id: int): if self.store_dtype != self.dtype: @@ -315,26 +331,28 @@ def __init__( layer_num: int, device: str, heavy_channel_num: int, + memory_saver_adapter: TorchMemorySaverAdapter, ): super().__init__(size, dtype, device) - # [size, head_num, head_dim] for each layer - self.k_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) - for _ in range(layer_num) - ] - self.v_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) - for _ in range(layer_num) - ] - - # [size, head_num, heavy_channel_num] for each layer - self.label_buffer = [ - torch.empty( - (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device - ) - for _ in range(layer_num) - ] + with memory_saver_adapter.region(): + # [size, head_num, head_dim] for each layer + self.k_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) + for _ in range(layer_num) + ] + self.v_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) + for _ in range(layer_num) + ] + + # [size, head_num, heavy_channel_num] for each layer + self.label_buffer = [ + torch.empty( + (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device + ) + for _ in range(layer_num) + ] def get_key_buffer(self, layer_id: int): return self.k_buffer[layer_id] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 786f654ded..56fe117aaf 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -58,6 +58,7 @@ monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, ) +from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter logger = logging.getLogger(__name__) @@ -152,6 +153,10 @@ def __init__( # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=self.server_args.memory_saver + ) + # Load the model self.sampler = Sampler() self.load_model() @@ -254,11 +259,12 @@ def load_model(self): monkey_patch_vllm_gguf_config() # Load the model - self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=DeviceConfig(self.device), - ) + with self.memory_saver_adapter.region(): + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=DeviceConfig(self.device), + ) # Parse other args self.sliding_window_size = ( @@ -376,7 +382,7 @@ def init_weights_update_group( logger.info( f"init custom process group: master_address={master_address}, master_port={master_port}, " - f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}" + f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}" ) try: @@ -536,6 +542,7 @@ def init_memory_pool( max_context_len=self.model_config.context_len + 4, device=self.device, use_records=False, + memory_saver_adapter=self.memory_saver_adapter, ) if ( self.model_config.attention_arch == AttentionArch.MLA @@ -548,6 +555,7 @@ def init_memory_pool( qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, + memory_saver_adapter=self.memory_saver_adapter, ) elif self.server_args.enable_double_sparsity: self.token_to_kv_pool = DoubleSparseTokenToKVPool( @@ -558,6 +566,7 @@ def init_memory_pool( layer_num=self.model_config.num_hidden_layers, device=self.device, heavy_channel_num=self.server_args.ds_heavy_channel_num, + memory_saver_adapter=self.memory_saver_adapter, ) else: self.token_to_kv_pool = MHATokenToKVPool( @@ -567,6 +576,7 @@ def init_memory_pool( head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, + memory_saver_adapter=self.memory_saver_adapter, ) logger.info( f"Memory pool end. " diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index 3bd60c25d3..7e09647f5f 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -46,6 +46,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -307,7 +308,7 @@ def forward( return hidden_states -class BaiChuanBaseForCausalLM(nn.Module): +class BaiChuanBaseForCausalLM(BaseCausalLM): packed_modules_mapping = { "W_pack": ["W_pack"], "gate_up_proj": [ diff --git a/python/sglang/srt/models/base.py b/python/sglang/srt/models/base.py new file mode 100644 index 0000000000..c1beda92ff --- /dev/null +++ b/python/sglang/srt/models/base.py @@ -0,0 +1,15 @@ +from torch import nn + + +class BaseCausalLM(nn.Module): + def export_static_state(self): + return dict( + buffers=[ + (name, buffer.detach().clone()) for name, buffer in self.named_buffers() + ] + ) + + def import_static_state(self, static_params): + self_named_buffers = dict(self.named_buffers()) + for name, tensor in static_params["buffers"]: + self_named_buffers[name][...] = tensor diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 9c3bc2ee9e..d99f3367db 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -41,6 +41,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM LoraConfig = None @@ -332,7 +333,7 @@ def forward( return hidden_states -class ChatGLMForCausalLM(nn.Module): +class ChatGLMForCausalLM(BaseCausalLM): packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"], diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 83ac3d8671..489bc0f961 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -62,6 +62,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.utils import get_compiler_backend, set_weight_attrs @@ -313,7 +314,7 @@ def forward( return hidden_states -class CohereForCausalLM(nn.Module): +class CohereForCausalLM(BaseCausalLM): def __init__( self, config: PretrainedConfig, diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 852f58a710..6a945c1cb9 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -43,6 +43,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.utils import set_weight_attrs @@ -361,7 +362,7 @@ def forward( return hidden_states -class DbrxForCausalLM(nn.Module): +class DbrxForCausalLM(BaseCausalLM): def __init__( self, config: DbrxConfig, diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index d840cb866b..563031bdd2 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -46,6 +46,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class DeepseekMLP(nn.Module): @@ -362,7 +363,7 @@ def forward( return hidden_states -class DeepseekForCausalLM(nn.Module): +class DeepseekForCausalLM(BaseCausalLM): def __init__( self, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index a9c0b59cea..3f86c207e7 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -56,6 +56,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.utils import is_flashinfer_available, is_hip is_hip_ = is_hip() @@ -823,7 +824,7 @@ def forward( return hidden_states -class DeepseekV2ForCausalLM(nn.Module): +class DeepseekV2ForCausalLM(BaseCausalLM): def __init__( self, diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 536c253c33..c518dbdf86 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -39,6 +39,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class ExaoneGatedMLP(nn.Module): @@ -288,7 +289,7 @@ def forward( return hidden_states -class ExaoneForCausalLM(nn.Module): +class ExaoneForCausalLM(BaseCausalLM): def __init__( self, config, diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 10949a2f57..4be83ee150 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -37,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class GemmaMLP(nn.Module): @@ -249,7 +250,7 @@ def forward( return hidden_states -class GemmaForCausalLM(nn.Module): +class GemmaForCausalLM(BaseCausalLM): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 58d9ce02f2..51965fcc09 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -35,6 +35,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.utils import make_layers @@ -306,7 +307,7 @@ def forward( return hidden_states -class Gemma2ForCausalLM(nn.Module): +class Gemma2ForCausalLM(BaseCausalLM): # BitandBytes specific attributes default_bitsandbytes_target_modules = [ ".gate_proj.", diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index 144ad8bbf7..d3ca436d41 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -38,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class GPT2Attention(nn.Module): @@ -217,7 +218,7 @@ def forward( return hidden_states -class GPT2LMHeadModel(nn.Module): +class GPT2LMHeadModel(BaseCausalLM): def __init__( self, diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index f2f5ebd520..712cbc6a4d 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -35,6 +35,7 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class GPTBigCodeAttention(nn.Module): @@ -219,7 +220,7 @@ def forward( return hidden_states -class GPTBigCodeForCausalLM(nn.Module): +class GPTBigCodeForCausalLM(BaseCausalLM): packed_modules_mapping = {"c_attn": ["c_attn"]} supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py index d207ff61b2..873da2feb7 100644 --- a/python/sglang/srt/models/granite.py +++ b/python/sglang/srt/models/granite.py @@ -42,6 +42,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -295,7 +296,7 @@ def forward( return hidden_states -class GraniteForCausalLM(nn.Module): +class GraniteForCausalLM(BaseCausalLM): def __init__( self, config: GraniteConfig, diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index cb6a72a3f6..c85a22f3fd 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -43,6 +43,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class Grok1MLP(nn.Module): @@ -333,7 +334,7 @@ def forward( return hidden_states -class Grok1ForCausalLM(nn.Module): +class Grok1ForCausalLM(BaseCausalLM): def __init__( self, config: PretrainedConfig, diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index 0a737c1388..130804510c 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -38,6 +38,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class InternLM2MLP(nn.Module): @@ -246,7 +247,7 @@ def forward( return hidden_states -class InternLM2ForCausalLM(nn.Module): +class InternLM2ForCausalLM(BaseCausalLM): def __init__( self, config: PretrainedConfig, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index c06637962c..c1cec5e742 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -42,6 +42,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.utils import make_layers from sglang.utils import get_exception_traceback @@ -293,8 +294,7 @@ def forward( return hidden_states -class LlamaForCausalLM(nn.Module): - +class LlamaForCausalLM(BaseCausalLM): # BitandBytes specific attributes default_bitsandbytes_target_modules = [ ".gate_proj.", diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index c8ce9302b4..510cdf039f 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -39,12 +39,13 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM -class LlavaBaseForCausalLM(nn.Module): +class LlavaBaseForCausalLM(BaseCausalLM): def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 7b5f236a5b..75d8b41811 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -25,10 +25,11 @@ from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.models.llama import LlamaForCausalLM -class LlavaVidForCausalLM(nn.Module): +class LlavaVidForCausalLM(BaseCausalLM): def __init__( self, config: LlavaConfig, diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 3482a82813..6fe6139fcc 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -37,6 +37,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class MiniCPMMLP(nn.Module): @@ -270,7 +271,7 @@ def forward( return hidden_states -class MiniCPMForCausalLM(nn.Module): +class MiniCPMForCausalLM(BaseCausalLM): def __init__( self, config, diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index b0c93274e2..01efb86309 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -40,6 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.utils import is_flashinfer_available if is_flashinfer_available(): @@ -537,7 +538,7 @@ def forward( return hidden_states -class MiniCPM3ForCausalLM(nn.Module): +class MiniCPM3ForCausalLM(BaseCausalLM): def __init__( self, config: PretrainedConfig, diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 9dbdb46ff9..562d01e6d3 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -45,6 +45,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class MixtralMoE(nn.Module): @@ -290,7 +291,7 @@ def forward( return hidden_states -class MixtralForCausalLM(nn.Module): +class MixtralForCausalLM(BaseCausalLM): def __init__( self, diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index e5f49f5662..728616181b 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -45,6 +45,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class MixtralMLP(nn.Module): @@ -319,7 +320,7 @@ def forward( return hidden_states -class QuantMixtralForCausalLM(nn.Module): +class QuantMixtralForCausalLM(BaseCausalLM): def __init__( self, config: MixtralConfig, diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 019d21c208..b23c52e3e2 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -34,6 +34,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP @@ -719,7 +720,7 @@ def forward( return hidden_states -class MllamaForCausalLM(nn.Module): +class MllamaForCausalLM(BaseCausalLM): config_class = config_mllama.MllamaTextConfig base_model_prefix = "language_model" _no_split_modules = [ diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 1cfa27309f..bdc6248238 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -38,6 +38,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.utils import make_layers @@ -266,7 +267,7 @@ def forward( return hidden_states -class OlmoForCausalLM(nn.Module): +class OlmoForCausalLM(BaseCausalLM): """ Extremely barebones HF model wrapper. """ diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py index 0944b57209..50d5533a2e 100755 --- a/python/sglang/srt/models/olmo2.py +++ b/python/sglang/srt/models/olmo2.py @@ -45,6 +45,7 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.base import BaseCausalLM from sglang.srt.utils import make_layers @@ -304,7 +305,7 @@ def forward( return hidden_states -class Olmo2ForCausalLM(nn.Module): +class Olmo2ForCausalLM(BaseCausalLM): """ Extremely barebones HF model wrapper. """ diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index df96be3bc9..b03e22730c 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -47,6 +47,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.utils import make_layers, print_warning_once @@ -292,7 +293,7 @@ def forward( return hidden_states -class OlmoeForCausalLM(nn.Module): +class OlmoeForCausalLM(BaseCausalLM): fall_back_to_pt_during_load = False diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index 1e70c7d787..f155a2ba69 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -24,6 +24,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.utils import make_layers @@ -328,7 +329,7 @@ def forward( return hidden_states -class Phi3SmallForCausalLM(nn.Module): +class Phi3SmallForCausalLM(BaseCausalLM): _tied_weights_keys = ["lm_head.weight"] def __init__( diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 5492a3e122..d8d60f3663 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -39,6 +39,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class QWenMLP(nn.Module): @@ -237,7 +238,7 @@ def forward( return hidden_states -class QWenLMHeadModel(nn.Module): +class QWenLMHeadModel(BaseCausalLM): def __init__( self, config: PretrainedConfig, diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 2a20d6c50d..e844402c8d 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -40,6 +40,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM from sglang.srt.utils import make_layers Qwen2Config = None @@ -266,7 +267,7 @@ def forward( return hidden_states -class Qwen2ForCausalLM(nn.Module): +class Qwen2ForCausalLM(BaseCausalLM): # BitandBytes specific attributes default_bitsandbytes_target_modules = [ diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 9db2d53823..bea7c1ff2a 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -46,6 +46,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class Qwen2MoeMLP(nn.Module): @@ -338,7 +339,7 @@ def forward( return hidden_states -class Qwen2MoeForCausalLM(nn.Module): +class Qwen2MoeForCausalLM(BaseCausalLM): fall_back_to_pt_during_load = False diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 079d54e3c8..c7c5d00443 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -42,6 +42,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class StablelmMLP(nn.Module): @@ -237,7 +238,7 @@ def forward( return hidden_states -class StableLmForCausalLM(nn.Module): +class StableLmForCausalLM(BaseCausalLM): def __init__( self, config: PretrainedConfig, diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 7a55d50457..c632cef529 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -64,6 +64,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() @@ -381,7 +382,7 @@ def forward( return hidden_states -class TorchNativeLlamaForCausalLM(nn.Module): +class TorchNativeLlamaForCausalLM(BaseCausalLM): def __init__( self, config: LlamaConfig, diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index e655142151..72fd173f67 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -40,6 +40,7 @@ ) from sglang.srt.model_executor.model_runner import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class XverseMLP(nn.Module): @@ -290,7 +291,7 @@ def forward( return hidden_states -class XverseForCausalLM(nn.Module): +class XverseForCausalLM(BaseCausalLM): def __init__( self, config: LlamaConfig, diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index 9b4b27f07d..9e7a0f4941 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -43,6 +43,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.base import BaseCausalLM class XverseMLP(nn.Module): @@ -358,7 +359,7 @@ def forward( return hidden_states -class XverseMoeForCausalLM(nn.Module): +class XverseMoeForCausalLM(BaseCausalLM): def __init__( self, diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index d95ce5931b..6feef67671 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -29,6 +29,8 @@ from http import HTTPStatus from typing import AsyncIterator, Dict, List, Optional, Union +from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -55,6 +57,8 @@ GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, OpenSessionReqInput, + ReleaseGPUOccupationReqInput, + ResumeGPUOccupationReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, @@ -254,6 +258,24 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): return _create_error_response(e) +@app.api_route("/release_gpu_occupation", methods=["GET", "POST"]) +async def release_gpu_occupation(obj: ReleaseGPUOccupationReqInput, request: Request): + """Release GPU occupation temporarily""" + try: + await tokenizer_manager.release_gpu_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/resume_gpu_occupation", methods=["GET", "POST"]) +async def resume_gpu_occupation(obj: ResumeGPUOccupationReqInput, request: Request): + """Resume GPU occupation""" + try: + await tokenizer_manager.resume_gpu_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + @app.api_route("/open_session", methods=["GET", "POST"]) async def open_session(obj: OpenSessionReqInput, request: Request): """Open a session, and return its unique session id.""" @@ -437,6 +459,10 @@ def launch_engine( server_args.model_path, server_args.tokenizer_path ) + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.memory_saver + ) + if server_args.dp_size == 1: # Launch tensor parallel scheduler processes scheduler_procs = [] @@ -453,7 +479,8 @@ def launch_engine( target=run_scheduler_process, args=(server_args, port_args, gpu_id, tp_rank, None, writer), ) - proc.start() + with memory_saver_adapter.configure_subprocess(): + proc.start() scheduler_procs.append(proc) scheduler_pipe_readers.append(reader) @@ -470,7 +497,8 @@ def launch_engine( target=run_data_parallel_controller_process, args=(server_args, port_args, writer), ) - proc.start() + with memory_saver_adapter.configure_subprocess(): + proc.start() # Launch detokenizer process detoken_proc = mp.Process( @@ -888,6 +916,18 @@ def get_weights_by_name(self, name, truncate_size=100): loop = asyncio.get_event_loop() return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None)) + def release_gpu_occupation(self): + """Release GPU occupation temporarily""" + obj = ReleaseGPUOccupationReqInput() + loop = asyncio.get_event_loop() + loop.run_until_complete(tokenizer_manager.release_gpu_occupation(obj, None)) + + def resume_gpu_occupation(self): + """Resume GPU occupation""" + obj = ResumeGPUOccupationReqInput() + loop = asyncio.get_event_loop() + loop.run_until_complete(tokenizer_manager.resume_gpu_occupation(obj, None)) + class Runtime: """ diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 58a6a6a825..8162690e50 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -147,6 +147,7 @@ class ServerArgs: triton_attention_num_kv_splits: int = 8 num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False + memory_saver: bool = False def __post_init__(self): # Set missing default values @@ -792,6 +793,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Delete the model checkpoint after loading the model.", ) + parser.add_argument( + "--memory-saver", + action="store_true", + help="Allow saving memory using release_gpu_occupation and resume_gpu_occupation", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/torch_memory_saver_adapter.py b/python/sglang/torch_memory_saver_adapter.py new file mode 100644 index 0000000000..31f8ebf2f0 --- /dev/null +++ b/python/sglang/torch_memory_saver_adapter.py @@ -0,0 +1,59 @@ +from abc import ABC +from contextlib import contextmanager + +try: + import torch_memory_saver + + _primary_memory_saver = torch_memory_saver.TorchMemorySaver() +except ImportError: + pass + + +class TorchMemorySaverAdapter(ABC): + @staticmethod + def create(enable: bool): + return ( + _TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop() + ) + + def configure_subprocess(self): + raise NotImplementedError + + def region(self): + raise NotImplementedError + + def pause(self): + raise NotImplementedError + + def resume(self): + raise NotImplementedError + + +class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter): + def configure_subprocess(self): + return torch_memory_saver.configure_subprocess() + + def region(self): + return _primary_memory_saver.region() + + def pause(self): + return _primary_memory_saver.pause() + + def resume(self): + return _primary_memory_saver.resume() + + +class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): + @contextmanager + def configure_subprocess(self): + yield + + @contextmanager + def region(self): + yield + + def pause(self): + pass + + def resume(self): + pass diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index 5fa202a32a..c6484ee6c4 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -14,6 +14,7 @@ pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2 # Force reinstall flashinfer pip install flashinfer==0.1.6 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps +pip uninstall -y torch_memory_saver && pip install torch_memory_saver pip install transformers==4.45.2 sentence_transformers accelerate peft diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 02fe8032e0..bb0957bf62 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -26,6 +26,7 @@ "test_openai_server.py", "test_pytorch_sampling_backend.py", "test_radix_attention.py", + "test_release_gpu_occupation.py", "test_retract_decode.py", "test_server_args.py", "test_session_control.py", diff --git a/test/srt/test_release_gpu_occupation.py b/test/srt/test_release_gpu_occupation.py new file mode 100644 index 0000000000..34437c37ae --- /dev/null +++ b/test/srt/test_release_gpu_occupation.py @@ -0,0 +1,99 @@ +import time +import unittest + +import torch +from transformers import AutoModelForCausalLM + +import sglang as sgl +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + +# (temporarily) set to true to observe memory usage in nvidia-smi more clearly +_DEBUG_EXTRA = True + + +class TestReleaseGPUOccupation(unittest.TestCase): + def test_release_and_resume_occupation(self): + prompt = "Today is a sunny day and I like" + sampling_params = {"temperature": 0, "max_new_tokens": 8} + model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + expect_output = " to spend it outdoors. I decided to" + + engine = sgl.Engine( + model_path=model_name, + random_seed=42, + memory_saver=True, + # disable_cuda_graph=True, # for debugging only + ) + hf_model_new = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype="bfloat16" + ) + + print("generate (#1)") + outputs = engine.generate(prompt, sampling_params)["text"] + self.assertEqual(outputs, expect_output) + + if _DEBUG_EXTRA: + time.sleep(3) + + self.assertEqual( + _try_allocate_big_tensor(), + False, + "Should not be able to allocate big tensors before releasing", + ) + + print("release_gpu_occupation start") + t = time.time() + engine.release_gpu_occupation() + if _DEBUG_EXTRA: + print("release_gpu_occupation", time.time() - t) + + if _DEBUG_EXTRA: + time.sleep(5) + + self.assertEqual( + _try_allocate_big_tensor(), + True, + "Should be able to allocate big tensors aftre releasing", + ) + + if _DEBUG_EXTRA: + time.sleep(5) + + print("resume_gpu_occupation start") + t = time.time() + engine.resume_gpu_occupation() + if _DEBUG_EXTRA: + print("resume_gpu_occupation", time.time() - t) + + self.assertEqual( + _try_allocate_big_tensor(), + False, + "Should not be able to allocate big tensors after resuming", + ) + + print("update_weights_from_tensor") + # As if: PPO has updated hf model's weights, and now we sync it to SGLang + for name, tensor in hf_model_new.named_parameters(): + engine.update_weights_from_tensor(name, tensor) + + print("generate (#2)") + outputs = engine.generate(prompt, sampling_params)["text"] + self.assertEqual(outputs, expect_output) + + if _DEBUG_EXTRA: + time.sleep(4) + + engine.shutdown() + + +def _try_allocate_big_tensor(size: int = 20_000_000_000): + try: + torch.empty((size,), dtype=torch.uint8, device="cuda") + torch.cuda.empty_cache() + return True + except torch.cuda.OutOfMemoryError: + return False + + +if __name__ == "__main__": + unittest.main()