diff --git a/src/tribler/core/components.py b/src/tribler/core/components.py index 6282996dbe..42352d5092 100644 --- a/src/tribler/core/components.py +++ b/src/tribler/core/components.py @@ -9,6 +9,7 @@ from ipv8.loader import CommunityLauncher, after, kwargs, overlay, precondition, set_in_session, walk_strategy from ipv8.overlay import Overlay, SettingsClass from ipv8.peerdiscovery.discovery import DiscoveryStrategy, RandomWalk +from ipv8.taskmanager import TaskManager if TYPE_CHECKING: from ipv8.bootstrapping.bootstrapper_interface import Bootstrapper @@ -319,17 +320,19 @@ def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None: @after("ContentDiscoveryComponent", "TorrentCheckerComponent") @precondition('session.config.get("user_activity/enabled")') -class UserActivityComponent(ComponentLauncher): +@overlay("tribler.core.user_activity.community", "UserActivityCommunity") +class UserActivityComponent(BaseLauncher): """ Launch instructions for the user activity community. """ - def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None: + def get_kwargs(self, session: Session) -> dict: """ - When we are done launching, start listening for GUI events. + Create and forward the rendezvous database for the Community. """ from tribler.core.user_activity.manager import UserActivityManager - component = cast(Component, community) + out = super().get_kwargs(session) max_query_history = session.config.get("user_activity/max_query_history") - component.settings.manager = UserActivityManager(component, session, max_query_history) + out["manager"] = UserActivityManager(TaskManager(), session, max_query_history) + return out diff --git a/src/tribler/core/database/restapi/database_endpoint.py b/src/tribler/core/database/restapi/database_endpoint.py index e6ef233d9b..d6e3864207 100644 --- a/src/tribler/core/database/restapi/database_endpoint.py +++ b/src/tribler/core/database/restapi/database_endpoint.py @@ -295,7 +295,7 @@ def search_db() -> tuple[list[dict], int, int]: total = max_rowid = None if self.download_manager is not None: self.download_manager.notifier.notify(Notification.local_query_results, data={ - "query": request.query.get("txt_filter"), + "query": request.query.get("fts_text"), "results": list(pony_query) }) return search_results, total, max_rowid diff --git a/src/tribler/core/session.py b/src/tribler/core/session.py index 2a30940112..e21673bff5 100644 --- a/src/tribler/core/session.py +++ b/src/tribler/core/session.py @@ -16,6 +16,7 @@ RendezvousComponent, TorrentCheckerComponent, TunnelComponent, + UserActivityComponent, ) from tribler.core.libtorrent.download_manager.download_manager import DownloadManager from tribler.core.libtorrent.restapi.create_torrent_endpoint import CreateTorrentEndpoint @@ -98,7 +99,7 @@ def register_launchers(self) -> None: Register all IPv8 launchers that allow communities to be loaded. """ for launcher_class in [ContentDiscoveryComponent, DatabaseComponent, DHTDiscoveryComponent, KnowledgeComponent, - RendezvousComponent, TorrentCheckerComponent, TunnelComponent]: + RendezvousComponent, TorrentCheckerComponent, TunnelComponent, UserActivityComponent]: instance = launcher_class() for rest_ep in instance.get_endpoints(): self.rest_manager.add_endpoint(rest_ep) diff --git a/src/tribler/core/user_activity/community.py b/src/tribler/core/user_activity/community.py index 96eaeab9e6..54ef94bf6c 100644 --- a/src/tribler/core/user_activity/community.py +++ b/src/tribler/core/user_activity/community.py @@ -43,6 +43,13 @@ def __init__(self, settings: UserActivitySettings) -> None: self.register_task("Gossip random preference", self.gossip, interval=5.0) + async def unload(self) -> None: + """ + Unload our activity manager. + """ + await self.composition.manager.task_manager.shutdown_task_manager() + await super().unload() + def gossip(self, receivers: list[Peer] | None = None) -> None: """ Select a random database entry and send it to a random peer. diff --git a/src/tribler/test_unit/core/user_activity/test_community.py b/src/tribler/test_unit/core/user_activity/test_community.py index bd6cd7d61d..c4a60e5396 100644 --- a/src/tribler/test_unit/core/user_activity/test_community.py +++ b/src/tribler/test_unit/core/user_activity/test_community.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock, call +from unittest.mock import AsyncMock, Mock, call from ipv8.test.base import TestBase from ipv8.test.mocking.endpoint import MockEndpointListener @@ -31,8 +31,8 @@ async def test_gossip_aggregate(self) -> None: """ Test if valid aggregates are gossiped to a random connected peer. """ - self.overlay(0).composition.manager = Mock() - self.overlay(1).composition.manager = Mock() + self.overlay(0).composition.manager = AsyncMock() + self.overlay(1).composition.manager = AsyncMock() self.database_manager(0).get_random_query_aggregate = Mock(return_value=( "test", [b"\x00" * 20, b"\x01" * 20], [1.0, 2.0] )) @@ -54,8 +54,8 @@ async def test_gossip_no_aggregate(self) -> None: """ Test if missing aggregates are not gossiped. """ - self.overlay(0).composition.manager = Mock() - self.overlay(1).composition.manager = Mock() + self.overlay(0).composition.manager = AsyncMock() + self.overlay(1).composition.manager = AsyncMock() self.database_manager(0).get_random_query_aggregate = Mock(return_value=None) with self.assertReceivedBy(1, []): @@ -67,9 +67,9 @@ async def test_gossip_target_peer(self) -> None: Test if gossip can be sent to a target peer. """ self.add_node_to_experiment(self.create_node()) - self.overlay(0).composition.manager = Mock() - self.overlay(1).composition.manager = Mock() - self.overlay(2).composition.manager = Mock() + self.overlay(0).composition.manager = AsyncMock() + self.overlay(1).composition.manager = AsyncMock() + self.overlay(2).composition.manager = AsyncMock() self.database_manager(0).get_random_query_aggregate = Mock(return_value=( "test", [b"\x00" * 20, b"\x01" * 20], [1.0, 2.0] )) @@ -91,8 +91,8 @@ async def test_pull_known_crawler(self) -> None: """ Test if a known crawler is allowed to crawl. """ - self.overlay(0).composition.manager = Mock() - self.overlay(1).composition.manager = Mock() + self.overlay(0).composition.manager = AsyncMock() + self.overlay(1).composition.manager = AsyncMock() self.overlay(1).composition.crawler_mid = self.mid(0) self.database_manager(1).get_random_query_aggregate = Mock(return_value=( "test", [b"\x00" * 20, b"\x01" * 20], [1.0, 2.0] @@ -112,8 +112,8 @@ async def test_pull_unknown_crawler(self) -> None: """ Test if an unknown crawler does not receive any information. """ - self.overlay(0).composition.manager = Mock() - self.overlay(1).composition.manager = Mock() + self.overlay(0).composition.manager = AsyncMock() + self.overlay(1).composition.manager = AsyncMock() self.overlay(1).composition.crawler_mid = bytes(b ^ 255 for b in self.mid(0)) self.database_manager(1).get_random_query_aggregate = Mock(return_value=( "test", [b"\x00" * 20, b"\x01" * 20], [1.0, 2.0] @@ -128,9 +128,9 @@ async def test_pull_replay_attack(self) -> None: Test if an unknown crawler does not receive any information. """ self.add_node_to_experiment(self.create_node()) - self.overlay(0).composition.manager = Mock() - self.overlay(1).composition.manager = Mock() - self.overlay(2).composition.manager = Mock() + self.overlay(0).composition.manager = AsyncMock() + self.overlay(1).composition.manager = AsyncMock() + self.overlay(2).composition.manager = AsyncMock() self.overlay(1).composition.crawler_mid = self.mid(0) self.overlay(2).composition.crawler_mid = self.mid(0) self.database_manager(1).get_random_query_aggregate = Mock(return_value=(