From 474ad67d530f743959f7e8f4a5a33cdb9b23bfa8 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Thu, 11 Jun 2026 15:33:55 +0200 Subject: [PATCH 1/4] fix: reconnect to platform events websocket after connection drop --- src/apify/events/_apify_event_manager.py | 86 ++++++++++++------- tests/unit/events/test_apify_event_manager.py | 39 ++++----- 2 files changed, 70 insertions(+), 55 deletions(-) diff --git a/src/apify/events/_apify_event_manager.py b/src/apify/events/_apify_event_manager.py index 217078758..9fa31c933 100644 --- a/src/apify/events/_apify_event_manager.py +++ b/src/apify/events/_apify_event_manager.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Annotated, Self import websockets.asyncio.client +import websockets.exceptions from pydantic import Discriminator, TypeAdapter from typing_extensions import Unpack, override @@ -91,49 +92,70 @@ async def __aexit__( exc_value: BaseException | None, exc_traceback: TracebackType | None, ) -> None: - if self._platform_events_websocket: - await self._platform_events_websocket.close() - + # Cancel the message-processing task first so that closing the websocket below is not treated + # as a dropped connection and followed by a reconnect attempt. if self._process_platform_messages_task and not self._process_platform_messages_task.done(): self._process_platform_messages_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._process_platform_messages_task + if self._platform_events_websocket: + await self._platform_events_websocket.close() + await super().__aexit__(exc_type, exc_value, exc_traceback) async def _process_platform_messages(self, ws_url: str) -> None: + def process_exception(exc: Exception) -> Exception | None: + # Until the first connection succeeds, treat every error as fatal so that `__aenter__` fails fast. + # Afterwards, treat every error as transient — the reconnect iterator keeps retrying with backoff + # so that platform events (e.g. `MIGRATING`) are not missed for the rest of the run. + if self._connected_to_platform_websocket is None or not self._connected_to_platform_websocket.done(): + return exc + return None + try: - async with websockets.asyncio.client.connect(ws_url) as websocket: + async for websocket in websockets.asyncio.client.connect(ws_url, process_exception=process_exception): self._platform_events_websocket = websocket - if self._connected_to_platform_websocket is not None: - self._connected_to_platform_websocket.set_result(True) - - async for message in websocket: - try: - parsed_message = event_data_adapter.validate_json(message) - - if isinstance(parsed_message, DeprecatedEvent): - continue - - if isinstance(parsed_message, UnknownEvent): - logger.info( - f'Unknown message received: event_name={parsed_message.name}, ' - f'event_data={parsed_message.data}' + connected_future = self._connected_to_platform_websocket + if connected_future is not None and not connected_future.done(): + connected_future.set_result(True) + + try: + async for message in websocket: + try: + parsed_message = event_data_adapter.validate_json(message) + + if isinstance(parsed_message, DeprecatedEvent): + continue + + if isinstance(parsed_message, UnknownEvent): + logger.info( + f'Unknown message received: event_name={parsed_message.name}, ' + f'event_data={parsed_message.data}' + ) + continue + + self.emit( + event=parsed_message.name, + event_data=parsed_message.data + if not isinstance(parsed_message.data, SystemInfoEventData) + else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1), ) - continue - - self.emit( - event=parsed_message.name, - event_data=parsed_message.data - if not isinstance(parsed_message.data, SystemInfoEventData) - else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1), - ) - - if parsed_message.name == Event.MIGRATING: - await self._emit_persist_state_event_rec_task.stop() - self.emit(event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True)) - except Exception: - logger.exception('Cannot parse Actor event', extra={'raw_message': message}) + + if parsed_message.name == Event.MIGRATING: + await self._emit_persist_state_event_rec_task.stop() + self.emit( + event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True) + ) + except Exception: + logger.exception('Cannot parse Actor event', extra={'raw_message': message}) + except websockets.exceptions.ConnectionClosed: + pass + + logger.warning( + f'Connection to platform events websocket was closed ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' + ) except Exception: logger.exception('Error in websocket connection') if self._connected_to_platform_websocket is not None and not self._connected_to_platform_websocket.done(): diff --git a/tests/unit/events/test_apify_event_manager.py b/tests/unit/events/test_apify_event_manager.py index 13568f43a..343314d5b 100644 --- a/tests/unit/events/test_apify_event_manager.py +++ b/tests/unit/events/test_apify_event_manager.py @@ -12,7 +12,6 @@ import pytest import websockets import websockets.asyncio.server -import websockets.exceptions from crawlee.events._types import Event @@ -320,38 +319,32 @@ def migrating_listener(data: Any) -> None: assert len(migration_persist_events) >= 1 -async def test_websocket_mid_stream_disconnect_does_not_raise_invalid_state_error( - monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture -) -> None: - """Regression: a mid-stream websocket disconnect after a successful connect must not raise InvalidStateError. - - The `_connected_to_platform_websocket` future is resolved to `True` on successful connect. If the websocket - later drops, the outer `except` in `_process_platform_messages` must not call `set_result(False)` on the - already-resolved future. - """ +async def test_websocket_reconnects_after_connection_drop(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that after a mid-stream websocket drop, the manager reconnects and keeps receiving platform events.""" async with ( _platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected), ApifyEventManager(Configuration.get_global_configuration()) as event_manager, ): await client_connected.wait() + aborting_calls: list[Any] = [] - # Force an abnormal close from the server so the client's `async for` raises ConnectionClosedError. + def listener(data: Any) -> None: + aborting_calls.append(data) + + event_manager.on(event=Event.ABORTING, listener=listener) + + # Drop the connection abnormally from the server side. + client_connected.clear() for ws in list(connected_ws_clients): await ws.close(code=1011, reason='Simulated server error') - task = event_manager._process_platform_messages_task - assert task is not None - await asyncio.wait_for(asyncio.shield(task), timeout=2.0) - - exc = task.exception() - assert not isinstance(exc, asyncio.InvalidStateError), f'Task raised InvalidStateError: {exc}' + # The event manager should reconnect on its own. + await asyncio.wait_for(client_connected.wait(), timeout=5.0) - # Confirm the test actually exercised the disconnect path — the outer `except` in - # `_process_platform_messages` should have logged a `ConnectionClosedError`. - logged_exc_types = [ - record.exc_info[0] for record in caplog.records if record.exc_info and record.exc_info[0] is not None - ] - assert any(issubclass(exc_type, websockets.exceptions.ConnectionClosedError) for exc_type in logged_exc_types) + # Events sent over the new connection must still be received. + websockets.broadcast(connected_ws_clients, json.dumps({'name': 'aborting'})) + await poll_until_condition(lambda: bool(aborting_calls), poll_interval=0.05) + assert len(aborting_calls) == 1 async def test_malformed_message_logs_exception( From fe082f179f2a95366d0ae6806b4b082c668b6d65 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Thu, 11 Jun 2026 15:56:16 +0200 Subject: [PATCH 2/4] fix: refine events websocket reconnect error handling and logging After the first successful connection, delegate to the default `websockets` transient/fatal classification instead of retrying every error. Log graceful and abnormal closes (with close code and reason) as well as reconnect success, and cover both close paths with parametrized tests. --- src/apify/events/_apify_event_manager.py | 50 +++++++----- tests/unit/events/test_apify_event_manager.py | 78 ++++++++++++------- 2 files changed, 81 insertions(+), 47 deletions(-) diff --git a/src/apify/events/_apify_event_manager.py b/src/apify/events/_apify_event_manager.py index 9fa31c933..8ded83340 100644 --- a/src/apify/events/_apify_event_manager.py +++ b/src/apify/events/_apify_event_manager.py @@ -92,8 +92,8 @@ async def __aexit__( exc_value: BaseException | None, exc_traceback: TracebackType | None, ) -> None: - # Cancel the message-processing task first so that closing the websocket below is not treated - # as a dropped connection and followed by a reconnect attempt. + # Cancel the task before closing the websocket so that the closed connection is not treated as a drop + # and followed by a reconnect attempt. if self._process_platform_messages_task and not self._process_platform_messages_task.done(): self._process_platform_messages_task.cancel() with contextlib.suppress(asyncio.CancelledError): @@ -104,21 +104,28 @@ async def __aexit__( await super().__aexit__(exc_type, exc_value, exc_traceback) - async def _process_platform_messages(self, ws_url: str) -> None: - def process_exception(exc: Exception) -> Exception | None: - # Until the first connection succeeds, treat every error as fatal so that `__aenter__` fails fast. - # Afterwards, treat every error as transient — the reconnect iterator keeps retrying with backoff - # so that platform events (e.g. `MIGRATING`) are not missed for the rest of the run. - if self._connected_to_platform_websocket is None or not self._connected_to_platform_websocket.done(): - return exc - return None + def _process_connection_exception(self, exc: Exception) -> Exception | None: + """Decide whether a failed connection attempt to the platform websocket should be retried. + + Before the first successful connection, every error is fatal so that `__aenter__` fails fast. After that, + the default `websockets` behavior decides which errors are transient and retried with exponential backoff. + """ + if self._connected_to_platform_websocket and self._connected_to_platform_websocket.done(): + return websockets.asyncio.client.process_exception(exc) + return exc + async def _process_platform_messages(self, ws_url: str) -> None: try: - async for websocket in websockets.asyncio.client.connect(ws_url, process_exception=process_exception): + # Used as an async iterator, `connect` reconnects with exponential backoff whenever a connection + # attempt fails with a transient error. + async for websocket in websockets.asyncio.client.connect( + ws_url, process_exception=self._process_connection_exception + ): self._platform_events_websocket = websocket - connected_future = self._connected_to_platform_websocket - if connected_future is not None and not connected_future.done(): - connected_future.set_result(True) + if self._connected_to_platform_websocket and not self._connected_to_platform_websocket.done(): + self._connected_to_platform_websocket.set_result(True) + else: + logger.info('Reconnected to the platform events websocket.') try: async for message in websocket: @@ -150,12 +157,15 @@ def process_exception(exc: Exception) -> Exception | None: except Exception: logger.exception('Cannot parse Actor event', extra={'raw_message': message}) except websockets.exceptions.ConnectionClosed: - pass - - logger.warning( - f'Connection to platform events websocket was closed ' - f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' - ) + logger.warning( + f'Connection to platform events websocket was lost ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' + ) + else: + logger.info( + f'Connection to platform events websocket was closed ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' + ) except Exception: logger.exception('Error in websocket connection') if self._connected_to_platform_websocket is not None and not self._connected_to_platform_websocket.done(): diff --git a/tests/unit/events/test_apify_event_manager.py b/tests/unit/events/test_apify_event_manager.py index 343314d5b..3e80eecb0 100644 --- a/tests/unit/events/test_apify_event_manager.py +++ b/tests/unit/events/test_apify_event_manager.py @@ -25,6 +25,18 @@ from collections.abc import AsyncGenerator, Callable +DUMMY_SYSTEM_INFO = { + 'memAvgBytes': 19328860.328293584, + 'memCurrentBytes': 65171456, + 'memMaxBytes': 65171456, + 'cpuAvgUsage': 2.0761105633130397, + 'cpuMaxUsage': 53.941134593993326, + 'cpuCurrentUsage': 8.45549815498155, + 'isCpuOverloaded': False, + 'createdAt': '2024-08-09T16:04:16.161Z', +} + + @contextlib.asynccontextmanager async def _platform_ws_server( monkeypatch: pytest.MonkeyPatch, @@ -188,17 +200,7 @@ async def send_platform_event(event_name: Event, data: Any = None) -> None: websockets.broadcast(connected_ws_clients, json.dumps(message)) - dummy_system_info = { - 'memAvgBytes': 19328860.328293584, - 'memCurrentBytes': 65171456, - 'memMaxBytes': 65171456, - 'cpuAvgUsage': 2.0761105633130397, - 'cpuMaxUsage': 53.941134593993326, - 'cpuCurrentUsage': 8.45549815498155, - 'isCpuOverloaded': False, - 'createdAt': '2024-08-09T16:04:16.161Z', - } - SystemInfoEventData.model_validate(dummy_system_info) + SystemInfoEventData.model_validate(DUMMY_SYSTEM_INFO) async with ApifyEventManager(Configuration.get_global_configuration()) as event_manager: await client_connected.wait() @@ -210,7 +212,7 @@ def listener(data: Any) -> None: event_manager.on(event=Event.SYSTEM_INFO, listener=listener) # Test sending event with data - await send_platform_event(Event.SYSTEM_INFO, dummy_system_info) + await send_platform_event(Event.SYSTEM_INFO, DUMMY_SYSTEM_INFO) await poll_until_condition(lambda: len(event_calls) == 1, poll_interval=0.05) assert len(event_calls) == 1 assert event_calls[0] is not None @@ -319,32 +321,54 @@ def migrating_listener(data: Any) -> None: assert len(migration_persist_events) >= 1 -async def test_websocket_reconnects_after_connection_drop(monkeypatch: pytest.MonkeyPatch) -> None: - """Test that after a mid-stream websocket drop, the manager reconnects and keeps receiving platform events.""" +@pytest.mark.parametrize( + ('close_code', 'expected_log'), + [ + pytest.param(1000, 'Connection to platform events websocket was closed (code=1000', id='graceful_close'), + pytest.param(1011, 'Connection to platform events websocket was lost (code=1011', id='abnormal_close'), + ], +) +async def test_websocket_reconnects_after_connection_drop( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture, close_code: int, expected_log: str +) -> None: + """Test that the event manager logs a websocket drop, reconnects, and keeps receiving platform events. + + Also a regression test for the resolved `_connected_to_platform_websocket` future: a mid-stream disconnect + must not kill the message-processing task with `InvalidStateError`. + """ + caplog.set_level(logging.INFO, logger='apify') async with ( _platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected), ApifyEventManager(Configuration.get_global_configuration()) as event_manager, ): await client_connected.wait() - aborting_calls: list[Any] = [] - - def listener(data: Any) -> None: - aborting_calls.append(data) + assert len(connected_ws_clients) == 1 - event_manager.on(event=Event.ABORTING, listener=listener) + event_calls: list[Any] = [] + event_manager.on(event=Event.SYSTEM_INFO, listener=event_calls.append) - # Drop the connection abnormally from the server side. + # Drop the connection from the server side and wait for the client to reconnect. client_connected.clear() for ws in list(connected_ws_clients): - await ws.close(code=1011, reason='Simulated server error') + await ws.close(code=close_code, reason='Simulated connection drop') + await asyncio.wait_for(client_connected.wait(), timeout=10) + # Poll because the old server-side handler may not have deregistered its connection yet. + await poll_until_condition(lambda: len(connected_ws_clients) == 1, poll_interval=0.05) + assert len(connected_ws_clients) == 1 + + # The message-processing task must have survived the drop. + task = event_manager._process_platform_messages_task + assert task is not None + assert not task.done() - # The event manager should reconnect on its own. - await asyncio.wait_for(client_connected.wait(), timeout=5.0) + # Events sent over the new connection must still be emitted. + websockets.broadcast(connected_ws_clients, json.dumps({'name': 'systemInfo', 'data': DUMMY_SYSTEM_INFO})) + await poll_until_condition(lambda: len(event_calls) == 1, poll_interval=0.05) + assert len(event_calls) == 1 - # Events sent over the new connection must still be received. - websockets.broadcast(connected_ws_clients, json.dumps({'name': 'aborting'})) - await poll_until_condition(lambda: bool(aborting_calls), poll_interval=0.05) - assert len(aborting_calls) == 1 + # Both the drop and the successful reconnect must be logged. + assert expected_log in caplog.text + assert 'Reconnected to the platform events websocket.' in caplog.text async def test_malformed_message_logs_exception( From 94d757dfa1d52e4fef693f58c8d343a4c5b154e8 Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Sat, 13 Jun 2026 11:11:38 +0200 Subject: [PATCH 3/4] fix: classify non-retryable events websocket close codes and back off reconnects --- src/apify/events/_apify_event_manager.py | 45 ++++ tests/unit/events/test_apify_event_manager.py | 209 +++++++++++++++++- 2 files changed, 253 insertions(+), 1 deletion(-) diff --git a/src/apify/events/_apify_event_manager.py b/src/apify/events/_apify_event_manager.py index e24eb4c24..0f32c155c 100644 --- a/src/apify/events/_apify_event_manager.py +++ b/src/apify/events/_apify_event_manager.py @@ -2,9 +2,11 @@ import asyncio import contextlib +import time from typing import TYPE_CHECKING, Annotated, Self import websockets.asyncio.client +import websockets.client import websockets.exceptions from pydantic import Discriminator, TypeAdapter from typing_extensions import Unpack, override @@ -17,6 +19,7 @@ from apify.log import logger if TYPE_CHECKING: + from collections.abc import Generator from types import TracebackType from crawlee.events._event_manager import EventManagerOptions @@ -46,6 +49,17 @@ class ApifyEventManager(EventManager): with the event system. """ + _NON_RETRYABLE_CLOSE_CODES = frozenset({1002, 1003, 1007, 1008, 1010}) + """WebSocket close codes that signal a permanent condition, so the connection should not be re-established. + + Covers the protocol and data errors (`1002`, `1003`, `1007`), a mandatory extension failure (`1010`), and a + policy violation (`1008`). The platform sends `1008` for an unknown or missing run ID, or when the per-run + websocket connection limit is exceeded; reconnecting on any of these would fail in exactly the same way. + """ + + _HEALTHY_CONNECTION_MIN_DURATION = 1.0 + """Seconds a connection must stay open to count as healthy, after which a drop reconnects without backoff.""" + def __init__(self, configuration: Configuration, **kwargs: Unpack[EventManagerOptions]) -> None: """Initialize a new instance. @@ -117,6 +131,13 @@ def _process_connection_exception(self, exc: Exception) -> Exception | None: return exc async def _process_platform_messages(self, ws_url: str) -> None: + # Backoff between reconnects after an established connection is closed by the server. The `websockets` + # reconnect iterator only backs off on failed connection *attempts*, not on a connection that opens and is + # then closed, so a server that keeps accepting and immediately closing would otherwise be hammered. The + # generator is reset after any connection that stayed open long enough to count as healthy, so a healthy + # connection that drops reconnects immediately without missing platform events. + backoff_delays: Generator[float] | None = None + try: # Used as an async iterator, `connect` reconnects with exponential backoff whenever a connection # attempt fails with a transient error. @@ -129,6 +150,8 @@ async def _process_platform_messages(self, ws_url: str) -> None: else: logger.info('Reconnected to the platform events websocket.') + connection_opened_at = time.monotonic() + connection_lost = False try: async for message in websocket: try: @@ -159,6 +182,17 @@ async def _process_platform_messages(self, ws_url: str) -> None: except Exception: logger.exception('Cannot parse Actor event', extra={'raw_message': message}) except websockets.exceptions.ConnectionClosed: + connection_lost = True + + # Stop reconnecting on a permanent close code; otherwise the loop would reconnect forever. + if websocket.close_code in self._NON_RETRYABLE_CLOSE_CODES: + logger.error( + f'Connection to platform events websocket was closed with a non-retryable code ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}); not reconnecting.' + ) + break + + if connection_lost: logger.warning( f'Connection to platform events websocket was lost ' f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' @@ -168,6 +202,17 @@ async def _process_platform_messages(self, ws_url: str) -> None: f'Connection to platform events websocket was closed ' f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' ) + + # Reconnect a connection that stayed up long enough (including a one-off drop) immediately so platform + # events are not missed. Back off only when connections keep dropping too quickly, so a server that + # accepts and then immediately closes is not hammered. + if time.monotonic() - connection_opened_at >= self._HEALTHY_CONNECTION_MIN_DURATION: + backoff_delays = None + continue + if backoff_delays is None: + backoff_delays = websockets.client.backoff() + continue + await asyncio.sleep(next(backoff_delays)) except Exception: logger.exception('Error in websocket connection') if self._connected_to_platform_websocket is not None and not self._connected_to_platform_websocket.done(): diff --git a/tests/unit/events/test_apify_event_manager.py b/tests/unit/events/test_apify_event_manager.py index 095233fed..eb8118c31 100644 --- a/tests/unit/events/test_apify_event_manager.py +++ b/tests/unit/events/test_apify_event_manager.py @@ -4,6 +4,8 @@ import contextlib import json import logging +import socket +import types from collections import defaultdict from datetime import timedelta from typing import TYPE_CHECKING, Any @@ -22,7 +24,7 @@ from apify.events._types import SystemInfoEventData if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Callable + from collections.abc import AsyncGenerator, Awaitable, Callable DUMMY_SYSTEM_INFO = { @@ -68,6 +70,74 @@ async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None yield connected_ws_clients, client_connected +@contextlib.asynccontextmanager +async def _restartable_ws_server( + monkeypatch: pytest.MonkeyPatch, + *, + on_connect: Callable[[websockets.asyncio.server.ServerConnection], Awaitable[None]] | None = None, +) -> AsyncGenerator[Any]: + """A local `127.0.0.1` WebSocket server that can be stopped/restarted and counts connection attempts. + + Binds to a fixed free port (reserved up front) so a restart can reuse the same address, letting a test simulate the + platform server going away and coming back. Yields a control namespace with `live_clients`, a re-armable + `client_connected` event, a cumulative `attempts()` counter, and `stop()` / `start()` coroutines. Pass `on_connect` + to take over a freshly accepted connection (e.g. immediately close it with a chosen code). + """ + # Reserve a fixed free port so a restart can re-serve on the same address. + probe = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + probe.bind(('127.0.0.1', 0)) + port = probe.getsockname()[1] + probe.close() + + live_clients: set[websockets.asyncio.server.ServerConnection] = set() + client_connected = asyncio.Event() + attempts = 0 + server_holder: dict[str, Any] = {'srv': None} + + async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None: + nonlocal attempts + attempts += 1 + if on_connect is not None: + await on_connect(websocket) + return + live_clients.add(websocket) + client_connected.set() + try: + await websocket.wait_closed() + finally: + live_clients.discard(websocket) + + async def _serve() -> None: + server_holder['srv'] = await websockets.asyncio.server.serve(handler, host='127.0.0.1', port=port) + + async def stop() -> None: + srv = server_holder['srv'] + if srv is not None: + srv.close() + await srv.wait_closed() + server_holder['srv'] = None + # Drop any live connection so the client is forced into reconnect mode. + for websocket in list(live_clients): + await websocket.close() + + async def start() -> None: + await asyncio.sleep(0.3) # Give the OS a moment to release the port before re-serving. + await _serve() + + monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://127.0.0.1:{port}') + await _serve() + try: + yield types.SimpleNamespace( + live_clients=live_clients, + client_connected=client_connected, + attempts=lambda: attempts, + stop=stop, + start=start, + ) + finally: + await stop() + + async def test_lifecycle_local(caplog: pytest.LogCaptureFixture) -> None: caplog.set_level(logging.DEBUG, logger='apify') @@ -376,6 +446,143 @@ async def test_websocket_reconnects_after_connection_drop( assert 'Reconnected to the platform events websocket.' in caplog.text +async def test_non_retryable_close_stops_reconnecting( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + """Test that a non-retryable close code (1008) stops reconnection instead of looping forever.""" + caplog.set_level(logging.ERROR, logger='apify') + + async def close_with_policy_violation(websocket: websockets.asyncio.server.ServerConnection) -> None: + await websocket.close(code=1008, reason='policy violation') + + async with ( + _restartable_ws_server(monkeypatch, on_connect=close_with_policy_violation) as server, + ApifyEventManager(Configuration.get_global_configuration()) as event_manager, + ): + task = event_manager._process_platform_messages_task + assert task is not None + + # After a non-retryable close the processing task must give up rather than reconnect forever. + await poll_until_condition(task.done, poll_interval=0.05) + assert task.done() + assert server.attempts() <= 5, f'reconnected after a non-retryable close: {server.attempts()} attempts' + + assert 'non-retryable code' in caplog.text + + +async def test_rapid_retryable_close_backs_off( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + """Test that repeated retryable closes are retried with backoff instead of a tight reconnect loop.""" + caplog.set_level(logging.WARNING, logger='apify') + + async def close_with_internal_error(websocket: websockets.asyncio.server.ServerConnection) -> None: + await websocket.close(code=1011, reason='internal error') + + async with ( + _restartable_ws_server(monkeypatch, on_connect=close_with_internal_error) as server, + ApifyEventManager(Configuration.get_global_configuration()) as event_manager, + ): + task = event_manager._process_platform_messages_task + assert task is not None + + # Without backoff a tight loop would make thousands of attempts in this window; backoff keeps it tiny. + await asyncio.sleep(2) + assert not task.done() + attempts = server.attempts() + + assert 0 < attempts <= 15, f'client busy-looped on a retryable close: {attempts} attempts in 2s' + assert 'was lost (code=1011' in caplog.text + + +async def test_rapid_retryable_close_after_event_backs_off(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that a server that delivers an event before each retryable close is still retried with backoff.""" + + async def send_event_then_close(websocket: websockets.asyncio.server.ServerConnection) -> None: + await websocket.send(json.dumps({'name': 'systemInfo', 'data': DUMMY_SYSTEM_INFO})) + await websocket.close(code=1011, reason='internal error') + + async with ( + _restartable_ws_server(monkeypatch, on_connect=send_event_then_close) as server, + ApifyEventManager(Configuration.get_global_configuration()) as event_manager, + ): + task = event_manager._process_platform_messages_task + assert task is not None + + # A short-lived connection must back off even though it delivered an event, or it would busy-loop. + await asyncio.sleep(2) + assert not task.done() + attempts = server.attempts() + + assert 0 < attempts <= 15, f'client busy-looped after a message-bearing close: {attempts} attempts in 2s' + + +async def test_reconnects_after_server_becomes_unreachable( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + """Test that the client survives a server outage, keeps retrying, and resumes events once the server returns.""" + caplog.set_level(logging.INFO, logger='apify') + + async with ( + _restartable_ws_server(monkeypatch) as server, + ApifyEventManager(Configuration.get_global_configuration()) as event_manager, + ): + await asyncio.wait_for(server.client_connected.wait(), timeout=10) + assert len(server.live_clients) == 1 + + event_calls: list[Any] = [] + event_manager.on(event=Event.SYSTEM_INFO, listener=event_calls.append) + + # Take the server down and drop the live connection: every reconnect attempt now hits connection-refused. + server.client_connected.clear() + await server.stop() + task = event_manager._process_platform_messages_task + assert task is not None + + # During the outage the task must keep retrying instead of crashing or exiting. + await asyncio.sleep(1) + assert not task.done() + + # Bring the server back on the same port; the client must reconnect within a bounded time. + await server.start() + await asyncio.wait_for(server.client_connected.wait(), timeout=10) + await poll_until_condition(lambda: len(server.live_clients) == 1, poll_interval=0.05) + assert len(server.live_clients) == 1 + + # Events sent over the recovered connection must still be delivered. + websockets.broadcast(server.live_clients, json.dumps({'name': 'systemInfo', 'data': DUMMY_SYSTEM_INFO})) + await poll_until_condition(lambda: len(event_calls) == 1, poll_interval=0.05) + assert len(event_calls) == 1 + + assert 'Reconnected to the platform events websocket.' in caplog.text + + +async def test_shutdown_during_reconnect_backoff_is_clean(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that exiting the event manager while it is mid-reconnect (server down) shuts down cleanly.""" + async with _restartable_ws_server(monkeypatch) as server: + event_manager = ApifyEventManager(Configuration.get_global_configuration()) + async with event_manager: + await asyncio.wait_for(server.client_connected.wait(), timeout=10) + assert len(server.live_clients) == 1 + + # Force the client into reconnect/backoff: server down, live connection dropped. + server.client_connected.clear() + await server.stop() + task = event_manager._process_platform_messages_task + assert task is not None + await asyncio.sleep(0.5) + assert not task.done() + # __aexit__ runs here, while the client is between reconnect attempts. + + # The processing task must be finished and cancelled, not crashed with a stray error. + assert task.done() + assert task.cancelled() or task.exception() is None + assert event_manager.active is False + # The parent recurring persist-state task must be stopped too, mirroring the failed-connect lifecycle test. + persist_state_task = event_manager._emit_persist_state_event_rec_task.task + assert persist_state_task is None or persist_state_task.done() + + async def test_malformed_message_logs_exception( monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture ) -> None: From 3aa02e05fb0bcbd9a523aca5315d74121acfe99d Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Sat, 13 Jun 2026 11:19:22 +0200 Subject: [PATCH 4/4] refactor: split platform events websocket processing into helper methods --- src/apify/events/_apify_event_manager.py | 145 ++++++++++++----------- 1 file changed, 77 insertions(+), 68 deletions(-) diff --git a/src/apify/events/_apify_event_manager.py b/src/apify/events/_apify_event_manager.py index 0f32c155c..a9117d147 100644 --- a/src/apify/events/_apify_event_manager.py +++ b/src/apify/events/_apify_event_manager.py @@ -50,11 +50,11 @@ class ApifyEventManager(EventManager): """ _NON_RETRYABLE_CLOSE_CODES = frozenset({1002, 1003, 1007, 1008, 1010}) - """WebSocket close codes that signal a permanent condition, so the connection should not be re-established. + """WebSocket close codes for a permanent condition, on which the connection is not re-established. - Covers the protocol and data errors (`1002`, `1003`, `1007`), a mandatory extension failure (`1010`), and a - policy violation (`1008`). The platform sends `1008` for an unknown or missing run ID, or when the per-run - websocket connection limit is exceeded; reconnecting on any of these would fail in exactly the same way. + The platform sends `1008` (policy violation) for an unknown/missing run ID or an exceeded per-run + connection limit. `1002`, `1003`, and `1007` are protocol or data errors, and `1010` a mandatory + extension failure. """ _HEALTHY_CONNECTION_MIN_DURATION = 1.0 @@ -131,16 +131,13 @@ def _process_connection_exception(self, exc: Exception) -> Exception | None: return exc async def _process_platform_messages(self, ws_url: str) -> None: - # Backoff between reconnects after an established connection is closed by the server. The `websockets` - # reconnect iterator only backs off on failed connection *attempts*, not on a connection that opens and is - # then closed, so a server that keeps accepting and immediately closing would otherwise be hammered. The - # generator is reset after any connection that stayed open long enough to count as healthy, so a healthy - # connection that drops reconnects immediately without missing platform events. + # The `websockets` reconnect iterator only backs off on failed connection *attempts*, not on a connection + # that opens and is then closed. Track our own backoff here so a server that keeps accepting and immediately + # closing is not hammered; it is reset after a healthy connection so a healthy drop reconnects immediately. backoff_delays: Generator[float] | None = None try: - # Used as an async iterator, `connect` reconnects with exponential backoff whenever a connection - # attempt fails with a transient error. + # Used as an async iterator, `connect` reconnects with exponential backoff on failed connection attempts. async for websocket in websockets.asyncio.client.connect( ws_url, process_exception=self._process_connection_exception ): @@ -151,69 +148,81 @@ async def _process_platform_messages(self, ws_url: str) -> None: logger.info('Reconnected to the platform events websocket.') connection_opened_at = time.monotonic() - connection_lost = False - try: - async for message in websocket: - try: - parsed_message = event_data_adapter.validate_json(message) - - if isinstance(parsed_message, DeprecatedEvent): - continue - - if isinstance(parsed_message, UnknownEvent): - logger.info( - f'Unknown message received: event_name={parsed_message.name}, ' - f'event_data={parsed_message.data}' - ) - continue - - self.emit( - event=parsed_message.name, - event_data=parsed_message.data - if not isinstance(parsed_message.data, SystemInfoEventData) - else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1), - ) - - if parsed_message.name == Event.MIGRATING: - await self._emit_persist_state_event_rec_task.stop() - self.emit( - event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True) - ) - except Exception: - logger.exception('Cannot parse Actor event', extra={'raw_message': message}) - except websockets.exceptions.ConnectionClosed: - connection_lost = True - - # Stop reconnecting on a permanent close code; otherwise the loop would reconnect forever. - if websocket.close_code in self._NON_RETRYABLE_CLOSE_CODES: - logger.error( - f'Connection to platform events websocket was closed with a non-retryable code ' - f'(code={websocket.close_code}, reason={websocket.close_reason!r}); not reconnecting.' - ) + connection_lost = await self._consume_messages(websocket) + + if not self._should_reconnect_after_close(websocket, connection_lost=connection_lost): break - if connection_lost: - logger.warning( - f'Connection to platform events websocket was lost ' - f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' - ) - else: - logger.info( - f'Connection to platform events websocket was closed ' - f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' - ) - - # Reconnect a connection that stayed up long enough (including a one-off drop) immediately so platform - # events are not missed. Back off only when connections keep dropping too quickly, so a server that - # accepts and then immediately closes is not hammered. + # Reconnect a healthy connection immediately; back off only on repeated rapid drops. if time.monotonic() - connection_opened_at >= self._HEALTHY_CONNECTION_MIN_DURATION: backoff_delays = None - continue - if backoff_delays is None: + elif backoff_delays is None: backoff_delays = websockets.client.backoff() - continue - await asyncio.sleep(next(backoff_delays)) + else: + await asyncio.sleep(next(backoff_delays)) except Exception: logger.exception('Error in websocket connection') if self._connected_to_platform_websocket is not None and not self._connected_to_platform_websocket.done(): self._connected_to_platform_websocket.set_result(False) + + async def _consume_messages(self, websocket: websockets.asyncio.client.ClientConnection) -> bool: + """Handle platform messages until the connection closes; return whether it was lost vs. closed cleanly.""" + try: + async for message in websocket: + await self._handle_platform_message(message) + except websockets.exceptions.ConnectionClosed: + return True + return False + + async def _handle_platform_message(self, message: str | bytes) -> None: + """Parse a single platform message and emit the matching local event.""" + try: + parsed_message = event_data_adapter.validate_json(message) + + if isinstance(parsed_message, DeprecatedEvent): + return + + if isinstance(parsed_message, UnknownEvent): + logger.info( + f'Unknown message received: event_name={parsed_message.name}, event_data={parsed_message.data}' + ) + return + + self.emit( + event=parsed_message.name, + event_data=parsed_message.data + if not isinstance(parsed_message.data, SystemInfoEventData) + else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1), + ) + + if parsed_message.name == Event.MIGRATING: + await self._emit_persist_state_event_rec_task.stop() + self.emit(event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True)) + except Exception: + logger.exception('Cannot parse Actor event', extra={'raw_message': message}) + + def _should_reconnect_after_close( + self, + websocket: websockets.asyncio.client.ClientConnection, + *, + connection_lost: bool, + ) -> bool: + """Log the websocket close and report whether to reconnect (`False` on a non-retryable close code).""" + if websocket.close_code in self._NON_RETRYABLE_CLOSE_CODES: + logger.error( + f'Connection to platform events websocket was closed with a non-retryable code ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}); not reconnecting.' + ) + return False + + if connection_lost: + logger.warning( + f'Connection to platform events websocket was lost ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' + ) + else: + logger.info( + f'Connection to platform events websocket was closed ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' + ) + return True