Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/crawlee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from importlib import metadata

from ._request import Request, RequestOptions
from ._request import Request, RequestOptions, RequestState
from ._service_locator import service_locator
from ._types import ConcurrencySettings, EnqueueStrategy, HttpHeaders, RequestTransformAction, SkippedReason
from ._utils.globs import Glob
Expand All @@ -14,6 +14,7 @@
'HttpHeaders',
'Request',
'RequestOptions',
'RequestState',
'RequestTransformAction',
'SkippedReason',
'service_locator',
Expand Down
4 changes: 2 additions & 2 deletions src/crawlee/_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CrawleeRequestData(BaseModel):
enqueue_strategy: Annotated[EnqueueStrategy | None, Field(alias='enqueueStrategy')] = None
"""The strategy that was used for enqueuing the request."""

state: RequestState | None = None
state: RequestState = RequestState.UNPROCESSED
"""Describes the request's current lifecycle state."""

session_rotation_count: Annotated[int | None, Field(alias='sessionRotationCount')] = None
Expand Down Expand Up @@ -352,7 +352,7 @@ def crawl_depth(self, new_value: int) -> None:
self.crawlee_data.crawl_depth = new_value

@property
def state(self) -> RequestState | None:
def state(self) -> RequestState:
"""Crawlee-specific request handling state."""
return self.crawlee_data.state

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import ValidationError
from typing_extensions import NotRequired, TypeVar

from crawlee._request import Request, RequestOptions
from crawlee._request import Request, RequestOptions, RequestState
from crawlee._utils.docs import docs_group
from crawlee._utils.time import SharedTimeout
from crawlee._utils.urls import to_absolute_url_iterator
Expand Down Expand Up @@ -257,6 +257,7 @@ async def _make_http_request(self, context: BasicCrawlingContext) -> AsyncGenera
timeout=remaining_timeout,
)

context.request.state = RequestState.AFTER_NAV
yield HttpCrawlingContext.from_basic_crawling_context(context=context, http_response=result.http_response)

