Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA-graph-compatible releasing and resuming KV cache and model weight memory #2630

Open
wants to merge 122 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
a5061cc
empty struct
fzyzcjy Dec 26, 2024
5a5651b
more
fzyzcjy Dec 26, 2024
1ccf84c
more
fzyzcjy Dec 26, 2024
6e55282
simp
fzyzcjy Dec 26, 2024
5edcf5a
more
fzyzcjy Dec 26, 2024
35eb3ad
fix typing
fzyzcjy Dec 26, 2024
211550e
more
fzyzcjy Dec 26, 2024
619aa19
Merge branch 'feat/code_cleanup' into feat/memory_optimization
fzyzcjy Dec 26, 2024
95a8db9
more
fzyzcjy Dec 26, 2024
5650a75
more
fzyzcjy Dec 26, 2024
ecd3d9a
more
fzyzcjy Dec 26, 2024
53573cc
more
fzyzcjy Dec 26, 2024
8f8bc3d
fix
fzyzcjy Dec 26, 2024
eaa9808
more
fzyzcjy Dec 26, 2024
f3c948c
more
fzyzcjy Dec 26, 2024
94e9ec8
more
fzyzcjy Dec 26, 2024
2317150
more
fzyzcjy Dec 26, 2024
bc56193
Merge branch 'feat/code_cleanup' into feat/memory_optimization
fzyzcjy Dec 26, 2024
8042494
more
fzyzcjy Dec 26, 2024
5063ff0
Merge branch 'main' into feat/code_cleanup
fzyzcjy Dec 26, 2024
711b7de
Merge branch 'main' into feat/memory_optimization
fzyzcjy Dec 26, 2024
255251f
cleanup
fzyzcjy Dec 27, 2024
7f737d5
fix
fzyzcjy Dec 27, 2024
db53dfa
Merge branch 'main' into feat/code_cleanup
fzyzcjy Dec 27, 2024
9dcba7b
Merge branch 'main' into feat/code_cleanup
fzyzcjy Dec 28, 2024
adff864
Merge branch 'feat/code_cleanup' into feat/memory_optimization
fzyzcjy Dec 28, 2024
23ff620
Merge branch 'feat/memory_optimization' into feat/memory_saver
fzyzcjy Dec 28, 2024
59989d7
more
fzyzcjy Dec 28, 2024
08d5900
more
fzyzcjy Dec 28, 2024
bf890f8
enable cudagraph
fzyzcjy Dec 28, 2024
ec94ea9
more
fzyzcjy Dec 28, 2024
16ad9a6
more
fzyzcjy Dec 28, 2024
3623b07
more
fzyzcjy Dec 28, 2024
f2e9fb0
update
fzyzcjy Dec 28, 2024
c2f3a41
partial hack
fzyzcjy Dec 28, 2024
c5abfcd
Revert "partial hack"
fzyzcjy Dec 28, 2024
4ffdaa4
more
fzyzcjy Dec 28, 2024
3901795
more
fzyzcjy Dec 28, 2024
ee73bed
cp back
fzyzcjy Dec 28, 2024
d496f92
more
fzyzcjy Dec 28, 2024
4a323c0
more
fzyzcjy Dec 28, 2024
c6cb3d1
more
fzyzcjy Dec 28, 2024
d4ce807
fmt
fzyzcjy Dec 28, 2024
487f507
Merge branch 'main' into feat/memory_saver
fzyzcjy Dec 28, 2024
fb7d1f1
Merge branch 'main' into feat/memory_saver
fzyzcjy Dec 28, 2024
5f8dec5
more
fzyzcjy Dec 28, 2024
8273b77
more
fzyzcjy Dec 28, 2024
04e922e
model
fzyzcjy Dec 28, 2024
cc50ffb
temp revert
fzyzcjy Dec 28, 2024
ecf989a
more
fzyzcjy Dec 28, 2024
436eb85
more
fzyzcjy Dec 28, 2024
55e3906
more
fzyzcjy Dec 28, 2024
7622993
Revert "temp revert"
fzyzcjy Dec 28, 2024
b167dc0
more
fzyzcjy Dec 28, 2024
59f7115
Revert "more"
fzyzcjy Dec 28, 2024
98c6adc
more
fzyzcjy Dec 28, 2024
90b66a4
Revert "Revert "temp revert""
fzyzcjy Dec 28, 2024
bb5c771
more
fzyzcjy Dec 28, 2024
dcce14f
more
fzyzcjy Dec 28, 2024
ffff004
temp hack
fzyzcjy Dec 28, 2024
35c424f
temp
fzyzcjy Dec 28, 2024
f691a38
more
fzyzcjy Dec 28, 2024
71a5c6d
Revert "Revert "Revert "temp revert"""
fzyzcjy Dec 28, 2024
e1e1290
Revert "more"
fzyzcjy Dec 28, 2024
ec0774d
Revert "temp hack"
fzyzcjy Dec 28, 2024
e5514f8
more
fzyzcjy Dec 28, 2024
a190bb1
more
fzyzcjy Dec 29, 2024
9ed56d5
more
fzyzcjy Dec 29, 2024
4b56114
more
fzyzcjy Dec 29, 2024
9b13549
more
fzyzcjy Dec 29, 2024
5841893
Revert "more"
fzyzcjy Dec 29, 2024
6e06c6e
Revert "more"
fzyzcjy Dec 29, 2024
5339e67
Revert "more"
fzyzcjy Dec 29, 2024
122d7bc
Revert "more"
fzyzcjy Dec 29, 2024
9478820
Revert "Auxiliary commit to revert individual files from e5514f8a387b…
fzyzcjy Dec 29, 2024
e2b7b55
empty
fzyzcjy Dec 29, 2024
a44f9f5
base class
fzyzcjy Dec 29, 2024
72c140c
import
fzyzcjy Dec 29, 2024
35fb9c6
more
fzyzcjy Dec 29, 2024
e224392
more
fzyzcjy Dec 29, 2024
7042f0f
more
fzyzcjy Dec 29, 2024
17be8a6
more
fzyzcjy Dec 29, 2024
f78d80e
more
fzyzcjy Dec 29, 2024
a501c31
more
fzyzcjy Dec 29, 2024
e257829
more
fzyzcjy Dec 29, 2024
097cfe6
dep
fzyzcjy Dec 29, 2024
b81fa88
Merge branch 'main' into feat/memory_saver
fzyzcjy Dec 29, 2024
cdf54fa
more
fzyzcjy Dec 30, 2024
27c1ad4
simp
fzyzcjy Dec 30, 2024
b1d1c53
Merge branch 'main' into feat/memory_saver
fzyzcjy Dec 30, 2024
d7d0c11
Merge branch 'main' into feat/memory_saver
fzyzcjy Dec 30, 2024
74c68ad
more
fzyzcjy Dec 31, 2024
9b7d394
Merge remote-tracking branch 'origin/feat/memory_saver' into feat/mem…
fzyzcjy Dec 31, 2024
b876438
more
fzyzcjy Dec 31, 2024
8e65c60
Merge branch 'feat/shell_script' into feat/memory_saver
fzyzcjy Dec 31, 2024
4c5f4a2
more
fzyzcjy Dec 31, 2024
79c5ebc
more
fzyzcjy Dec 31, 2024
833b423
Merge branch 'feat/shell_script' into feat/memory_saver
fzyzcjy Dec 31, 2024
6000d07
bump
fzyzcjy Dec 31, 2024
3870214
fmt
fzyzcjy Dec 31, 2024
b8741ce
more
fzyzcjy Dec 31, 2024
98725f1
more
fzyzcjy Dec 31, 2024
e866ec9
bump
fzyzcjy Dec 31, 2024
199d286
bump
fzyzcjy Dec 31, 2024
169e02c
fmt
fzyzcjy Dec 31, 2024
fa81f27
bump
fzyzcjy Dec 31, 2024
ae0c589
fmt
fzyzcjy Dec 31, 2024
1e7b1d4
Merge branch 'main' into feat/memory_saver
fzyzcjy Dec 31, 2024
0b5e26a
more
fzyzcjy Dec 31, 2024
d9ad0e9
optional dep
fzyzcjy Dec 31, 2024
3b30168
more
fzyzcjy Dec 31, 2024
e8e2375
more
fzyzcjy Dec 31, 2024
4ef6d46
more
fzyzcjy Dec 31, 2024
ceec579
more
fzyzcjy Dec 31, 2024
2bcac8e
more
fzyzcjy Dec 31, 2024
7749fbf
more
fzyzcjy Dec 31, 2024
0de1f1a
more
fzyzcjy Dec 31, 2024
9e3d55d
more
fzyzcjy Dec 31, 2024
b2a4804
more
fzyzcjy Dec 31, 2024
3bb197e
Merge branch 'main' into feat/memory_saver
fzyzcjy Dec 31, 2024
809be14
fmt
fzyzcjy Dec 31, 2024
b7f795d
more
fzyzcjy Dec 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ srt_hpu = ["sglang[runtime_common]"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver"]
test = [
"jsonlines",
"matplotlib",
Expand Down
20 changes: 20 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,26 @@ class GetWeightsByNameReqOutput:
parameter: list


@dataclass
class ReleaseGPUOccupationReqInput:
pass


@dataclass
class ReleaseGPUOccupationReqOutput:
pass


@dataclass
class ResumeGPUOccupationReqInput:
pass


@dataclass
class ResumeGPUOccupationReqOutput:
pass


@dataclass
class AbortReq:
# The request id
Expand Down
29 changes: 29 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
ReleaseGPUOccupationReqInput,
ReleaseGPUOccupationReqOutput,
ResumeGPUOccupationReqInput,
ResumeGPUOccupationReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
Expand Down Expand Up @@ -86,6 +90,7 @@
set_random_seed,
suppress_other_loggers,
)
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -332,6 +337,10 @@ def __init__(
t.start()
self.parent_process = psutil.Process().parent()

self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.memory_saver
)

# Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None
Expand Down Expand Up @@ -492,6 +501,12 @@ def process_input_requests(self, recv_reqs: List):
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ReleaseGPUOccupationReqInput):
self.release_gpu_occupation()
self.send_to_tokenizer.send_pyobj(ReleaseGPUOccupationReqOutput())
elif isinstance(recv_req, ResumeGPUOccupationReqInput):
self.resume_gpu_occupation()
self.send_to_tokenizer.send_pyobj(ResumeGPUOccupationReqOutput())
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
Expand Down Expand Up @@ -1497,6 +1512,20 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter

