Skip to content

[#3334][feat] Support of CPU Inference for Scaffolding via PyTorch #4639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions examples/scaffolding/contrib/PytorchCPU/pytorch_worker_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import argparse
import asyncio

from tensorrt_llm.scaffolding.contrib.PytorchCPU import PytorchWorker
from tensorrt_llm.scaffolding.native_controller import \
NativeGenerationController
from tensorrt_llm.scaffolding.scaffolding_llm import ScaffoldingLlm


def parse_arguments():
parser = argparse.ArgumentParser()
# .e.g. DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
parser.add_argument(
'--model_dir',
type=str,
required=True,
help="Path to the directory containing the generation model")
parser.add_argument('--run_async', action='store_true')
args = parser.parse_args()
return args


def test_sync(prompts, proposer_worker):
prototype_controller = NativeGenerationController(
sampling_params={"temperature": 0.9})

llm = ScaffoldingLlm(
prototype_controller,
{NativeGenerationController.WorkerTag.GENERATION: proposer_worker},
)
results = llm.generate(prompts)
for result in results:
print(result.output.output_str)
print(f'main shutting down...')
llm.shutdown()
print(f'worker shutting down...')
proposer_worker.shutdown()
print(f'main shut down done')


def test_async(prompt, proposer_worker):

async def test_async_func(prompt, proposer_worker):
prototype_controller = NativeGenerationController(
sampling_params={"temperature": 0.9})
llm = ScaffoldingLlm(
prototype_controller,
{NativeGenerationController.WorkerTag.GENERATION: proposer_worker},
)

future = llm.generate_async(prompt)

result = await future.aresult()
print(result.output.output_str)

print(f'main shutting down...')
llm.shutdown()
print(f'worker shutting down...')
proposer_worker.shutdown()
print(f'main shut down done')

asyncio.run(test_async_func(prompt, proposer_worker))


def main():
args = parse_arguments()

prompts = [
"Anton sold GPUs to 48 of his friends in April, and then he sold half as many GPUs in May. How many GPUs did Anton sell altogether in April and May?\r\n\r\n",
"There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
]

llm_worker = PytorchWorker(
args.model_dir,
max_batch_size=32,
max_num_tokens=4096,
)

if args.run_async:
test_async(prompts[0], llm_worker)
else:
test_sync(prompts, llm_worker)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions tensorrt_llm/scaffolding/contrib/PytorchCPU/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
## TODO
- Write a doc explaining the motivation and how to use it
3 changes: 3 additions & 0 deletions tensorrt_llm/scaffolding/contrib/PytorchCPU/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .pytorch_worker import PytorchWorker

__all__ = ['PytorchWorker']
60 changes: 60 additions & 0 deletions tensorrt_llm/scaffolding/contrib/PytorchCPU/pytorch_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from transformers import AutoTokenizer

from tensorrt_llm.llmapi.llm import LLM
from tensorrt_llm.sampling_params import SamplingParams
from tensorrt_llm.scaffolding import GenerationTask, TaskStatus, Worker


class PytorchWorker(Worker):

def __init__(
self,
model_path: str,
max_batch_size: int = 32,
max_num_tokens: int = 4096,
):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
legacy=False,
trust_remote_code=False,
use_fast=True,
)
self.llm = LLM(
model_dir=model_path, # Use model_dir for consistency
tokenizer=self.tokenizer, # Pass the tokenizer to LLM
backend='pytorch',
pytorch_backend_config={'device': 'cpu'},
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
)

@classmethod
def convert_task_params(self, task: GenerationTask) -> SamplingParams:
sampling_params = SamplingParams(
max_tokens=task.max_tokens,
temperature=task.temperature,
top_p=task.top_p,
top_k=task.top_k,
return_context_logits=task.return_context_logits)
return sampling_params

async def generation_handler(self, task: GenerationTask) -> TaskStatus:
sampling_params = self.convert_task_params(task)
result = await self.llm.generate_async(task.input_str,
sampling_params=sampling_params)

task.output_tokens = result.outputs[0].token_ids
task.cumulative_logprob = result.outputs[0].cumulative_logprob
task.logprobs = result.outputs[0].logprobs
task.output_str = result.outputs[0].text
task.context_logits = result.context_logits

# TODO: error handle
return TaskStatus.SUCCESS

def shutdown(self):
# There is no clean-up needed
pass

task_handlers = {GenerationTask: generation_handler}