From 8bce4255addf7c675d4c10f154d9af8a0ceec406 Mon Sep 17 00:00:00 2001 From: Josef Prochazka Date: Wed, 29 Jan 2025 10:24:42 +0100 Subject: [PATCH] Use context result map for handling request handler results --- src/crawlee/_types.py | 4 +++ .../_adaptive_playwright_crawler.py | 17 +++++------ src/crawlee/crawlers/_basic/_basic_crawler.py | 29 ++++++++++--------- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index 1aa8609d77..99f798a252 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -558,3 +558,7 @@ class BasicCrawlingContext: log: logging.Logger """Logger instance.""" + + def __hash__(self) -> int: + """Return hash of the context. Each context is considered unique.""" + return id(self) diff --git a/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py b/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py index e26816e3c7..5c62883c84 100644 --- a/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py +++ b/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py @@ -319,9 +319,7 @@ async def from_pw_pipeline_to_top_router(context: PlaywrightCrawlingContext) -> ) @override - async def _run_request_handler( - self, context: BasicCrawlingContext, result: RequestHandlerRunResult - ) -> RequestHandlerRunResult: + async def _run_request_handler(self, context: BasicCrawlingContext) -> None: """Override BasicCrawler method that delegates request processing to sub crawlers. To decide which sub crawler should process the request it runs `rendering_type_predictor`. @@ -343,7 +341,8 @@ async def _run_request_handler( static_run = await self._crawl_one(rendering_type='static', context=context) if static_run.result and self.result_checker(static_run.result): - return static_run.result + self._context_result_map[context] = static_run.result + return if static_run.exception: context.log.exception( msg=f'Static crawler: failed for {context.request.url}', exc_info=static_run.exception @@ -367,7 +366,12 @@ async def _run_request_handler( pw_run = await self._crawl_one('client only', context=context) self.track_browser_request_handler_runs() + if pw_run.exception is not None: + raise pw_run.exception + if pw_run.result: + self._context_result_map[context] = pw_run.result + if should_detect_rendering_type: detection_result: RenderingType static_run = await self._crawl_one('static', context=context, state=old_state_copy) @@ -379,11 +383,6 @@ async def _run_request_handler( context.log.debug(f'Detected rendering type {detection_result} for {context.request.url}') self.rendering_type_predictor.store_result(context.request, detection_result) - return pw_run.result - if pw_run.exception is not None: - raise pw_run.exception - # Unreachable code, but mypy can't know it. - raise RuntimeError('Missing both result and exception.') def pre_navigation_hook( self, diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 4528dd9121..2e6f81365d 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -14,6 +14,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast from urllib.parse import ParseResult, urlparse +from weakref import WeakKeyDictionary from tldextract import TLDExtract from typing_extensions import NotRequired, TypedDict, TypeVar, Unpack, assert_never @@ -290,6 +291,9 @@ def __init__( self._failed_request_handler: FailedRequestHandler[TCrawlingContext | BasicCrawlingContext] | None = None self._abort_on_error = abort_on_error + # Context of each request with matching result. + self._context_result_map = WeakKeyDictionary[BasicCrawlingContext, RequestHandlerRunResult]() + # Context pipeline self._context_pipeline = (_context_pipeline or ContextPipeline()).compose(self._check_url_after_redirects) @@ -908,9 +912,9 @@ async def send_request( return send_request - async def _commit_request_handler_result( - self, context: BasicCrawlingContext, result: RequestHandlerRunResult - ) -> None: + async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> None: + result = self._context_result_map[context] + request_manager = await self.get_request_manager() origin = context.request.loaded_url or context.request.url @@ -1018,19 +1022,20 @@ async def __run_task_function(self) -> None: session = await self._get_session() proxy_info = await self._get_proxy_info(request, session) - empty_result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store) + result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store) context = BasicCrawlingContext( request=request, session=session, proxy_info=proxy_info, send_request=self._prepare_send_request_function(session, proxy_info), - add_requests=empty_result.add_requests, - push_data=empty_result.push_data, - get_key_value_store=empty_result.get_key_value_store, + add_requests=result.add_requests, + push_data=result.push_data, + get_key_value_store=result.get_key_value_store, use_state=self._use_state, log=self._logger, ) + self._context_result_map[context] = result statistics_id = request.id or request.unique_key self._statistics.record_request_processing_start(statistics_id) @@ -1039,12 +1044,11 @@ async def __run_task_function(self) -> None: request.state = RequestState.REQUEST_HANDLER try: - result = await self._run_request_handler(context=context, result=empty_result) + await self._run_request_handler(context=context) except asyncio.TimeoutError as e: raise RequestHandlerError(e, context) from e - await self._commit_request_handler_result(context, result) - + await self._commit_request_handler_result(context) await wait_for( lambda: request_manager.mark_request_as_handled(context.request), timeout=self._internal_timeout, @@ -1132,9 +1136,7 @@ async def __run_task_function(self) -> None: ) raise - async def _run_request_handler( - self, context: BasicCrawlingContext, result: RequestHandlerRunResult - ) -> RequestHandlerRunResult: + async def _run_request_handler(self, context: BasicCrawlingContext) -> None: await wait_for( lambda: self._context_pipeline(context, self.router), timeout=self._request_handler_timeout, @@ -1142,7 +1144,6 @@ async def _run_request_handler( f'{self._request_handler_timeout.total_seconds()} seconds', logger=self._logger, ) - return result def _is_session_blocked_status_code(self, session: Session | None, status_code: int) -> bool: """Check if the HTTP status code indicates that the session was blocked by the target website.