Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions api_server/routes/internal/internal_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
from api_server.services.terminal_service import TerminalService
import app.logger

class InternalRoutes:
Expand All @@ -11,14 +12,17 @@ class InternalRoutes:
Check README.md for more information.

'''
def __init__(self):

def __init__(self, prompt_server):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
})
self.prompt_server = prompt_server
self.terminal_service = TerminalService(prompt_server)

def setup_routes(self):
@self.routes.get('/files')
Expand All @@ -34,7 +38,28 @@ async def list_files(request):

@self.routes.get('/logs')
async def get_logs(request):
return web.json_response(app.logger.get_logs())
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))

@self.routes.get('/logs/raw')
async def get_logs(request):
self.terminal_service.update_size()
return web.json_response({
"entries": list(app.logger.get_logs()),
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
})

@self.routes.patch('/logs/subscribe')
async def subscribe_logs(request):
json_data = await request.json()
client_id = json_data["clientId"]
enabled = json_data["enabled"]
if enabled:
self.terminal_service.subscribe(client_id)
else:
self.terminal_service.unsubscribe(client_id)

return web.Response(status=200)


@self.routes.get('/folder_paths')
async def get_folder_paths(request):
Expand Down
47 changes: 47 additions & 0 deletions api_server/services/terminal_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from app.logger import on_flush
import os


class TerminalService:
def __init__(self, server):
self.server = server
self.cols = None
self.rows = None
self.subscriptions = set()
on_flush(self.send_messages)

def update_size(self):
sz = os.get_terminal_size()
changed = False
if sz.columns != self.cols:
self.cols = sz.columns
changed = True

if sz.lines != self.rows:
self.rows = sz.lines
changed = True

if changed:
return {"cols": self.cols, "rows": self.rows}

return None

def subscribe(self, client_id):
self.subscriptions.add(client_id)

def unsubscribe(self, client_id):
self.subscriptions.discard(client_id)

def send_messages(self, entries):
if not len(entries) or not len(self.subscriptions):
return

new_size = self.update_size()

for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
if client_id not in self.server.sockets:
# Automatically unsub if the socket has disconnected
self.unsubscribe(client_id)
continue

self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
64 changes: 53 additions & 11 deletions app/logger.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,73 @@
import logging
from logging.handlers import MemoryHandler
from collections import deque
from datetime import datetime
import io
import logging
import sys
import threading

logs = None
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stdout_interceptor = None
stderr_interceptor = None


class LogInterceptor(io.TextIOWrapper):
def __init__(self, stream, *args, **kwargs):
buffer = stream.buffer
encoding = stream.encoding
super().__init__(buffer, *args, **kwargs, encoding=encoding)
self._lock = threading.Lock()
self._flush_callbacks = []
self._logs_since_flush = []

def write(self, data):
entry = {"t": datetime.now().isoformat(), "m": data}
with self._lock:
self._logs_since_flush.append(entry)

# Simple handling for cr to overwrite the last output if it isnt a full line
# else logs just get full of progress messages
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
logs.pop()
logs.append(entry)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems here the handling of progress messages is altering original terminal behavior.

super().write(data)

def flush(self):
super().flush()
for cb in self._flush_callbacks:
cb(self._logs_since_flush)
self._logs_since_flush = []

def on_flush(self, callback):
self._flush_callbacks.append(callback)


def get_logs():
return "\n".join([formatter.format(x) for x in logs])
return logs


def on_flush(callback):
if stdout_interceptor is not None:
stdout_interceptor.on_flush(callback)
if stderr_interceptor is not None:
stderr_interceptor.on_flush(callback)

def setup_logger(log_level: str = 'INFO', capacity: int = 300):
global logs
if logs:
return

# Override output streams and log to buffer
logs = deque(maxlen=capacity)

global stdout_interceptor
global stderr_interceptor
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)

# Setup default global logger
logger = logging.getLogger()
logger.setLevel(log_level)

stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_handler)

# Create a memory handler with a deque as its buffer
logs = deque(maxlen=capacity)
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
memory_handler.buffer = logs
memory_handler.setFormatter(formatter)
logger.addHandler(memory_handler)
2 changes: 1 addition & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, loop):
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'

self.user_manager = UserManager()
self.internal_routes = InternalRoutes()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
self.loop = loop
Expand Down
6 changes: 3 additions & 3 deletions tests-unit/server/routes/internal_routes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@pytest.fixture
def internal_routes():
return InternalRoutes()
return InternalRoutes(None)

@pytest.fixture
def aiohttp_client_factory(aiohttp_client, internal_routes):
Expand Down Expand Up @@ -102,7 +102,7 @@ async def test_file_service_initialization():
# Create a mock instance
mock_file_service_instance = MagicMock(spec=FileService)
MockFileService.return_value = mock_file_service_instance
internal_routes = InternalRoutes()
internal_routes = InternalRoutes(None)

# Check if FileService was initialized with the correct parameters
MockFileService.assert_called_once_with({
Expand All @@ -112,4 +112,4 @@ async def test_file_service_initialization():
})

# Verify that the file_service attribute of InternalRoutes is set
assert internal_routes.file_service == mock_file_service_instance
assert internal_routes.file_service == mock_file_service_instance
Loading