Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion examples/deployments/train_mnist/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ async def __aexit__(self, *exc):
pass


@app.remote(benchmark_dataset=[{"pixel_values": [[0.0] * 28] * 28}])
# Optional: workload = input pixel count, so bigger images count as more work.
@app.remote(
benchmark_dataset=[{"pixel_values": [[0.0] * 28] * 28}],
workload_calculator=lambda pixel_values: float(
len(pixel_values) * len(pixel_values[0])
),
)
async def infer(pixel_values: list[list[float]]) -> dict:
"""Classify a 28x28 grayscale MNIST image.

Expand Down
6 changes: 5 additions & 1 deletion examples/deployments/vllm/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ async def __aexit__(self, *exc):
self.engine.shutdown_background_loop()


@app.remote(benchmark_dataset=[{"prompt": "Hello"}])
# Optional: workload = max_tokens, so the autoscaler sizes by total tokens/sec.
@app.remote(
benchmark_dataset=[{"prompt": "Hello"}],
workload_calculator=lambda prompt, max_tokens=128: float(max_tokens),
)
async def generate(prompt: str, max_tokens: int = 128) -> str:
from vllm import SamplingParams
import uuid
Expand Down
138 changes: 138 additions & 0 deletions tests/serverless/test_remote_workload_calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Tests for the per-remote-function ``workload_calculator`` argument."""

from vastai.serverless.remote.serve import Deployment
from vastai.serverless.remote.serialization import serialize
from vastai.serverless.server.worker import (
WorkerConfig,
HandlerConfig,
EndpointHandlerFactory,
)


def _client_payload(deployment: Deployment, **kwargs) -> dict:
"""The request body a client sends for a remote call (see Deployment._dispatch)."""
return {
"kwargs": {
k: serialize(v, deployment.root_module) for k, v in kwargs.items()
}
}


def _handler_for(deployment: Deployment, route: str, calc):
"""Build the handler into_worker would create for a remote function."""
entry = deployment.remote_funcs[next(iter(deployment.remote_funcs))]
wrapped = (
deployment._wrap_workload_calculator(
deployment.root_module, calc, entry.globals
)
if calc is not None
else None
)
config = WorkerConfig(
handlers=[HandlerConfig(route=route, workload_calculator=wrapped)]
)
return EndpointHandlerFactory(config).get_handler(route)


def test_remote_stores_workload_calculator() -> None:
d = Deployment(name="wl-store")

def calc(a, b):
return float(len(a) * len(b))

@d.remote(workload_calculator=calc)
async def mul(a, b):
return a

entry = d.remote_funcs[next(iter(d.remote_funcs))]
assert entry.workload_calculator is calc


def test_workload_calculator_receives_deserialized_kwargs() -> None:
d = Deployment(name="wl-args")

@d.remote(workload_calculator=lambda a, b: float(len(a) * len(b)))
async def mul(a, b):
return a

handler = _handler_for(d, "/remote/mul", lambda a, b: float(len(a) * len(b)))
payload = handler.payload_cls().from_json_msg(
_client_payload(d, a=[1, 2, 3], b=[4, 5])
)
assert payload.count_workload() == 6.0


def test_workload_calculator_default_without_calculator() -> None:
d = Deployment(name="wl-default")

@d.remote()
async def mul(a, b):
return a

handler = _handler_for(d, "/remote/mul", None)
payload = handler.payload_cls().from_json_msg(
_client_payload(d, a=[1, 2, 3], b=[4, 5])
)
assert payload.count_workload() == 100.0


def test_workload_calculator_falls_back_when_it_raises() -> None:
d = Deployment(name="wl-raises")

def boom(a, b):
raise ValueError("bad input")

@d.remote(workload_calculator=boom)
async def mul(a, b):
return a

handler = _handler_for(d, "/remote/mul", boom)
payload = handler.payload_cls().from_json_msg(
_client_payload(d, a=[1, 2, 3], b=[4, 5])
)
assert payload.count_workload() == 100.0


def test_workload_calculator_falls_back_on_negative_or_nan() -> None:
d = Deployment(name="wl-bad-value")

@d.remote()
async def mul(a, b):
return a

for bad in (-1.0, float("nan"), float("inf")):
calc = (lambda v: lambda a, b: v)(bad)
handler = _handler_for(d, "/remote/mul", calc)
payload = handler.payload_cls().from_json_msg(
_client_payload(d, a=[1, 2, 3], b=[4, 5])
)
assert payload.count_workload() == 100.0


def test_into_worker_wires_workload_calculator(monkeypatch) -> None:
d = Deployment(name="wl-into-worker")

@d.remote(
benchmark_dataset=[{"a": [1, 2, 3], "b": [4, 5]}],
workload_calculator=lambda a, b: float(len(a) * len(b)),
)
async def mul(a, b):
return a

captured = {}

class StubWorker:
def __init__(self, config):
captured["config"] = config

monkeypatch.setattr("vastai.serverless.remote.serve.Worker", StubWorker)

d.into_worker()