def release_gpu_occupation(self):
self.stashed_model_static_state = (
self.tp_worker.worker.model_runner.model.export_static_state()
)
self.memory_saver_adapter.pause()
self.flush_cache()

def resume_gpu_occupation(self):
self.memory_saver_adapter.resume()
self.tp_worker.worker.model_runner.model.import_static_state(
self.stashed_model_static_state
)
del self.stashed_model_static_state

def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
Expand Down
32 changes: 32 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
ReleaseGPUOccupationReqInput,
ReleaseGPUOccupationReqOutput,
ResumeGPUOccupationReqInput,
ResumeGPUOccupationReqOutput,
SessionParams,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
Expand Down Expand Up @@ -188,6 +192,12 @@ def __init__(
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.release_gpu_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.resume_gpu_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)

# Metrics
if self.enable_metrics:
Expand Down Expand Up @@ -548,6 +558,22 @@ async def get_weights_by_name(
else:
return all_parameters

async def release_gpu_occupation(
self,
obj: ReleaseGPUOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.release_gpu_occupation_communicator(obj)

async def resume_gpu_occupation(
self,
obj: ResumeGPUOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.resume_gpu_occupation_communicator(obj)

async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
Expand Down Expand Up @@ -627,6 +653,8 @@ async def handle_loop(self):
UpdateWeightsFromDistributedReqOutput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqOutput,
ReleaseGPUOccupationReqOutput,
ResumeGPUOccupationReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()

if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
Expand Down Expand Up @@ -750,6 +778,10 @@ async def handle_loop(self):
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
self.get_weights_by_name_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, ReleaseGPUOccupationReqOutput):
self.release_gpu_occupation_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, ResumeGPUOccupationReqOutput):
self.resume_gpu_occupation_communicator.handle_recv(recv_obj)
else:
raise ValueError(f"Invalid object: {recv_obj=}")

Expand Down
114 changes: 66 additions & 48 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
limitations under the License.
"""

from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter

"""
Memory pool.

Expand All @@ -35,13 +37,21 @@
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""

def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
def __init__(
self,
size: int,
max_context_len: int,
device: str,
use_records: bool,
memory_saver_adapter: TorchMemorySaverAdapter,
):
self.size = size
self.max_context_len = max_context_len
self.device = device
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
with memory_saver_adapter.region():
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size))
self.write_records = []
self.use_records = use_records
Expand Down Expand Up @@ -182,32 +192,35 @@ def __init__(
head_dim: int,
layer_num: int,
device: str,
memory_saver_adapter: TorchMemorySaverAdapter,
):
super().__init__(size, dtype, device)
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self.memory_saver_adapter = memory_saver_adapter
self._create_buffers()

def _create_buffers(self):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
with self.memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]

def _clear_buffers(self):
del self.k_buffer
Expand Down Expand Up @@ -262,19 +275,22 @@ def __init__(
qk_rope_head_dim: int,
layer_num: int,
device: str,
memory_saver_adapter: TorchMemorySaverAdapter,
):
super().__init__(size, dtype, device)

self.kv_lora_rank = kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]

with memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]

def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
Expand Down Expand Up @@ -315,26 +331,28 @@ def __init__(
layer_num: int,
device: str,
heavy_channel_num: int,
memory_saver_adapter: TorchMemorySaverAdapter,
):
super().__init__(size, dtype, device)

# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]

# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.empty(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
]
with memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]

# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.empty(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
]

def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id]
Expand Down
Loading
Loading