Skip to content

Graceful server stopping #10458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modules/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,4 @@
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
42 changes: 41 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import sys
import threading
import time

import gradio as gr
Expand Down Expand Up @@ -110,8 +111,47 @@ class State:
id_live_preview = 0
textinfo = None
time_start = None
need_restart = False
server_start = None
_server_command_signal = threading.Event()
_server_command: str | None = None

@property
def need_restart(self) -> bool:
# Compatibility getter for need_restart.
return self.server_command == "restart"

@need_restart.setter
def need_restart(self, value: bool) -> None:
# Compatibility setter for need_restart.
if value:
self.server_command = "restart"

@property
def server_command(self):
return self._server_command

@server_command.setter
def server_command(self, value: str | None) -> None:
"""
Set the server command to `value` and signal that it's been set.
"""
self._server_command = value
self._server_command_signal.set()

def wait_for_server_command(self, timeout: float | None = None) -> str | None:
"""
Wait for server command to get set; return and clear the value and signal.
"""
if self._server_command_signal.wait(timeout):
self._server_command_signal.clear()
req = self._server_command
self._server_command = None
return req
return None

def request_restart(self) -> None:
self.interrupt()
self.server_command = True

def skip(self):
self.skipped = True
Expand Down
6 changes: 1 addition & 5 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,12 +1609,8 @@ def reload_scripts():
outputs=[]
)

def request_restart():
shared.state.interrupt()
shared.state.need_restart = True

restart_gradio.click(
fn=request_restart,
fn=shared.state.request_restart,
_js='restart_reload',
inputs=[],
outputs=[],
Expand Down
7 changes: 2 additions & 5 deletions modules/ui_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
shared.opts.save(shared.config_filename)

shared.state.interrupt()
shared.state.need_restart = True
shared.state.request_restart()


def save_config_state(name):
Expand Down Expand Up @@ -92,8 +90,7 @@ def restore_config_state(confirmed, config_state_name, restore_type):
if restore_type == "webui" or restore_type == "both":
config_states.restore_webui_config(config_state)

shared.state.interrupt()
shared.state.need_restart = True
shared.state.request_restart()

return ""

Expand Down
50 changes: 34 additions & 16 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
from threading import Thread

from fastapi import FastAPI
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
Expand Down Expand Up @@ -234,7 +234,10 @@ def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
os._exit(0)

signal.signal(signal.SIGINT, sigint_handler)
if not os.environ.get("COVERAGE_RUN"):
# Don't install the immediate-quit handler when running under coverage,
# as then the coverage report won't be generated.
signal.signal(signal.SIGINT, sigint_handler)


def setup_middleware(app):
Expand All @@ -255,19 +258,6 @@ def create_api(app):
return api


def wait_on_server(demo=None):
while 1:
time.sleep(0.5)
if shared.state.need_restart:
shared.state.need_restart = False
time.sleep(0.5)
demo.close()
time.sleep(0.5)

modules.script_callbacks.app_reload_callback()
break


def api_only():
initialize()

Expand All @@ -280,6 +270,12 @@ def api_only():
print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)


def stop_route(request):
shared.state.server_command = "stop"
return Response("Stopping.")


def webui():
launch_api = cmd_opts.api
initialize()
Expand Down Expand Up @@ -328,6 +324,9 @@ def fastapi_setup(self):
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True
)
if cmd_opts.add_stop_route:
app.add_route("/_stop", stop_route, methods=["POST"])

# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False

Expand Down Expand Up @@ -359,8 +358,27 @@ def fastapi_setup(self):
redirector.get("/")
gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")

wait_on_server(shared.demo)
try:
while True:
server_command = shared.state.wait_for_server_command(timeout=5)
if server_command:
if server_command in ("stop", "restart"):
break
else:
print(f"Unknown server command: {server_command}")
except KeyboardInterrupt:
print('Caught KeyboardInterrupt, stopping...')
server_command = "stop"

if server_command == "stop":
print("Stopping server...")
# If we catch a keyboard interrupt, we want to stop the server and exit.
shared.demo.close()
break
print('Restarting UI...')
shared.demo.close()
time.sleep(0.5)
modules.script_callbacks.app_reload_callback()

startup_timer.reset()

Expand Down