async def _handle_status_code_response(
Expand Down
12 changes: 5 additions & 7 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,7 @@ async def _handle_request_retries(

await request_manager.reclaim_request(request)
else:
request.state = RequestState.ERROR
await self._mark_request_as_handled(request)
await self._handle_failed_request(context, error)
self._statistics.record_request_processing_failure(request.unique_key)
Expand All @@ -1167,8 +1168,6 @@ async def _handle_request_error(self, context: TCrawlingContext | BasicCrawlingC
f'{self._internal_timeout.total_seconds()} seconds',
logger=self._logger,
)

context.request.state = RequestState.DONE
except UserDefinedErrorHandlerError:
context.request.state = RequestState.ERROR
raise
Expand Down Expand Up @@ -1201,8 +1200,8 @@ async def _handle_skipped_request(
self, request: Request | str, reason: SkippedReason, *, need_mark: bool = False
) -> None:
if need_mark and isinstance(request, Request):
await self._mark_request_as_handled(request)
request.state = RequestState.SKIPPED
await self._mark_request_as_handled(request)

url = request.url if isinstance(request, Request) else request

Expand Down Expand Up @@ -1403,8 +1402,6 @@ async def __run_task_function(self) -> None:
self._statistics.record_request_processing_start(request.unique_key)

try:
request.state = RequestState.REQUEST_HANDLER

self._check_request_collision(context.request, context.session)

try:
Expand All @@ -1414,10 +1411,10 @@ async def __run_task_function(self) -> None:

await self._commit_request_handler_result(context)

await self._mark_request_as_handled(request)

request.state = RequestState.DONE

await self._mark_request_as_handled(request)

if context.session and context.session.is_usable:
context.session.mark_good()

Expand Down Expand Up @@ -1483,6 +1480,7 @@ async def __run_task_function(self) -> None:
raise

async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
context.request.state = RequestState.BEFORE_NAV
await self._context_pipeline(
context,
lambda final_context: wait_for(
Expand Down
3 changes: 2 additions & 1 deletion src/crawlee/crawlers/_playwright/_playwright_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing_extensions import NotRequired, TypedDict, TypeVar

from crawlee import service_locator
from crawlee._request import Request, RequestOptions
from crawlee._request import Request, RequestOptions, RequestState
from crawlee._types import (
BasicCrawlingContext,
ConcurrencySettings,
Expand Down Expand Up @@ -323,6 +323,7 @@ async def _navigate(
response = await context.page.goto(
context.request.url, timeout=remaining_timeout.total_seconds() * 1000
)
context.request.state = RequestState.AFTER_NAV
except playwright.async_api.TimeoutError as exc:
raise asyncio.TimeoutError from exc

Expand Down
2 changes: 2 additions & 0 deletions src/crawlee/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Awaitable, Callable
from typing import Generic, TypeVar

from crawlee._request import RequestState
from crawlee._types import BasicCrawlingContext
from crawlee._utils.docs import docs_group

Expand Down Expand Up @@ -89,6 +90,7 @@ def wrapper(handler: Callable[[TCrawlingContext], Awaitable]) -> Callable[[TCraw

async def __call__(self, context: TCrawlingContext) -> None:
"""Invoke a request handler that matches the request label (or the default)."""
context.request.state = RequestState.REQUEST_HANDLER
if context.request.label is None or context.request.label not in self._handlers_by_label:
if self._default_handler is None:
raise RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/crawlers/_basic/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1829,5 +1829,5 @@ async def error_handler(context: BasicCrawlingContext, error: Exception) -> Requ
assert original_request.was_already_handled

assert error_request is not None
assert error_request.state == RequestState.REQUEST_HANDLER
assert error_request.state == RequestState.DONE
assert error_request.was_already_handled
57 changes: 56 additions & 1 deletion tests/unit/crawlers/_http/test_http_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import pytest

from crawlee import ConcurrencySettings, Request
from crawlee import ConcurrencySettings, Request, RequestState
from crawlee.crawlers import HttpCrawler
from crawlee.sessions import SessionPool
from crawlee.statistics import Statistics
from crawlee.storages import RequestQueue
from tests.unit.server_endpoints import HELLO_WORLD

if TYPE_CHECKING:
Expand Down Expand Up @@ -577,3 +578,57 @@ async def request_handler(context: HttpCrawlingContext) -> None:
assert len(kvs_content) == 1
assert content_key.endswith('.html')
assert kvs_content[content_key] == HELLO_WORLD.decode('utf8')


async def test_request_state(server_url: URL) -> None:
queue = await RequestQueue.open(alias='http_request_state')
crawler = HttpCrawler(request_manager=queue)

success_request = Request.from_url(str(server_url))
assert success_request.state == RequestState.UNPROCESSED

error_request = Request.from_url(str(server_url / 'error'), user_data={'cause_error': True})

requests_states: dict[str, dict[str, RequestState]] = {success_request.unique_key: {}, error_request.unique_key: {}}

@crawler.pre_navigation_hook
async def pre_navigation_hook(context: BasicCrawlingContext) -> None:
requests_states[context.request.unique_key]['pre_navigation'] = context.request.state

@crawler.router.default_handler
async def request_handler(context: HttpCrawlingContext) -> None:
if context.request.user_data.get('cause_error'):
raise ValueError('Caused error as requested')
requests_states[context.request.unique_key]['request_handler'] = context.request.state

@crawler.error_handler
async def error_handler(context: BasicCrawlingContext, _error: Exception) -> None:
requests_states[context.request.unique_key]['error_handler'] = context.request.state

@crawler.failed_request_handler
async def failed_request_handler(context: BasicCrawlingContext, _error: Exception) -> None:
requests_states[context.request.unique_key]['failed_request_handler'] = context.request.state

await crawler.run([success_request, error_request])

handled_success_request = await queue.get_request(success_request.unique_key)

assert handled_success_request is not None
assert handled_success_request.state == RequestState.DONE

assert requests_states[success_request.unique_key] == {
'pre_navigation': RequestState.BEFORE_NAV,
'request_handler': RequestState.REQUEST_HANDLER,
}

handled_error_request = await queue.get_request(error_request.unique_key)
assert handled_error_request is not None
assert handled_error_request.state == RequestState.ERROR

assert requests_states[error_request.unique_key] == {
'pre_navigation': RequestState.BEFORE_NAV,
'error_handler': RequestState.ERROR_HANDLER,
'failed_request_handler': RequestState.ERROR,
}

await queue.drop()
55 changes: 55 additions & 0 deletions tests/unit/crawlers/_playwright/test_playwright_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Glob,
HttpHeaders,
Request,
RequestState,
RequestTransformAction,
SkippedReason,
service_locator,
Expand Down Expand Up @@ -991,3 +992,57 @@ async def test_slow_navigation_does_not_count_toward_handler_timeout(server_url:
assert result.requests_failed == 0
assert result.requests_finished == 1
assert request_handler.call_count == 1


async def test_request_state(server_url: URL) -> None:
queue = await RequestQueue.open(alias='playwright_request_state')
crawler = PlaywrightCrawler(request_manager=queue)

success_request = Request.from_url(str(server_url))
assert success_request.state == RequestState.UNPROCESSED

error_request = Request.from_url(str(server_url / 'error'), user_data={'cause_error': True})

requests_states: dict[str, dict[str, RequestState]] = {success_request.unique_key: {}, error_request.unique_key: {}}

@crawler.pre_navigation_hook
async def pre_navigation_hook(context: PlaywrightPreNavCrawlingContext) -> None:
requests_states[context.request.unique_key]['pre_navigation'] = context.request.state

@crawler.router.default_handler
async def request_handler(context: PlaywrightCrawlingContext) -> None:
if context.request.user_data.get('cause_error'):
raise ValueError('Caused error as requested')
requests_states[context.request.unique_key]['request_handler'] = context.request.state

@crawler.error_handler
async def error_handler(context: BasicCrawlingContext, _error: Exception) -> None:
requests_states[context.request.unique_key]['error_handler'] = context.request.state

@crawler.failed_request_handler
async def failed_request_handler(context: BasicCrawlingContext, _error: Exception) -> None:
requests_states[context.request.unique_key]['failed_request_handler'] = context.request.state

await crawler.run([success_request, error_request])

handled_success_request = await queue.get_request(success_request.unique_key)

assert handled_success_request is not None
assert handled_success_request.state == RequestState.DONE

assert requests_states[success_request.unique_key] == {
'pre_navigation': RequestState.BEFORE_NAV,
'request_handler': RequestState.REQUEST_HANDLER,
}

handled_error_request = await queue.get_request(error_request.unique_key)
assert handled_error_request is not None
assert handled_error_request.state == RequestState.ERROR

assert requests_states[error_request.unique_key] == {
'pre_navigation': RequestState.BEFORE_NAV,
'error_handler': RequestState.ERROR_HANDLER,
'failed_request_handler': RequestState.ERROR,
}

await queue.drop()
Loading