Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add arguments mapping between SGLang / vllm / trt-llm #2657

Open
2 tasks done
zhaochenyang20 opened this issue Dec 30, 2024 · 2 comments
Open
2 tasks done
Assignees
Labels
documentation Improvements or additions to documentation good first issue Good for newcomers help wanted Extra attention is needed RLHF Using SGLang for post training

Comments

@zhaochenyang20
Copy link
Collaborator

Checklist

Motivation

This is what I need to do for integrating SGLang into OpenRLHF. OpenRLHF already supports vllm. We need to add sglang. I need to map the server and sampling parameters from vllm to sglang. I think this is a good issue for us to let our users switch smoothly between mainstream engines.

I attached how I am doing right now. But it may be wrong.

Related resources

The args Mapping from vllm to sglang

These are the server parameters of vllm:

pretrain,
noset_visible_devices=noset_visible_devices,
trust_remote_code=True,
tensor_parallel_size=tensor_parallel_size,
dtype="bfloat16",
seed=seed + i,
enable_prefix_caching=enable_prefix_caching,
enforce_eager=enforce_eager,
max_model_len=max_model_len,
backend=backend,

Among them, pretrain is the model path, and this is my mapping in sglang:

#! TODO chenyang check engine params
sglang_params = {
    "model_path": args[0],  # pretrain path
    "trust_remote_code": kwargs.get("trust_remote_code", True),
    "dtype": kwargs.get("dtype", "auto"),
    "tp_size": kwargs.get("tensor_parallel_size", 1),
    "device": "cuda",
    "disable_radix_cache": not kwargs.get("enable_prefix_caching", False),
    "random_seed": kwargs.get("seed", 42),
    "disable_cuda_graph": not kwargs.get("enforce_eager", False),
    "disable_cuda_graph_padding": not kwargs.get("enable_prefix_caching", False),
    "context_length": kwargs.get("max_model_len", None),
    "log_level": "info",
    "return_token_ids": True,
}
self.llm = sglang.Engine(**sglang_params)

The Sampling Params Mapping from vllm to sglang

if self.backend == "vllm":
    outputs = self.llm.generate(
        sampling_params=kwargs["sampling_params"], prompt_token_ids=kwargs["prompt_token_ids"]
    )
elif self.backend == "sglang":
    # Note that sglang sampling params are different from vllm
    sampling_params = kwargs["sampling_params"]
    all_prompts = kwargs["all_prompts"]

    # min_tokens, include_stop_str_in_output is not used in sglang

    sampling_params = dict(
        max_new_tokens=sampling_params.max_tokens,
        top_p=sampling_params.top_p,
        top_k=sampling_params.top_k,
        temperature=sampling_params.temperature,
        repetition_penalty=sampling_params.repetition_penalty,
        skip_special_tokens=sampling_params.skip_special_tokens,
    )
    outputs = self.llm.generate(all_prompts, sampling_params)

Of course, the sampling params passed in from the front end are as follows:

sampling_params = SamplingParams(
    temperature=kwargs.get("temperature", 1.0),
    top_p=kwargs.get("top_p", 1.0),
    top_k=kwargs.get("top_k", -1),
    max_tokens=kwargs.get("max_new_tokens", 1024),
    min_tokens=kwargs.get("min_new_tokens", 1),
    skip_special_tokens=kwargs.get("skip_special_tokens", False),
    include_stop_str_in_output=True,
)

There may be problems with my these mappings. We need documentation as a guide.

@zhaochenyang20 zhaochenyang20 self-assigned this Dec 30, 2024
@zhaochenyang20 zhaochenyang20 added documentation Improvements or additions to documentation good first issue Good for newcomers help wanted Extra attention is needed labels Dec 30, 2024
@zhaochenyang20
Copy link
Collaborator Author

@shuaills Would you like to take this first? At least make sure my mapping between vllm and sglang in openrlhf is right. The docs can be later.

@zhaochenyang20 zhaochenyang20 added the RLHF Using SGLang for post training label Dec 30, 2024
@zhaochenyang20
Copy link
Collaborator Author

@minleminzui thanks for interests!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation good first issue Good for newcomers help wanted Extra attention is needed RLHF Using SGLang for post training
Projects
None yet
Development

No branches or pull requests

1 participant