Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/beige-jobs-switch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:ZeroGPU native support
15 changes: 0 additions & 15 deletions gradio/block_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@

from . import utils

try:
import spaces # type: ignore
except Exception:
spaces = None


if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.components.base import Component
from gradio.renderable import Renderable
Expand Down Expand Up @@ -114,15 +108,6 @@ def __init__(
self.component_prop_inputs = component_prop_inputs or []
self.key = key

self.spaces_auto_wrap()

def spaces_auto_wrap(self):
if spaces is None:
return
if utils.get_space() is None:
return
self.fn = spaces.gradio_auto_wrap(self.fn)

def __str__(self):
return str(
{
Expand Down
16 changes: 9 additions & 7 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@
from gradio.helpers import create_tracker, skip, special_args
from gradio.i18n import I18n, I18nData
from gradio.node_server import start_node_server
from gradio.route_utils import API_PREFIX, MediaStream, slugify
from gradio.route_utils import (
API_PREFIX,
MediaStream,
maybe_setup_zerogpu_middleware,
slugify,
)
from gradio.routes import INTERNAL_ROUTES, VERSION, App, Request
from gradio.state_holder import SessionState, StateHolder
from gradio.themes import ThemeClass as Theme
Expand All @@ -95,12 +100,6 @@
get_upload_folder,
)

try:
import spaces # type: ignore
except Exception:
spaces = None


if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.components.base import Component
from gradio.mcp import GradioMCPServer
Expand Down Expand Up @@ -2753,6 +2752,9 @@ def reverse(text):
mcp_server=mcp_server,
debug=debug,
)

maybe_setup_zerogpu_middleware(self.app)

if self.mcp_error and not quiet:
print(self.mcp_error)

Expand Down
31 changes: 31 additions & 0 deletions gradio/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,37 @@ class LocalContext:
)


class MultiprocessWorkerContextualizer:
Comment thread
cbensimon marked this conversation as resolved.
"""
Refreshes LocalContext for persistent multiprocessing workers processing consecutive requests.

Example usage:
```
pool = ProcessPoolExecutor()

def handler(value):
contextualize = MultiprocessWorkerContextualizer()
return e.submit(process_wrapper, contextualize, value).result()

def process_wrapper(contextualize, value):
contextualize()
return process(value)

demo = gr.Interface(handler, gr.Text(), gr.Text())
```
"""

def __init__(self):
self.event_id = LocalContext.event_id.get(None)
self.in_event_listener = LocalContext.in_event_listener.get(False)
self.progress = LocalContext.progress.get(None)

def __call__(self):
LocalContext.event_id.set(self.event_id)
LocalContext.in_event_listener.set(self.in_event_listener)
LocalContext.progress.set(self.progress)
Comment thread
cbensimon marked this conversation as resolved.
Comment thread
cbensimon marked this conversation as resolved.


def get_render_context() -> BlockContext | None:
if LocalContext.renderable.get(None):
return LocalContext.render_block.get(None)
Expand Down
3 changes: 2 additions & 1 deletion gradio/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
START_TIMEOUT = int(os.getenv("GRADIO_START_TIMEOUT", "5"))

GRADIO_HOT_RELOAD = os.getenv("GRADIO_HOT_RELOAD", "false").lower()

Expand Down Expand Up @@ -69,7 +70,7 @@ def run_in_thread(self):
start = time.time()
while not self.started:
time.sleep(1e-3)
if time.time() - start > 5:
if time.time() - start > START_TIMEOUT:
raise ServerFailedToStartError(
"Server failed to start. Please check that the port is available."
)
Expand Down
93 changes: 65 additions & 28 deletions gradio/queueing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
import asyncio
import copy
import inspect
import multiprocessing
import os
import platform
import random
import time
import traceback
import uuid
import warnings
from asyncio import Queue as AsyncQueue
from collections import defaultdict
from multiprocessing import SimpleQueue
from threading import Thread
from typing import TYPE_CHECKING, Any, Literal, cast

