diff --git a/CHANGELOG.md b/CHANGELOG.md index f002d37d8..6a6c3a3c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ to include examples, links to docs, or any other relevant information. signaled event is attached to the caller workflow's Nexus operation history event. This makes the caller and callee mutually navigable in the UI for signal-based Nexus operations. - Exposed `backoff_start_interval` for continue-as-new, to allow the new workflow to start after a delay. +- Added an optional `signal_shutdown` argument to `Worker.run`. When set (e.g. + `signal_shutdown=[signal.SIGTERM, signal.SIGINT]`), the listed OS signals initiate a graceful + worker shutdown. Handlers are installed for the duration of the call and removed on exit. ### Changed diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 2ad1d42c6..0c0ff6103 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -4,11 +4,13 @@ import asyncio import concurrent.futures +import contextlib import hashlib import logging +import signal import sys import warnings -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import Awaitable, Callable, Iterator, Sequence from dataclasses import dataclass from datetime import timedelta from typing import ( @@ -739,7 +741,11 @@ def is_shutdown(self) -> bool: """ return self._shutdown_complete_event.is_set() - async def run(self) -> None: + async def run( + self, + *, + signal_shutdown: Sequence[signal.Signals] = (), + ) -> None: """Run the worker and wait on it to be shut down. This will not return until shutdown is complete. This means that @@ -755,6 +761,15 @@ async def run(self) -> None: async function assuming that it is currently running. A cancel could also cancel the shutdown process. Therefore users are encouraged to use explicit shutdown instead. + + Args: + signal_shutdown: OS signals that, when received, initiate a + graceful worker shutdown. For example, + ``signal_shutdown=[signal.SIGTERM, signal.SIGINT]``. Handlers + are installed on entry and removed on exit. On platforms where + :py:meth:`asyncio.AbstractEventLoop.add_signal_handler` is not + supported (e.g. Windows), :py:func:`signal.signal` is used as a + fallback. """ def make_lambda(plugin: Plugin, next: Callable[[Worker], Awaitable[None]]): @@ -764,7 +779,48 @@ def make_lambda(plugin: Plugin, next: Callable[[Worker], Awaitable[None]]): for plugin in reversed(self._plugins): next_function = make_lambda(plugin, next_function) - await next_function(self) + with self._install_signal_handlers(signal_shutdown): + await next_function(self) + + @contextlib.contextmanager + def _install_signal_handlers( + self, sigs: Sequence[signal.Signals] + ) -> Iterator[None]: + if not sigs: + yield + return + + def request_shutdown() -> None: + self._shutdown_event.set() + + loop = asyncio.get_running_loop() + added_via_loop: list[signal.Signals] = [] + previous_handlers: dict[signal.Signals, Any] = {} + try: + for sig in sigs: + try: + loop.add_signal_handler(sig, request_shutdown) + added_via_loop.append(sig) + except NotImplementedError: + # Fallback (e.g. Windows): the handler runs outside the loop, + # so schedule the set threadsafe to wake a blocked loop. + previous_handlers[sig] = signal.signal( + sig, lambda *_: loop.call_soon_threadsafe(request_shutdown) + ) + yield + finally: + for sig in added_via_loop: + try: + loop.remove_signal_handler(sig) + except (NotImplementedError, ValueError): + pass + for sig, prev in previous_handlers.items(): + # prev is None when the prior handler was not set from Python, + # which signal.signal rejects with TypeError. + try: + signal.signal(sig, prev) + except (ValueError, OSError, TypeError): + pass async def _run(self): # Eagerly validate which will do a namespace check in Core diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index dda754a5b..63dfd59d9 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -5,6 +5,8 @@ import multiprocessing import multiprocessing.context import os +import signal +import sys import uuid from collections.abc import Awaitable, Callable, Sequence from contextlib import contextmanager @@ -209,6 +211,97 @@ async def test_worker_cancel_run(client: Client): assert not worker.is_running and worker.is_shutdown +async def test_worker_run_signal_shutdown_default_unchanged(client: Client): + worker = create_worker(client) + run_task = asyncio.create_task(worker.run()) + await asyncio.sleep(0.3) + assert worker.is_running and not worker.is_shutdown + await worker.shutdown() + await run_task + assert not worker.is_running and worker.is_shutdown + + +async def test_worker_run_signal_shutdown_empty_is_noop(client: Client): + # An empty sequence installs no handlers and leaves existing ones untouched. + worker = create_worker(client) + original_handler = ( + signal.getsignal(signal.SIGUSR1) if sys.platform != "win32" else None + ) + run_task = asyncio.create_task(worker.run(signal_shutdown=[])) + await asyncio.sleep(0.3) + assert worker.is_running + if sys.platform != "win32": + assert signal.getsignal(signal.SIGUSR1) is original_handler + await worker.shutdown() + await run_task + assert worker.is_shutdown + + +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX-only signal") +async def test_worker_run_signal_shutdown_triggers_shutdown(client: Client): + worker = create_worker(client) + run_task = asyncio.create_task(worker.run(signal_shutdown=[signal.SIGUSR1])) + await asyncio.sleep(0.3) + assert worker.is_running and not worker.is_shutdown + os.kill(os.getpid(), signal.SIGUSR1) + await asyncio.wait_for(run_task, timeout=10) + assert not worker.is_running and worker.is_shutdown + + +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX-only signals") +async def test_worker_run_signal_shutdown_multiple_signals(client: Client): + # Any one of the registered signals should trigger shutdown. + worker = create_worker(client) + run_task = asyncio.create_task( + worker.run(signal_shutdown=[signal.SIGUSR1, signal.SIGUSR2]) + ) + await asyncio.sleep(0.3) + assert worker.is_running + os.kill(os.getpid(), signal.SIGUSR2) + await asyncio.wait_for(run_task, timeout=10) + assert worker.is_shutdown + + +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX-only signal") +async def test_worker_run_signal_handlers_removed_after_run(client: Client): + worker = create_worker(client) + loop = asyncio.get_running_loop() + + run_task = asyncio.create_task(worker.run(signal_shutdown=[signal.SIGUSR1])) + await asyncio.sleep(0.3) + # remove_signal_handler returns True only while the handler is installed. + assert loop.remove_signal_handler(signal.SIGUSR1) is True + await worker.shutdown() + await run_task + assert loop.remove_signal_handler(signal.SIGUSR1) is False + + +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX-only signal") +async def test_worker_run_signal_shutdown_explicit_shutdown_still_works( + client: Client, +): + # The signal path is additive: explicit shutdown() must still work. + worker = create_worker(client) + run_task = asyncio.create_task(worker.run(signal_shutdown=[signal.SIGUSR1])) + await asyncio.sleep(0.3) + assert worker.is_running + await worker.shutdown() + await run_task + assert worker.is_shutdown + + +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX-only signal") +async def test_worker_run_signal_handlers_removed_on_fatal_error(client: Client): + # Handlers must be removed even when run() exits via a fatal worker error. + worker = create_worker(client) + loop = asyncio.get_running_loop() + with pytest.raises(RuntimeError): + with WorkerFailureInjector(worker) as inj: + inj.workflow.poll_fail_queue.put_nowait(RuntimeError("OH NO")) + await worker.run(signal_shutdown=[signal.SIGUSR1]) + assert loop.remove_signal_handler(signal.SIGUSR1) is False + + @activity.defn async def say_hello(name: str) -> str: return f"Hello, {name}!"