handler_config = next(
hc for hc in captured["config"].handlers if hc.route == "/remote/mul"
)
assert handler_config.workload_calculator is not None
assert handler_config.workload_calculator(
_client_payload(d, a=[1, 2, 3], b=[4, 5])
) == 6.0
1 change: 1 addition & 0 deletions vastai/serverless/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def remote(
benchmark_dataset: list[dict] | None = None,
benchmark_generator: Callable[[], dict] | None = None,
benchmark_runs: int = 10,
workload_calculator: Callable[..., float] | None = None,
) -> (
Callable[P, Awaitable[Any]]
| Callable[[Callable[P, Awaitable[Any]]], Callable[P, Awaitable[Any]]]
Expand Down
1 change: 1 addition & 0 deletions vastai/serverless/remote/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def remote(
benchmark_dataset: list[dict] | None = None,
benchmark_generator: Callable[[], dict] | None = None,
benchmark_runs: int = 10,
workload_calculator: Callable[..., float] | None = None,
) -> (
Callable[P, Awaitable[Any]]
| Callable[[Callable[P, Awaitable[Any]]], Callable[P, Awaitable[Any]]]
Expand Down
35 changes: 35 additions & 0 deletions vastai/serverless/remote/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class RemoteFunc:
benchmark_dataset: Optional[list[dict[str, Any]]]
benchmark_generator: Optional[Callable[[], dict[str, Any]]]
benchmark_runs: int
workload_calculator: Optional[Callable[..., float]]


T = TypeVar("T")
Expand Down Expand Up @@ -106,6 +107,31 @@ async def wrapper(*, args: list = [], kwargs: dict = {}) -> dict:

return wrapper

def _wrap_workload_calculator(
self,
root_module: str,
user_calc: Callable[..., float],
func_globals: dict,
) -> Callable[[dict], float]:
"""Deserialize the request payload and score it with the user's calculator.

Runs synchronously inside count_workload and returns a plain float,
unlike the async, result-serializing remote-function wrapper.
"""

def calculator(payload: dict) -> float:
deserialized_args = [
deserialize(a, root_module, func_globals)
for a in payload.get("args", [])
]
deserialized_kwargs = {
k: deserialize(v, root_module, func_globals)
for k, v in payload.get("kwargs", {}).items()
}
return float(user_calc(*deserialized_args, **deserialized_kwargs))

return calculator

def into_worker(self) -> Worker:
handlers: list[HandlerConfig] = []
if isinstance(self.root_module, str):
Expand Down Expand Up @@ -153,13 +179,20 @@ def into_worker(self) -> Worker:
self.root_module, entry.func, entry.globals
)

workload_calculator = None
if entry.workload_calculator is not None:
workload_calculator = self._wrap_workload_calculator(
self.root_module, entry.workload_calculator, entry.globals
)

handlers.append(
HandlerConfig(
route=route,
remote_function=wrapped,
allow_parallel_requests=entry.allow_parallel_requests,
benchmark_config=benchmark_config,
max_queue_time=entry.max_queue_time,
workload_calculator=workload_calculator,
)
)

Expand All @@ -179,6 +212,7 @@ def remote(
benchmark_dataset: list[dict] | None = None,
benchmark_generator: Callable[[], dict] | None = None,
benchmark_runs: int = 10,
workload_calculator: Callable[..., float] | None = None,
) -> (
Callable[P, Awaitable[Any]]
| Callable[[Callable[P, Awaitable[Any]]], Callable[P, Awaitable[Any]]]
Expand All @@ -193,6 +227,7 @@ def decorator(f: Callable[P, Awaitable[Any]]) -> Callable[P, Awaitable[Any]]:
benchmark_dataset=benchmark_dataset,
benchmark_generator=benchmark_generator,
benchmark_runs=benchmark_runs,
workload_calculator=workload_calculator,
)
return f

Expand Down
30 changes: 25 additions & 5 deletions vastai/serverless/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aiohttp import web, ClientResponse
import logging
import json
import math
import random
import inspect
import logging
Expand All @@ -30,6 +31,10 @@
]
WorkloadCalculator = Callable[[Dict[str, Any]], float]

log = logging.getLogger(__name__)

DEFAULT_WORKLOAD = 100.0


@dataclass
class LogActionConfig:
Expand Down Expand Up @@ -172,11 +177,26 @@ def from_dict(cls, input: Dict[str, Any]) -> "GenericApiPayload":
return cls(input=input)

def count_workload(self) -> float:
# Use custom workload calculator if provided
if user_workload_calculator:
return user_workload_calculator(self.input)
# Default to 100 unless overridden
return 100.0
if not user_workload_calculator:
return DEFAULT_WORKLOAD
# Fall back rather than let a bad calculator 500 requests or
# poison the autoscaler's load sums.
try:
workload = float(user_workload_calculator(self.input))
except Exception:
log.warning(
f'workload_calculator for "{route_path}" raised; '
f"falling back to {DEFAULT_WORKLOAD}",
exc_info=True,
)
return DEFAULT_WORKLOAD
if not math.isfinite(workload) or workload < 0:
log.warning(
f'workload_calculator for "{route_path}" returned '
f"{workload!r}; falling back to {DEFAULT_WORKLOAD}"
)
return DEFAULT_WORKLOAD
return workload

@classmethod
def from_json_msg(cls, json_msg: Dict[str, Any]) -> "GenericApiPayload":
Expand Down
Loading