import fastapi
Expand Down Expand Up @@ -136,6 +140,8 @@ def __init__(
self.active_jobs: list[None | list[Event]] = []
self.delete_lock = safe_get_lock()
self.server_app = None
self.server_pid = os.getpid()
self.rpc_queue: SimpleQueue[tuple[str, EventMessage]] | None = None
self.process_time_per_fn: defaultdict[BlockFunction, ProcessTime] = defaultdict(
ProcessTime
)
Expand Down Expand Up @@ -212,6 +218,23 @@ def start(self):
run_coro_in_background(self.start_progress_updates)
if not self.live_updates:
run_coro_in_background(self.notify_clients)
if os.getenv("GRADIO_QUEUE_MULTIPROCESSING_ENABLED") == "true":
Thread(target=self.start_rpc, daemon=True).start()

def start_rpc(self):
try:
ctx = multiprocessing.get_context("fork")
except ValueError:
warnings.warn("GRADIO_QUEUE_MULTIPROCESSING_ENABLED but fork not available")
return
self.rpc_queue = ctx.SimpleQueue()
Comment thread
cbensimon marked this conversation as resolved.
while True:
event_id, message = self.rpc_queue.get()
try:
self._send_message_rpc(event_id, message)
except Exception:
print("Exception while calling _send_message_rpc from Queue RPC thread")
traceback.print_exc()

def create_event_queue_for_fn(self, block_fn: BlockFunction):
concurrency_id = block_fn.concurrency_id
Expand Down Expand Up @@ -582,30 +605,47 @@ async def start_progress_updates(self) -> None:

await asyncio.sleep(self.progress_update_sleep_when_free)

def _send_message_rpc(
self,
event_id: str,
message: EventMessage,
):
if os.getpid() != self.server_pid:
if self.rpc_queue is None:
warnings.warn(
"Sending queue event from child process without GRADIO_QUEUE_MULTIPROCESSING_ENABLED"
)
else:
self.rpc_queue.put((event_id, message))
return
Comment thread
cbensimon marked this conversation as resolved.
events = [evt for job in self.active_jobs if job is not None for evt in job]
for event in events:
if event._id == event_id:
match message:
case ProgressMessage():
event.progress = message
event.progress_pending = True
case _:
self.send_message(event, message)

def set_progress(
self,
event_id: str,
iterables: list[TrackedIterable] | None,
):
if iterables is None:
return
for job in self.active_jobs:
if job is None:
continue
for evt in job:
if evt._id == event_id:
progress_data: list[ProgressUnit] = []
for iterable in iterables:
progress_unit = ProgressUnit(
index=iterable.index,
length=iterable.length,
unit=iterable.unit,
progress=iterable.progress,
desc=iterable.desc,
)
progress_data.append(progress_unit)
evt.progress = ProgressMessage(progress_data=progress_data)
evt.progress_pending = True
progress_data: list[ProgressUnit] = []
for iterable in iterables:
progress_unit = ProgressUnit(
index=iterable.index,
length=iterable.length,
unit=iterable.unit,
progress=iterable.progress,
desc=iterable.desc,
)
progress_data.append(progress_unit)
self._send_message_rpc(event_id, ProgressMessage(progress_data=progress_data))

def log_message(
self,
Expand All @@ -616,17 +656,14 @@ def log_message(
duration: float | None = 10,
visible: bool = True,
):
events = [evt for job in self.active_jobs if job is not None for evt in job]
for event in events:
if event._id == event_id:
log_message = LogMessage(
log=log,
level=level,
duration=duration,
visible=visible,
title=title,
)
self.send_message(event, log_message)
log_message = LogMessage(
log=log,
level=level,
duration=duration,
visible=visible,
title=title,
)
self._send_message_rpc(event_id, log_message)

async def clean_events(
self, *, session_hash: str | None = None, event_id: str | None = None
Expand Down
31 changes: 31 additions & 0 deletions gradio/route_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from gradio import processing_utils, utils
from gradio.context import MultiprocessWorkerContextualizer
from gradio.data_classes import (
BlocksConfigDict,
MediaStreamChunk,
Expand Down Expand Up @@ -1160,3 +1161,33 @@ async def iter_body(head: bytes, queue: asyncio.Queue[bytes | None]):
yield head
while (chunk := await queue.get()) is not None:
yield chunk


def maybe_setup_zerogpu_middleware(app: App | fastapi.FastAPI):
if not utils.is_zero_gpu_space():
return

try:
from spaces.zero import ZeroGPUMiddleware
except ImportError:
return

from gradio.helpers import log_message

app.add_middleware(
ZeroGPUMiddleware, # ty: ignore[invalid-argument-type]
exception_mapper=lambda err, exc: (
setattr(exc, "print_exception", False) or exc
if isinstance(exc, Error)
else Error(
title=err["detail"]["title"],
message=err["detail"]["message"],
)
),
log_emitter=lambda log: log_message(
title=log["title"],
message=log["message"],
level=log["level"],
),
worker_contextualizer=MultiprocessWorkerContextualizer,
)
2 changes: 2 additions & 0 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
Request,
compare_passwords_securely,
create_lifespan_handler,
maybe_setup_zerogpu_middleware,
move_uploaded_files_to_cache,
)
from gradio.screen_recording_utils import process_video_with_ffmpeg
Expand Down Expand Up @@ -2833,6 +2834,7 @@ async def new_lifespan(app: FastAPI):
app.router.lifespan_context = new_lifespan # type: ignore

app.mount(path, gradio_app)
maybe_setup_zerogpu_middleware(app)
return app


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ exclude = [

[tool.uv]
exclude-newer = "7 days"
exclude-newer-package = {hf-gradio = false, gradio-client = false}
exclude-newer-package = {hf-gradio = false, gradio-client = false, spaces = false}

[tool.ruff]
exclude = ["gradio/node/*.py", ".venv/*", "gradio/_frontend_code/*.py", "gradio/_vendor/*"]
Expand Down
1 change: 1 addition & 0 deletions test/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ tqdm
transformers
vega_datasets
diffusers
spaces>=0.50.dev0
8 changes: 7 additions & 1 deletion test/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was autogenerated by uv via the following command:
# uv pip compile --python-version 3.10 --exclude-newer 7 days test/requirements.in -o test/requirements.txt
# uv pip compile --exclude-newer 7 days test/requirements.in -o test/requirements.txt
aiofiles==23.2.1
# via gradio
altair==5.5.0
Expand Down Expand Up @@ -114,6 +114,7 @@ httpx==0.28.1
# openai
# respx
# safehttpx
# spaces
huggingface-hub==1.4.1
# via
# -r test/requirements.in
Expand Down Expand Up @@ -215,6 +216,7 @@ packaging==24.2
# pytest
# pytest-rerunfailures
# scikit-image
# spaces
# transformers
pandas==2.2.3
# via
Expand Down Expand Up @@ -249,6 +251,7 @@ pydantic==2.10.6
# fastapi
# gradio
# openai
# spaces
pydantic-core==2.27.2
# via pydantic
pydub==0.25.1
Expand Down Expand Up @@ -335,6 +338,8 @@ sniffio==1.3.1
# openai
sortedcontainers==2.4.0
# via hypothesis
spaces==0.50.dev0
# via -r test/requirements.in
stack-data==0.6.3
# via ipython
starlette==0.45.3
Expand Down Expand Up @@ -391,6 +396,7 @@ typing-extensions==4.12.2
# pydantic-core
# referencing
# rich
# spaces
# torch
# typer
# uvicorn
Expand Down
Loading