Skip to content

Commit

Permalink
[Minor] Improve code style (#2422)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Dec 9, 2024
1 parent 0ce091a commit 641b7d0
Show file tree
Hide file tree
Showing 15 changed files with 33 additions and 21 deletions.
3 changes: 2 additions & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
srt_xpu = ["sglang[runtime_common]"]
#For Intel Gaudi(device : hpu) follow the installation guide
#https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
srt_hpu = ["sglang[runtime_common]"]
srt_hpu = ["sglang[runtime_common]"]

openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
Expand All @@ -50,6 +50,7 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]

dev = ["sglang[all]", "sglang[test]"]
dev_hip = ["sglang[all_hip]", "sglang[test]"]
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/bench_offline_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def throughput_test(
else:
raise ValueError('Please set backend to either "engine" or "runtime"')

tokenizer_id = server_args.model_path
tokenizer_id = server_args.tokenizer_path or server_args.model_path
tokenizer = get_tokenizer(tokenizer_id)

# Set global environmnets
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/constrained/xgrammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
key_type, key_string = key
if key_type == "json":
try:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
if key_string == "$$ANY$$":
ctx = self.grammar_compiler.compile_builtin_json_grammar()
else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except RuntimeError as e:
logging.warning(
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch

from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch

if TYPE_CHECKING:
Expand Down
9 changes: 8 additions & 1 deletion python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,14 @@ def __init__(
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention

def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
def forward(
self,
q,
k,
v,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def __repr__(self):

@dataclasses.dataclass
class ScheduleBatch:
"""Store all inforamtion of a batch on the scheduler."""
"""Store all information of a batch on the scheduler."""

# Request, memory pool, and cache
reqs: List[Req]
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sys
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union

import fastapi
import uvloop
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(self, model_runner: "ModelRunner"):

# Batch sizes to capture
if model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 32)) + [64, 128]
self.capture_bs = list(range(1, 33)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]

Expand Down
12 changes: 8 additions & 4 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,22 @@ def load_model(self):
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")

# Prepare the vllm model config
# Prepare the model config
self.load_config = LoadConfig(
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
)

if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()

# Load the model
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)

# Parse other args
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
Expand All @@ -270,8 +272,10 @@ def load_model(self):
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)

def update_weights_from_disk(self, model_path: str, load_format: str):
"""Update engine weights online from disk."""
def update_weights_from_disk(
self, model_path: str, load_format: str
) -> tuple[bool, str]:
"""Update engine weights in-place from the disk."""
from sglang.srt.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/models/gemma2_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(
) -> None:
super().__init__()
self.config = config
self.torchao_config = None
self.quant_config = quant_config
self.num_labels = config.num_labels
self.model = Gemma2Model(config, quant_config=quant_config)
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/models/llama_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(
) -> None:
super().__init__()
self.config = config
self.torchao_config = None
self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config)

Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/models/llama_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel


Expand All @@ -33,7 +32,6 @@ def __init__(
) -> None:
super().__init__()
self.config = config
self.torchao_config = None
self.quant_config = quant_config
self.num_labels = config.num_labels
self.model = LlamaModel(config, quant_config=quant_config)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def stop_profile_async():
@app.post("/update_weights_from_disk")
@time_func_latency
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk inplace without re-launching the server."""
"""Update the weights from disk in-place without re-launching the server."""
success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
content = {"success": success, "message": message}
if success:
Expand Down
9 changes: 6 additions & 3 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def inner_func(*args, **kwargs):
return wrapper


def get_available_gpu_memory(device, gpu_id, distributed=False):
def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
Expand All @@ -184,7 +184,8 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
"which may cause useless memory allocation for torch CUDA context.",
)

torch.cuda.empty_cache()
if empty_cache:
torch.cuda.empty_cache()
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)

elif device == "xpu":
Expand All @@ -196,7 +197,9 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
"which may cause useless memory allocation for torch XPU context.",
)
torch.xpu.empty_cache()

if empty_cache:
torch.xpu.empty_cache()
used_memory = torch.xpu.memory_allocated()
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
free_gpu_memory = total_gpu_memory - used_memory
Expand Down
1 change: 0 additions & 1 deletion test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"test_double_sparsity.py",
"test_embedding_openai_server.py",
"test_eval_accuracy_mini.py",
"test_fused_moe.py",
"test_get_weights_by_name.py",
"test_gguf.py",
"test_input_embeddings.py",
Expand Down

0 comments on commit 641b7d0

Please sign in to comment.