Skip to content

Commit

Permalink
Add multiple session detection
Browse files Browse the repository at this point in the history
  • Loading branch information
qstokkink committed Feb 18, 2025
1 parent 43c7e14 commit 6086a1c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 17 deletions.
9 changes: 6 additions & 3 deletions src/run_tribler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def show_error(exc: Exception, shutdown: bool = True) -> NoReturn:
import argparse
import asyncio
import encodings.idna # noqa: F401 (https://github.com/pyinstaller/pyinstaller/issues/1113)
import json
import logging.config
import os
import threading
Expand Down Expand Up @@ -230,15 +231,17 @@ async def main() -> None:
session = Session(config)

torrent_uri = load_torrent_uri(parsed_args)
server_url = await session.find_api_server()
server_url, initial_message = await session.find_api_server()

headless = parsed_args.get("server")
if server_url:
logger.info("Core already running at %s", server_url)
if torrent_uri:
logger.info("Starting torrent using existing core")
await start_download(config, server_url, torrent_uri)
if not headless:
# Don't open a new tab if we're (a) only adding a torrent with a session open, or (b) running headless.
if not headless and (not torrent_uri or (initial_message is not None
and json.loads(initial_message).get("sessions", "0") == "0")):
open_webbrowser_tab(server_url + f"?key={config.get('api/key')}")
logger.info("Shutting down")
return
Expand All @@ -247,7 +250,7 @@ async def main() -> None:
except Exception as exc:
show_error(exc, True)

server_url = await session.find_api_server()
server_url, _ = await session.find_api_server()
if server_url and torrent_uri:
await start_download(config, server_url, torrent_uri)
icon = None if headless else spawn_tray_icon(session, config)
Expand Down
24 changes: 21 additions & 3 deletions src/tribler/core/restapi/events_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ipv8.REST.schema import schema

from tribler.core.notifier import Notification, Notifier
from tribler.core.restapi.rest_endpoint import RESTEndpoint
from tribler.core.restapi.rest_endpoint import RESTEndpoint, RESTResponse

if TYPE_CHECKING:
from aiohttp.abc import Request
Expand Down Expand Up @@ -68,7 +68,8 @@ def __init__(self, notifier: Notifier, public_key: str | None = None) -> None:
notifier.add(Notification.circuit_removed, self.on_circuit_removed)
notifier.delegates.add(self.on_notification)

self.app.add_routes([web.get("", self.get_events)])
self.app.add_routes([web.get("", self.get_events),
web.get("/info", self.get_info)])

@property
def _shutdown(self) -> bool:
Expand Down Expand Up @@ -109,7 +110,7 @@ def initial_message(self) -> MessageDict:
v = "git"
return {
"topic": Notification.events_start.value.name,
"kwargs": {"public_key": self.public_key or "", "version": v}
"kwargs": {"public_key": self.public_key or "", "version": v, "sessions": str(len(self.events_responses))}
}

def error_message(self, reported_error: Exception) -> MessageDict:
Expand Down Expand Up @@ -213,6 +214,23 @@ def on_tribler_exception(self, reported_error: Exception) -> None:
# If there are several undelivered errors, we store the first error as more important and skip other
self.undelivered_error = reported_error

@docs(
tags=["General"],
summary="Get the general info of the events endpoint.",
responses={
200: {
"schema": schema(EventsInfoResponse={"public_key": marshmallow.fields.String,
"version": marshmallow.fields.Integer,
"sessions": marshmallow.fields.Integer})
}
}
)
async def get_info(self, request: Request) -> RESTResponse:
"""
Get the general info of the events endpoint.
"""
return RESTResponse(self.initial_message()["kwargs"])

@docs(
tags=["General"],
summary="Open an EventStream for receiving Tribler events.",
Expand Down
24 changes: 14 additions & 10 deletions src/tribler/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ def rust_enhancements(session: Session) -> Generator[None, None, None]:
if_specs[i]["worker_threads"] = previous_value


async def _is_url_available(url: str, timeout: int=1) -> bool:
async def _is_url_available(url: str, timeout: int=1) -> tuple[bool, bytes | None]:
async with aiohttp.ClientSession() as session:
try:
async with session.get(url, timeout=timeout):
return True
async with session.get(url, timeout=timeout) as response:
return True, await response.read()
except (asyncio.TimeoutError, aiohttp.client_exceptions.ClientConnectorError,
aiohttp.client_exceptions.ClientResponseError):
return False
return False, None


def rescue_keys(config: TriblerConfigManager) -> None:
Expand Down Expand Up @@ -243,21 +243,25 @@ async def start(self) -> None:
self.rest_manager.get_endpoint("/api/ipv8").endpoints["/overlays"].enable_overlay_statistics(True, None,
True)

async def find_api_server(self) -> str | None:
async def find_api_server(self) -> tuple[str | None, bytes | None]:
"""
Find the API server, if available.
"""
info_route = f'/api/events/info?key={self.config.get("api/key")}'

if port := self.config.get("api/http_port_running"):
http_url = f'http://{self.config.get("api/http_host")}:{port}'
if await _is_url_available(http_url):
return http_url
available, response = await _is_url_available(http_url + info_route)
if available:
return http_url, response

if port := self.config.get("api/https_port_running"):
https_url = f'https://{self.config.get("api/https_host")}:{port}'
if await _is_url_available(https_url):
return https_url
available, response = await _is_url_available(https_url + info_route)
if available:
return https_url, response

return None
return None, None

async def shutdown(self) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/tribler/test_unit/core/restapi/test_events_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def test_establish_connection(self) -> None:

self.assertEqual(200, response.status)
self.assertEqual((b'event: events_start\n'
b'data: {"public_key": "", "version": "git"}'
b'data: {"public_key": "", "version": "git", "sessions": "0"}'
b'\n\n'), request.payload_writer.captured[0])

async def test_establish_connection_with_error(self) -> None:
Expand Down

0 comments on commit 6086a1c

Please sign in to comment.