From d8f95b801e2937857c3db1302432eec7bb9dd109 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 24 Jun 2026 10:38:18 -0700 Subject: [PATCH 1/2] feat: context aware replay status --- .../examples-catalog.json | 15 + .../src/logger_example/replay_logging.py | 93 ++++ .../template.yaml | 18 + .../logger_example/test_replay_logging.py | 52 ++ .../concurrency/executor.py | 1 - .../context.py | 459 +++++++++++------- .../execution.py | 15 +- .../logger.py | 40 +- .../aws_durable_execution_sdk_python/state.py | 66 +-- .../tests/context_test.py | 200 +++++++- .../tests/execution_test.py | 5 +- .../tests/logger_test.py | 37 +- .../tests/state_test.py | 85 ++-- 13 files changed, 755 insertions(+), 331 deletions(-) create mode 100644 packages/aws-durable-execution-sdk-python-examples/src/logger_example/replay_logging.py create mode 100644 packages/aws-durable-execution-sdk-python-examples/test/logger_example/test_replay_logging.py diff --git a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json index 754fb386..ca775ab2 100644 --- a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json +++ b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json @@ -291,6 +291,21 @@ }, "path": "./src/logger_example/logger_example.py" }, + { + "name": "Replay Logging", + "description": "Demonstrating replay-aware logger de-duplication across a wait/replay boundary", + "handler": "replay_logging.handler", + "integration": true, + "durableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + }, + "loggingConfig": { + "ApplicationLogLevel": "INFO", + "LogFormat": "JSON" + }, + "path": "./src/logger_example/replay_logging.py" + }, { "name": "Steps with Retry", "description": "Multiple steps with retry logic in a polling pattern", diff --git a/packages/aws-durable-execution-sdk-python-examples/src/logger_example/replay_logging.py b/packages/aws-durable-execution-sdk-python-examples/src/logger_example/replay_logging.py new file mode 100644 index 00000000..2f8c39b7 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-examples/src/logger_example/replay_logging.py @@ -0,0 +1,93 @@ +"""Example demonstrating replay-aware logging across a wait boundary.""" + +from typing import Any + +from aws_durable_execution_sdk_python.config import Duration +from aws_durable_execution_sdk_python.context import ( + DurableContext, + StepContext, + durable_step, + durable_with_child_context, +) +from aws_durable_execution_sdk_python.execution import durable_execution + + +@durable_step +def prepare(step_context: StepContext, item: str) -> str: + """A step that runs before the wait. + + Its log is emitted on the first invocation. On replay this step is not + re-executed (it returns its checkpointed result), so this log does not + repeat. + """ + step_context.logger.info("Preparing item", extra={"item": item}) + return f"prepared:{item}" + + +@durable_step +def finalize(step_context: StepContext, prepared: str) -> str: + """A step that runs after the wait (new work on the replay invocation).""" + step_context.logger.info("Finalizing item", extra={"prepared": prepared}) + return f"done:{prepared}" + + +@durable_with_child_context +def audit(child_ctx: DurableContext, prepared: str) -> str: + """Child context with its own logger and its own replay status.""" + child_ctx.logger.info( + "Auditing in child context (before child wait)", + extra={"prepared": prepared, "child_is_replaying": child_ctx.is_replaying()}, + ) + + # The child's own replay boundary. + child_ctx.wait(duration=Duration.from_seconds(5), name="audit_cooldown") + + # After the child's wait: emitted as new work on the child's replay. + child_ctx.logger.info( + "Resumed in child context (after child wait)", + extra={"child_is_replaying": child_ctx.is_replaying()}, + ) + + return child_ctx.step(lambda _: f"audited:{prepared}", name="record_audit") + + +@durable_execution +def handler(event: Any, context: DurableContext) -> dict[str, Any]: + """Handler demonstrating replay-aware logging across a wait.""" + item: str = event.get("item", "widget") if isinstance(event, dict) else "widget" + + # --- Before the wait --- + # On the replay invocation these lines are de-duplicated by the replay-aware + # logger because the context is still replaying when it reaches them. + context.logger.info( + "Workflow started (before wait)", + extra={"item": item, "is_replaying": context.is_replaying()}, + ) + + prepared: str = context.step(prepare(item), name="prepare") + + context.logger.info( + "Prepared, about to wait", + extra={"prepared": prepared, "is_replaying": context.is_replaying()}, + ) + + # --- The replay boundary --- + # The wait suspends the execution. When it resumes, the handler replays from + # the top; everything above is de-duplicated, and everything below is new. + context.wait(duration=Duration.from_seconds(5), name="cooldown") + + # --- After the wait --- + # These logs are emitted on the replay invocation because the context has + # crossed its replay boundary and is no longer replaying. + context.logger.info( + "Resumed after wait", + extra={"is_replaying": context.is_replaying()}, + ) + + audited: str = context.run_in_child_context(audit(prepared), name="audit") + + result: str = context.step(finalize(audited), name="finalize") + + context.logger.info("Workflow completed", extra={"result": result}) + + return {"result": result, "item": item} diff --git a/packages/aws-durable-execution-sdk-python-examples/template.yaml b/packages/aws-durable-execution-sdk-python-examples/template.yaml index 5e39d9e6..90a682fc 100644 --- a/packages/aws-durable-execution-sdk-python-examples/template.yaml +++ b/packages/aws-durable-execution-sdk-python-examples/template.yaml @@ -492,6 +492,24 @@ } } }, + "ReplayLogging": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "build/", + "Handler": "replay_logging.handler", + "Description": "Demonstrating replay-aware logger de-duplication across a wait/replay boundary", + "Role": { + "Fn::GetAtt": [ + "DurableFunctionRole", + "Arn" + ] + }, + "DurableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + } + } + }, "StepsWithRetry": { "Type": "AWS::Serverless::Function", "Properties": { diff --git a/packages/aws-durable-execution-sdk-python-examples/test/logger_example/test_replay_logging.py b/packages/aws-durable-execution-sdk-python-examples/test/logger_example/test_replay_logging.py new file mode 100644 index 00000000..bb42e30b --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-examples/test/logger_example/test_replay_logging.py @@ -0,0 +1,52 @@ +"""Tests for the replay_logging example. + +These tests do not assert on emitted log lines (the replay-aware +de-duplication is best observed in CloudWatch after deploying). They verify the +workflow runs end-to-end across the wait/replay boundary and produces the +expected operations and result. +""" + +import pytest + +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import OperationType +from src.logger_example import replay_logging +from test.conftest import deserialize_operation_payload + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=replay_logging.handler, + lambda_function_name="Replay Logging", +) +def test_replay_logging(durable_runner): + """Test the replay-aware logging example runs across the wait boundary.""" + with durable_runner: + result = durable_runner.run(input={"item": "widget"}, timeout=30) + + assert result.status is InvocationStatus.SUCCEEDED + assert deserialize_operation_payload(result.result) == { + "result": "done:audited:prepared:widget", + "item": "widget", + } + + # Two wait operations force suspend/replay cycles: one in the parent context + # and one inside the child (audit) context. This exercises per-context replay + # status in different contexts. + wait_ops = [ + op for op in result.operations if op.operation_type == OperationType.WAIT + ] + assert len(wait_ops) >= 1 + + # Steps before (prepare) and after (finalize) the wait both ran. The child + # context's record_audit step is nested inside the CONTEXT operation. + step_ops = [ + op for op in result.operations if op.operation_type == OperationType.STEP + ] + assert len(step_ops) >= 2 + + # The audit child context produces a CONTEXT operation. + context_ops = [ + op for op in result.operations if op.operation_type.value == "CONTEXT" + ] + assert len(context_ops) >= 1 diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py index 0cdf40e8..7767b155 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py @@ -489,7 +489,6 @@ def run_in_child_handler() -> ResultType: is_virtual=is_virtual, ), ) - child_context.state.track_replay(operation_id=operation_id) return result def replay(self, execution_state: ExecutionState, executor_context: DurableContext): diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py index c2e1d326..f03b7edc 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py @@ -2,7 +2,9 @@ import hashlib import logging +from contextlib import contextmanager from dataclasses import dataclass +from threading import Lock from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeVar from aws_durable_execution_sdk_python.config import ( @@ -46,7 +48,10 @@ SerDes, deserialize, ) -from aws_durable_execution_sdk_python.state import ExecutionState # noqa: TCH001 +from aws_durable_execution_sdk_python.state import ( # noqa: TCH001 + ExecutionState, + ReplayStatus, +) from aws_durable_execution_sdk_python.threading import OrderedCounter from aws_durable_execution_sdk_python.types import Callback as CallbackProtocol from aws_durable_execution_sdk_python.types import ( @@ -289,6 +294,7 @@ def __init__( parent_id: str | None = None, logger: Logger | None = None, step_id_prefix: str | None = None, + replay_status: ReplayStatus = ReplayStatus.NEW, ) -> None: self.state: ExecutionState = state self.execution_context: ExecutionContext = execution_context @@ -304,15 +310,30 @@ def __init__( self._is_virtual: bool = self._parent_id != self._step_id_prefix self._step_counter: OrderedCounter = OrderedCounter() + # Replay status is tracked per-context. + # A context starts in the status inherited from its creator and refines + # itself to NEW via look-ahead as it reaches its own replay boundary. + # Concurrent branches each get their own child context, so the lock + # guards refinement when branches share a context reference. + self._replay_status: ReplayStatus = replay_status + self._replay_status_lock: Lock = Lock() + log_info = LogInfo( execution_state=state, parent_id=parent_id, ) self._log_info = log_info - self.logger: Logger = logger or Logger.from_log_info( - logger=logging.getLogger(), - info=log_info, - ) + # The logger consults THIS context's replay status for de-duplication. + # A child inherits the parent's underlying logger/extra but must report + # its own status, so rebind the replay source onto self. + if logger is not None: + self.logger: Logger = logger.with_is_replaying(self.is_replaying) + else: + self.logger = Logger.from_log_info( + logger=logging.getLogger(), + info=log_info, + is_replaying=self.is_replaying, + ) @property def is_virtual(self) -> bool: @@ -331,6 +352,7 @@ def is_virtual(self) -> bool: def from_lambda_context( state: ExecutionState, lambda_context: LambdaContext, + replay_status: ReplayStatus = ReplayStatus.NEW, ): return DurableContext( state=state, @@ -339,6 +361,7 @@ def from_lambda_context( ), lambda_context=lambda_context, parent_id=None, + replay_status=replay_status, ) def create_child_context( @@ -374,6 +397,11 @@ def create_child_context( lambda_context=self.lambda_context, parent_id=child_parent_id, step_id_prefix=operation_id, + # Inherit the creator's current replay status; the child refines + # itself to NEW via look-ahead against its own step ids. + replay_status=( + ReplayStatus.REPLAY if self.is_replaying() else ReplayStatus.NEW + ), logger=self.logger.with_log_info( LogInfo( execution_state=self.state, @@ -399,6 +427,7 @@ def set_logger(self, new_logger: LoggerInterface): self.logger = Logger.from_log_info( logger=new_logger, info=self._log_info, + is_replaying=self.is_replaying, ) def _create_step_id_for_logical_step(self, step: int) -> str: @@ -420,6 +449,75 @@ def _create_step_id(self) -> str: new_counter: int = self._step_counter.increment() return self._create_step_id_for_logical_step(new_counter) + # region replay status + + def is_replaying(self) -> bool: + """Return True if this context is currently replaying prior operations.""" + with self._replay_status_lock: + return self._replay_status is ReplayStatus.REPLAY + + def _set_replay_status_new(self) -> None: + with self._replay_status_lock: + self._replay_status = ReplayStatus.NEW + + def _peek_next_operation_id(self) -> str: + """Return the operation id the next operation will take, without consuming it.""" + return self._create_step_id_for_logical_step( + self._step_counter.get_current() + 1 + ) + + def _next_operation_is_terminal_checkpoint(self) -> bool: + """True if this context's next operation already completed in a prior invocation.""" + result: CheckpointedResult = self.state.get_checkpoint_result( + self._peek_next_operation_id() + ) + + return result.is_succeeded() or result.is_failed() + + def _next_operation_exists(self) -> bool: + """True if a checkpoint exists for this context's next operation.""" + return self.state.get_checkpoint_result( + self._peek_next_operation_id() + ).is_existent() + + @contextmanager + def _replay_aware(self): + """Wrap a single operation with replay-boundary detection. + + The boundary has three parts: + + - Existence flip (before the op): if we are replaying and the next + operation has no checkpoint at all, it is brand-new code, so flip to + NEW immediately so the operation and its logs count as new. + - Deferred status flip (after the op): if we are replaying and the next + operation exists but is NOT terminal, that operation is the resume + point. We keep replay status through the operation and flip to NEW + afterwards, so the resuming operation's own logs stay de-duplicated + but subsequent code counts as new. + - Post-op existence flip (after the op): if we are still replaying once + the operation completes and the *following* operation does not exist + yet, we have reached the replay boundary. + """ + was_replaying: bool = self.is_replaying() + # Status: exists but not terminal -> defer flip until after the op runs. + flip_after: bool = ( + was_replaying + and self._next_operation_exists() + and not self._next_operation_is_terminal_checkpoint() + ) + # Existence: brand-new next op -> flip before the op runs. + if was_replaying and not self._next_operation_exists(): + self._set_replay_status_new() + try: + yield + finally: + if flip_after: + self._set_replay_status_new() + elif self.is_replaying() and not self._next_operation_exists(): + self._set_replay_status_new() + + # endregion replay status + # region Operations def create_callback( @@ -441,26 +539,25 @@ def create_callback( """ if not config: config = CallbackConfig() - operation_id: str = self._create_step_id() - executor: CallbackOperationExecutor = CallbackOperationExecutor( - state=self.state, - operation_identifier=OperationIdentifier( + with self._replay_aware(): + operation_id: str = self._create_step_id() + executor: CallbackOperationExecutor = CallbackOperationExecutor( + state=self.state, + operation_identifier=OperationIdentifier( + operation_id=operation_id, + sub_type=OperationSubType.CALLBACK, + parent_id=self._parent_id, + name=name, + ), + config=config, + ) + callback_id: str = executor.process() + return Callback( + callback_id=callback_id, operation_id=operation_id, - sub_type=OperationSubType.CALLBACK, - parent_id=self._parent_id, - name=name, - ), - config=config, - ) - callback_id: str = executor.process() - result: Callback = Callback( - callback_id=callback_id, - operation_id=operation_id, - state=self.state, - serdes=config.serdes, - ) - self.state.track_replay(operation_id=operation_id) - return result + state=self.state, + serdes=config.serdes, + ) def invoke( self, @@ -482,22 +579,21 @@ def invoke( """ if not config: config = InvokeConfig[P, R]() - operation_id = self._create_step_id() - executor: InvokeOperationExecutor[R] = InvokeOperationExecutor( - function_name=function_name, - payload=payload, - state=self.state, - operation_identifier=OperationIdentifier( - operation_id=operation_id, - sub_type=OperationSubType.CHAINED_INVOKE, - parent_id=self._parent_id, - name=name, - ), - config=config, - ) - result: R = executor.process() - self.state.track_replay(operation_id=operation_id) - return result + with self._replay_aware(): + operation_id = self._create_step_id() + executor: InvokeOperationExecutor[R] = InvokeOperationExecutor( + function_name=function_name, + payload=payload, + state=self.state, + operation_identifier=OperationIdentifier( + operation_id=operation_id, + sub_type=OperationSubType.CHAINED_INVOKE, + parent_id=self._parent_id, + name=name, + ), + config=config, + ) + return executor.process() def map( self, @@ -509,44 +605,43 @@ def map( """Execute a callable for each item in parallel.""" map_name: str | None = self._resolve_step_name(name, func) - operation_id = self._create_step_id() - operation_identifier = OperationIdentifier( - operation_id=operation_id, - sub_type=OperationSubType.MAP, - parent_id=self._parent_id, - name=map_name, - ) - map_context = self.create_child_context(operation_id=operation_id) - - def map_in_child_context() -> BatchResult[R]: - # map_context is a child_context of the context upon which `.map` - # was called. We are calling it `map_context` to make it explicit - # that any operations happening from hereon are done on the context - # that owns the branches - return map_handler( - items=inputs, - func=func, - config=config, - execution_state=self.state, - map_context=map_context, - operation_identifier=operation_identifier, + with self._replay_aware(): + operation_id = self._create_step_id() + operation_identifier = OperationIdentifier( + operation_id=operation_id, + sub_type=OperationSubType.MAP, + parent_id=self._parent_id, + name=map_name, ) + map_context = self.create_child_context(operation_id=operation_id) + + def map_in_child_context() -> BatchResult[R]: + # map_context is a child_context of the context upon which `.map` + # was called. We are calling it `map_context` to make it explicit + # that any operations happening from hereon are done on the context + # that owns the branches + return map_handler( + items=inputs, + func=func, + config=config, + execution_state=self.state, + map_context=map_context, + operation_identifier=operation_identifier, + ) - result: BatchResult[R] = child_handler( - func=map_in_child_context, - state=self.state, - operation_identifier=operation_identifier, - config=ChildConfig( - sub_type=OperationSubType.MAP, - serdes=getattr(config, "serdes", None), - # child_handler should only know the serdes of the parent serdes, - # the item serdes will be passed when we are actually executing - # the branch within its own child_handler. - item_serdes=None, - ), - ) - self.state.track_replay(operation_id=operation_id) - return result + return child_handler( + func=map_in_child_context, + state=self.state, + operation_identifier=operation_identifier, + config=ChildConfig( + sub_type=OperationSubType.MAP, + serdes=getattr(config, "serdes", None), + # child_handler should only know the serdes of the parent serdes, + # the item serdes will be passed when we are actually executing + # the branch within its own child_handler. + item_serdes=None, + ), + ) def parallel( self, @@ -555,45 +650,44 @@ def parallel( config: ParallelConfig | None = None, ) -> BatchResult[T]: """Execute multiple callables in parallel.""" - # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id - operation_id = self._create_step_id() - parallel_context = self.create_child_context(operation_id=operation_id) - operation_identifier = OperationIdentifier( - operation_id=operation_id, - sub_type=OperationSubType.PARALLEL, - parent_id=self._parent_id, - name=name, - ) + with self._replay_aware(): + # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id + operation_id = self._create_step_id() + parallel_context = self.create_child_context(operation_id=operation_id) + operation_identifier = OperationIdentifier( + operation_id=operation_id, + sub_type=OperationSubType.PARALLEL, + parent_id=self._parent_id, + name=name, + ) - def parallel_in_child_context() -> BatchResult[T]: - # parallel_context is a child_context of the context upon which `.map` - # was called. We are calling it `parallel_context` to make it explicit - # that any operations happening from hereon are done on the context - # that owns the branches - return parallel_handler( - callables=functions, - config=config, - execution_state=self.state, - parallel_context=parallel_context, + def parallel_in_child_context() -> BatchResult[T]: + # parallel_context is a child_context of the context upon which `.map` + # was called. We are calling it `parallel_context` to make it explicit + # that any operations happening from hereon are done on the context + # that owns the branches + return parallel_handler( + callables=functions, + config=config, + execution_state=self.state, + parallel_context=parallel_context, + operation_identifier=operation_identifier, + ) + + return child_handler( + func=parallel_in_child_context, + state=self.state, operation_identifier=operation_identifier, + config=ChildConfig( + sub_type=OperationSubType.PARALLEL, + serdes=getattr(config, "serdes", None), + # child_handler should only know the serdes of the parent serdes, + # the item serdes will be passed when we are actually executing + # the branch within its own child_handler. + item_serdes=None, + ), ) - result: BatchResult[T] = child_handler( - func=parallel_in_child_context, - state=self.state, - operation_identifier=operation_identifier, - config=ChildConfig( - sub_type=OperationSubType.PARALLEL, - serdes=getattr(config, "serdes", None), - # child_handler should only know the serdes of the parent serdes, - # the item serdes will be passed when we are actually executing - # the branch within its own child_handler. - item_serdes=None, - ), - ) - self.state.track_replay(operation_id=operation_id) - return result - def run_in_child_context( self, func: Callable[[DurableContext], T], @@ -613,36 +707,35 @@ def run_in_child_context( T: The result of the callable. """ step_name: str | None = self._resolve_step_name(name, func) - # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id - operation_id = self._create_step_id() - sub_type = ( - config.sub_type - if config and config.sub_type - else OperationSubType.RUN_IN_CHILD_CONTEXT - ) + with self._replay_aware(): + # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id + operation_id = self._create_step_id() + sub_type = ( + config.sub_type + if config and config.sub_type + else OperationSubType.RUN_IN_CHILD_CONTEXT + ) - is_virtual: bool = config.is_virtual if config else False + is_virtual: bool = config.is_virtual if config else False - def callable_with_child_context(): - return func( - self.create_child_context( - operation_id=operation_id, is_virtual=is_virtual + def callable_with_child_context(): + return func( + self.create_child_context( + operation_id=operation_id, is_virtual=is_virtual + ) ) - ) - result: T = child_handler( - func=callable_with_child_context, - state=self.state, - operation_identifier=OperationIdentifier( - operation_id=operation_id, - sub_type=sub_type, - parent_id=self._parent_id, - name=step_name, - ), - config=config, - ) - self.state.track_replay(operation_id=operation_id) - return result + return child_handler( + func=callable_with_child_context, + state=self.state, + operation_identifier=OperationIdentifier( + operation_id=operation_id, + sub_type=sub_type, + parent_id=self._parent_id, + name=step_name, + ), + config=config, + ) def step( self, @@ -654,22 +747,21 @@ def step( logger.debug("Step name: %s", step_name) if not config: config = StepConfig() - operation_id = self._create_step_id() - executor: StepOperationExecutor[T] = StepOperationExecutor( - func=func, - config=config, - state=self.state, - operation_identifier=OperationIdentifier( - operation_id=operation_id, - sub_type=OperationSubType.STEP, - parent_id=self._parent_id, - name=step_name, - ), - context_logger=self.logger, - ) - result: T = executor.process() - self.state.track_replay(operation_id=operation_id) - return result + with self._replay_aware(): + operation_id = self._create_step_id() + executor: StepOperationExecutor[T] = StepOperationExecutor( + func=func, + config=config, + state=self.state, + operation_identifier=OperationIdentifier( + operation_id=operation_id, + sub_type=OperationSubType.STEP, + parent_id=self._parent_id, + name=step_name, + ), + context_logger=self.logger, + ) + return executor.process() def wait(self, duration: Duration, name: str | None = None) -> None: """Wait for a specified amount of time. @@ -682,20 +774,20 @@ def wait(self, duration: Duration, name: str | None = None) -> None: if seconds < 1: msg = "duration must be at least 1 second" raise ValidationError(msg) - operation_id = self._create_step_id() - wait_seconds = duration.seconds - executor: WaitOperationExecutor = WaitOperationExecutor( - seconds=wait_seconds, - state=self.state, - operation_identifier=OperationIdentifier( - operation_id=operation_id, - sub_type=OperationSubType.WAIT, - parent_id=self._parent_id, - name=name, - ), - ) - executor.process() - self.state.track_replay(operation_id=operation_id) + with self._replay_aware(): + operation_id = self._create_step_id() + wait_seconds = duration.seconds + executor: WaitOperationExecutor = WaitOperationExecutor( + seconds=wait_seconds, + state=self.state, + operation_identifier=OperationIdentifier( + operation_id=operation_id, + sub_type=OperationSubType.WAIT, + parent_id=self._parent_id, + name=name, + ), + ) + executor.process() def wait_for_callback( self, @@ -738,24 +830,23 @@ def wait_for_condition( msg = "`config` is required for wait_for_condition" raise ValidationError(msg) - operation_id = self._create_step_id() - executor: WaitForConditionOperationExecutor[T] = ( - WaitForConditionOperationExecutor( - check=check, - config=config, - state=self.state, - operation_identifier=OperationIdentifier( - operation_id=operation_id, - sub_type=OperationSubType.WAIT_FOR_CONDITION, - parent_id=self._parent_id, - name=name, - ), - context_logger=self.logger, + with self._replay_aware(): + operation_id = self._create_step_id() + executor: WaitForConditionOperationExecutor[T] = ( + WaitForConditionOperationExecutor( + check=check, + config=config, + state=self.state, + operation_identifier=OperationIdentifier( + operation_id=operation_id, + sub_type=OperationSubType.WAIT_FOR_CONDITION, + parent_id=self._parent_id, + name=name, + ), + context_logger=self.logger, + ) ) - ) - result: T = executor.process() - self.state.track_replay(operation_id=operation_id) - return result + return executor.process() # endregion Operations diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py index e564764c..95d5e756 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py @@ -227,7 +227,6 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: initial_checkpoint_token=invocation_input.checkpoint_token, operations={}, service_client=service_client, - replay_status=ReplayStatus.NEW, plugin_executor=plugin_executor, ) @@ -252,7 +251,11 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: ).to_dict() raise - execution_state.mark_replaying_if_prior_operations_exist() + # Determine whether this is a replay (prior operations exist) at the + # execution level. This seeds the root context's replay status and the + # plugin's is_first_invocation flag. Replay status itself is then + # tracked per-context as execution proceeds. + has_prior_operations: bool = execution_state.has_prior_operations() raw_input_payload: str | None = execution_state.get_input_payload() @@ -270,7 +273,11 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: raise durable_context: DurableContext = DurableContext.from_lambda_context( - state=execution_state, lambda_context=context + state=execution_state, + lambda_context=context, + replay_status=( + ReplayStatus.REPLAY if has_prior_operations else ReplayStatus.NEW + ), ) # Use ThreadPoolExecutor for concurrent execution of user code and background checkpoint processing @@ -291,7 +298,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: if execution_operation is not None else None ), - is_first_invocation=not execution_state.is_replaying(), + is_first_invocation=not has_prior_operations, ) # Thread 1: Run background checkpoint processing executor.submit(execution_state.checkpoint_batches_forever) diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/logger.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/logger.py index c2a2be71..7b330d3a 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/logger.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/logger.py @@ -7,6 +7,7 @@ from aws_durable_execution_sdk_python.types import LoggerInterface + if TYPE_CHECKING: from collections.abc import Callable, Mapping, MutableMapping @@ -54,15 +55,25 @@ def __init__( self, logger: LoggerInterface, default_extra: Mapping[str, object], - execution_state: ExecutionState, + is_replaying: Callable[[], bool] = lambda: False, ) -> None: self._logger = logger self._default_extra = default_extra - self._execution_state = execution_state + # Replay status is owned by the DurableContext this logger belongs to, + # not by the execution as a whole. This callable reads that context's + # current replay status so log de-duplication is decided per-context. + # The default (never replaying) suits a standalone logger with no + # owning context, so it always logs. + self._is_replaying = is_replaying @classmethod - def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger: - """Create a new logger with the given LogInfo.""" + def from_log_info( + cls, + logger: LoggerInterface, + info: LogInfo, + is_replaying: Callable[[], bool] = lambda: False, + ) -> Logger: + """Create a new logger with the given LogInfo and replay-status source.""" extra: MutableMapping[str, object] = { "executionArn": info.execution_state.durable_execution_arn } @@ -75,15 +86,26 @@ def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger: extra["attempt"] = info.attempt if info.operation_id: extra["operationId"] = info.operation_id - return cls( - logger=logger, default_extra=extra, execution_state=info.execution_state - ) + return cls(logger=logger, default_extra=extra, is_replaying=is_replaying) def with_log_info(self, info: LogInfo) -> Logger: - """Clone the existing logger with new LogInfo.""" + """Clone the existing logger with new LogInfo, preserving the replay-status source.""" return Logger.from_log_info( logger=self._logger, info=info, + is_replaying=self._is_replaying, + ) + + def with_is_replaying(self, is_replaying: Callable[[], bool]) -> Logger: + """Clone the logger, rebinding it to a new replay-status source. + + Used when a child context inherits a parent's underlying logger but must + report its own (the child's) replay status rather than the parent's. + """ + return Logger( + logger=self._logger, + default_extra=self._default_extra, + is_replaying=is_replaying, ) def get_logger(self) -> LoggerInterface: @@ -128,4 +150,4 @@ def _log( log_func(msg, *args, extra=merged_extra) def _should_log(self) -> bool: - return not self._execution_state.is_replaying() + return not self._is_replaying() diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py index c24bd96d..3e87a899 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py @@ -11,7 +11,7 @@ from dataclasses import dataclass from enum import Enum from threading import Lock -from typing import TYPE_CHECKING, Callable, Any +from typing import TYPE_CHECKING, Any, Callable from aws_durable_execution_sdk_python.exceptions import ( BackgroundThreadError, @@ -29,10 +29,10 @@ Operation, OperationAction, OperationStatus, + OperationSubType, OperationType, OperationUpdate, StateOutput, - OperationSubType, ) from aws_durable_execution_sdk_python.plugin import ( PluginExecutor, @@ -40,6 +40,7 @@ ) from aws_durable_execution_sdk_python.threading import CompletionEvent, OrderedLock + if TYPE_CHECKING: import datetime from collections.abc import MutableMapping @@ -246,7 +247,6 @@ def __init__( service_client: DurableServiceClient, plugin_executor: PluginExecutor, batcher_config: CheckpointBatcherConfig | None = None, - replay_status: ReplayStatus = ReplayStatus.NEW, ): self.durable_execution_arn: str = durable_execution_arn self._current_checkpoint_token: str = initial_checkpoint_token @@ -275,9 +275,6 @@ def __init__( # Protects parent_to_children and parent_done self._parent_done_lock: Lock = Lock() - self._replay_status: ReplayStatus = replay_status - self._replay_status_lock: Lock = Lock() - self._visited_operations: set[str] = set() @property def operations(self) -> dict[str, Operation]: @@ -367,63 +364,20 @@ def get_execution_operation(self) -> Operation | None: return candidate - def track_replay(self, operation_id: str) -> None: - """Check if operation exists with completed status; if not, transition to NEW status. - - This method is called before each operation (step, wait, invoke, etc.) to determine - if we've reached the replay boundary. Once we encounter an operation that doesn't - exist or isn't completed, we transition from REPLAY to NEW status, which enables - logging for all subsequent code. - - Args: - operation_id: The operation ID to check - """ - with self._replay_status_lock: - if self._replay_status == ReplayStatus.REPLAY: - self._visited_operations.add(operation_id) - # Lock order: _replay_status_lock then _operations_lock. - with self._operations_lock: - completed_ops = { - op_id - for op_id, op in self._operations.items() - if op.operation_type != OperationType.EXECUTION - and op.status - in { - OperationStatus.SUCCEEDED, - OperationStatus.FAILED, - OperationStatus.CANCELLED, - OperationStatus.STOPPED, - OperationStatus.TIMED_OUT, - } - } - if completed_ops.issubset(self._visited_operations): - logger.debug( - "Transitioning from REPLAY to NEW status at operation %s", - operation_id, - ) - self._replay_status = ReplayStatus.NEW - - def is_replaying(self) -> bool: - """Check if execution is currently in replay mode. + def has_prior_operations(self) -> bool: + """Return True if any non-execution operation already exists. - Returns: - True if in REPLAY status, False if in NEW status + Used at execution setup to decide whether this invocation is a replay + (prior operations were checkpointed in an earlier invocation) versus a + first invocation. Per-operation replay status is tracked per-context on + DurableContext, not here. """ - with self._replay_status_lock: - return self._replay_status is ReplayStatus.REPLAY - - def mark_replaying_if_prior_operations_exist(self) -> None: - """Mark execution state as replaying when non-execution operations exist.""" with self._operations_lock: - has_prior_operations: bool = any( + return any( op.operation_type is not OperationType.EXECUTION for op in self._operations.values() ) - if has_prior_operations: - with self._replay_status_lock: - self._replay_status = ReplayStatus.REPLAY - def get_checkpoint_result(self, checkpoint_id: str) -> CheckpointedResult: """Get checkpoint result. diff --git a/packages/aws-durable-execution-sdk-python/tests/context_test.py b/packages/aws-durable-execution-sdk-python/tests/context_test.py index 42c0205b..27f18526 100644 --- a/packages/aws-durable-execution-sdk-python/tests/context_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/context_test.py @@ -35,10 +35,15 @@ ErrorObject, Operation, OperationStatus, - OperationType, OperationSubType, + OperationType, +) +from aws_durable_execution_sdk_python.plugin import PluginExecutor +from aws_durable_execution_sdk_python.state import ( + CheckpointedResult, + ExecutionState, + ReplayStatus, ) -from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState from aws_durable_execution_sdk_python.waits import ( WaitForConditionConfig, WaitForConditionDecision, @@ -2320,3 +2325,194 @@ def named(ctx: DurableContext) -> str: # endregion durable_parallel_branch + + +# region per-context replay status + + +def _replay_state(operations: dict[str, Operation]) -> ExecutionState: + """Build a real ExecutionState seeded with the given operations.""" + return ExecutionState( + durable_execution_arn="arn:aws:durable:us-east-1:123456789012:execution/test", + initial_checkpoint_token="token", # noqa: S106 + operations=operations, + service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), + ) + + +def _step_op(operation_id: str, status: OperationStatus) -> Operation: + return Operation( + operation_id=operation_id, + operation_type=OperationType.STEP, + status=status, + ) + + +def test_is_replaying_defaults_to_new_for_fresh_context(): + """A context created without a replay seed is not replaying.""" + ctx = create_test_context(state=_replay_state({})) + assert ctx.is_replaying() is False + + +def test_replay_aware_flips_before_brand_new_operation(): + """When replaying and the next op has no checkpoint, flip to NEW before it runs.""" + ctx = DurableContext( + state=_replay_state({}), + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + assert ctx.is_replaying() is True + + inside_status: list[bool] = [] + with ctx._replay_aware(): # noqa: SLF001 + inside_status.append(ctx.is_replaying()) + + # Brand-new op (no checkpoint) flips before the body runs. + assert inside_status == [False] + assert ctx.is_replaying() is False + + +def test_replay_aware_flips_after_completed_op_when_nothing_follows(): + """A completed op with no following checkpoint crosses the boundary after the op. + + The op stays replaying through its own execution (so its logs de-dup), then + flips to NEW afterwards because the next operation is brand-new. + """ + ctx = DurableContext( + state=_replay_state({}), + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + # The next op id this context will allocate. + next_id = ctx._peek_next_operation_id() # noqa: SLF001 + ctx.state._operations[next_id] = _step_op( # noqa: SLF001 + next_id, OperationStatus.SUCCEEDED + ) + + inside_status: list[bool] = [] + with ctx._replay_aware(): # noqa: SLF001 + # consume the id so the counter advances like a real operation + ctx._create_step_id() # noqa: SLF001 + inside_status.append(ctx.is_replaying()) + + # Still replaying THROUGH the completed op, then flips because nothing follows. + assert inside_status == [True] + assert ctx.is_replaying() is False + + +def test_replay_aware_defers_flip_until_after_resume_point(): + """A next op that exists but is NOT terminal is the resume point. + + The context stays replaying through the op's execution (so its logs are still + de-duplicated) and flips to NEW only afterwards. + """ + ctx = DurableContext( + state=_replay_state({}), + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + next_id = ctx._peek_next_operation_id() # noqa: SLF001 + # STARTED is non-terminal: e.g. a wait whose timer just fired. + ctx.state._operations[next_id] = _step_op( # noqa: SLF001 + next_id, OperationStatus.STARTED + ) + + inside_status: list[bool] = [] + with ctx._replay_aware(): # noqa: SLF001 + ctx._create_step_id() # noqa: SLF001 + inside_status.append(ctx.is_replaying()) + + # Still replaying THROUGH the resume op, then flipped afterwards. + assert inside_status == [True] + assert ctx.is_replaying() is False + + +def test_child_context_inherits_replaying_status(): + """A child context inherits the parent's current replay status at creation.""" + state = _replay_state({}) + parent = DurableContext( + state=state, + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + child = parent.create_child_context(operation_id="op-1") + assert child.is_replaying() is True + + parent._set_replay_status_new() # noqa: SLF001 + child_after = parent.create_child_context(operation_id="op-2") + assert child_after.is_replaying() is False + + +def test_child_context_replay_status_is_independent_of_parent(): + """Refining a child's status does not mutate the parent's, and vice versa.""" + state = _replay_state({}) + parent = DurableContext( + state=state, + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + child = parent.create_child_context(operation_id="op-1") + + child._set_replay_status_new() # noqa: SLF001 + + assert child.is_replaying() is False + assert parent.is_replaying() is True + + +def test_replay_aware_flips_after_completed_op_when_next_is_brand_new(): + """A completed op followed by a brand-new op crosses the boundary after the op. + + This covers logs emitted between a completed operation (e.g. a `wait` that + already fired) and the next, not-yet-existing operation. Such logs must be + treated as new work rather than suppressed. + """ + ctx = DurableContext( + state=_replay_state({}), + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + # Next op (the one this _replay_aware wraps) already completed; the op AFTER + # it has no checkpoint yet. + next_id = ctx._peek_next_operation_id() # noqa: SLF001 + ctx.state._operations[next_id] = _step_op( # noqa: SLF001 + next_id, OperationStatus.SUCCEEDED + ) + + inside_status: list[bool] = [] + with ctx._replay_aware(): # noqa: SLF001 + ctx._create_step_id() # noqa: SLF001 - consume the completed op's id + inside_status.append(ctx.is_replaying()) + + # Still replaying THROUGH the completed op, then flipped because the next + # operation is brand-new. + assert inside_status == [True] + assert ctx.is_replaying() is False + + +def test_replay_aware_stays_replaying_between_two_completed_ops(): + """Logs between two completed ops stay suppressed (still replaying).""" + ctx = DurableContext( + state=_replay_state({}), + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + # Both the wrapped op and the following op already completed. + first_id = ctx._create_step_id_for_logical_step(1) # noqa: SLF001 + second_id = ctx._create_step_id_for_logical_step(2) # noqa: SLF001 + ctx.state._operations[first_id] = _step_op( # noqa: SLF001 + first_id, OperationStatus.SUCCEEDED + ) + ctx.state._operations[second_id] = _step_op( # noqa: SLF001 + second_id, OperationStatus.SUCCEEDED + ) + + with ctx._replay_aware(): # noqa: SLF001 + ctx._create_step_id() # noqa: SLF001 - consume first completed op + assert ctx.is_replaying() is True + + # Next op also completed, so we remain replaying. + assert ctx.is_replaying() is True + + +# endregion per-context replay status diff --git a/packages/aws-durable-execution-sdk-python/tests/execution_test.py b/packages/aws-durable-execution-sdk-python/tests/execution_test.py index ed79bedf..26b235d8 100644 --- a/packages/aws-durable-execution-sdk-python/tests/execution_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/execution_test.py @@ -34,6 +34,7 @@ CheckpointOutput, CheckpointUpdatedExecutionState, ContextDetails, + DurableExecutionInvocationOutput, DurableServiceClient, ErrorObject, ExecutionDetails, @@ -45,10 +46,10 @@ StateOutput, StepDetails, WaitDetails, - DurableExecutionInvocationOutput, ) from aws_durable_execution_sdk_python.plugin import DurableInstrumentationPlugin + LARGE_RESULT = "large_success" * 1024 * 1024 # region Models @@ -2708,7 +2709,7 @@ def test_durable_execution_replays_when_paginated_state_has_prior_operations(): @durable_execution def test_handler(event: Any, context: DurableContext) -> dict: - return {"is_replaying": context.state.is_replaying()} + return {"is_replaying": context.is_replaying()} result = test_handler(invocation_input, _make_lambda_context()) diff --git a/packages/aws-durable-execution-sdk-python/tests/logger_test.py b/packages/aws-durable-execution-sdk-python/tests/logger_test.py index 9ce3719d..42074ad9 100644 --- a/packages/aws-durable-execution-sdk-python/tests/logger_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/logger_test.py @@ -8,12 +8,12 @@ from aws_durable_execution_sdk_python.lambda_service import ( Operation, OperationStatus, - OperationType, OperationSubType, + OperationType, ) from aws_durable_execution_sdk_python.logger import Logger, LoggerInterface, LogInfo from aws_durable_execution_sdk_python.plugin import PluginExecutor -from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus +from aws_durable_execution_sdk_python.state import ExecutionState class PowertoolsLoggerStub: @@ -375,50 +375,49 @@ def test_logger_replay_no_logging(): operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, ) - replay_execution_state = ExecutionState( + execution_state = ExecutionState( durable_execution_arn="arn:aws:test", initial_checkpoint_token="test_token", # noqa: S106 operations={"op1": operation}, service_client=Mock(), - replay_status=ReplayStatus.REPLAY, plugin_executor=PluginExecutor([]), ) - log_info = LogInfo(replay_execution_state, "parent123", "test_name", 5) + log_info = LogInfo(execution_state, "parent123", "test_name", 5) mock_logger = Mock() - logger = Logger.from_log_info(mock_logger, log_info) + # Logger consults its owning context's replay status. While replaying, it + # suppresses logs. + logger = Logger.from_log_info(mock_logger, log_info, is_replaying=lambda: True) logger.info("logging info") - replay_execution_state.track_replay(operation_id="op1") mock_logger.info.assert_not_called() def test_logger_replay_then_new_logging(): - operation1 = Operation( + operation = Operation( operation_id="op1", operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, ) - operation2 = Operation( - operation_id="op2", - operation_type=OperationType.STEP, - status=OperationStatus.SUCCEEDED, - ) execution_state = ExecutionState( durable_execution_arn="arn:aws:test", initial_checkpoint_token="test_token", # noqa: S106 - operations={"op1": operation1, "op2": operation2}, + operations={"op1": operation}, service_client=Mock(), - replay_status=ReplayStatus.REPLAY, plugin_executor=PluginExecutor([]), ) log_info = LogInfo(execution_state, "parent123", "test_name", 5) mock_logger = Mock() - logger = Logger.from_log_info(mock_logger, log_info) - execution_state.track_replay(operation_id="op1") - logger.info("logging info") + # Drive the replay status through a mutable flag standing in for the + # owning context's per-context status. + replaying = {"value": True} + logger = Logger.from_log_info( + mock_logger, log_info, is_replaying=lambda: replaying["value"] + ) + logger.info("logging info") mock_logger.info.assert_not_called() - execution_state.track_replay(operation_id="op2") + # Once the context reaches its replay boundary, logging resumes. + replaying["value"] = False logger.info("logging info") mock_logger.info.assert_called_once() diff --git a/packages/aws-durable-execution-sdk-python/tests/state_test.py b/packages/aws-durable-execution-sdk-python/tests/state_test.py index f026d357..601fe62c 100644 --- a/packages/aws-durable-execution-sdk-python/tests/state_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/state_test.py @@ -9,7 +9,7 @@ import time import unittest.mock from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, call, patch, create_autospec +from unittest.mock import Mock, call, create_autospec, patch import pytest @@ -32,11 +32,11 @@ Operation, OperationAction, OperationStatus, + OperationSubType, OperationType, OperationUpdate, StateOutput, StepDetails, - OperationSubType, ) from aws_durable_execution_sdk_python.plugin import ( DurableInstrumentationPlugin, @@ -47,7 +47,6 @@ CheckpointedResult, ExecutionState, QueuedOperation, - ReplayStatus, ) from aws_durable_execution_sdk_python.threading import CompletionEvent @@ -3452,64 +3451,47 @@ def test_create_checkpoint_sync_always_synchronous(): executor.shutdown(wait=True) -def test_state_replay_mode(): +def test_state_has_prior_operations_true_when_non_execution_op_exists(): operation1 = Operation( operation_id="op1", operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, ) - operation2 = Operation( - operation_id="op2", - operation_type=OperationType.STEP, - status=OperationStatus.SUCCEEDED, - ) execution_state = ExecutionState( durable_execution_arn="arn:aws:test", initial_checkpoint_token="test_token", # noqa: S106 - operations={"op1": operation1, "op2": operation2}, + operations={"op1": operation1}, service_client=Mock(), plugin_executor=PluginExecutor(plugins=None), - replay_status=ReplayStatus.REPLAY, ) - assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op1") - assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op2") - assert execution_state.is_replaying() is False - - -def test_state_replay_mode_with_timed_out(): - """Test that TIMED_OUT operations are treated as terminal states for replay tracking. + assert execution_state.has_prior_operations() is True - This test verifies that when an operation has TIMED_OUT status, it is correctly - recognized as a completed/terminal state, allowing the replay status to transition - from REPLAY to NEW once all completed operations have been visited. - Regression test for: https://github.com/aws/aws-durable-execution-sdk-python/issues/262 - """ - operation1 = Operation( - operation_id="op1", - operation_type=OperationType.STEP, - status=OperationStatus.TIMED_OUT, +def test_state_has_prior_operations_false_when_only_execution_op_exists(): + execution_op = Operation( + operation_id="exec1", + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, ) - operation2 = Operation( - operation_id="op2", - operation_type=OperationType.STEP, - status=OperationStatus.SUCCEEDED, + execution_state = ExecutionState( + durable_execution_arn="arn:aws:test", + initial_checkpoint_token="test_token", # noqa: S106 + operations={"exec1": execution_op}, + service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), ) + assert execution_state.has_prior_operations() is False + + +def test_state_has_prior_operations_false_when_empty(): execution_state = ExecutionState( durable_execution_arn="arn:aws:test", initial_checkpoint_token="test_token", # noqa: S106 - operations={"op1": operation1, "op2": operation2}, + operations={}, service_client=Mock(), plugin_executor=PluginExecutor(plugins=None), - replay_status=ReplayStatus.REPLAY, ) - assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op1") - assert execution_state.is_replaying() is True - execution_state.track_replay(operation_id="op2") - assert execution_state.is_replaying() is False + assert execution_state.has_prior_operations() is False # Tests for empty checkpoint coalescing (issue #325) @@ -4262,16 +4244,13 @@ def test_plugin_executor_not_called_for_pending_operations(): # endregion Plugin Executor Integration Tests -def _make_execution_state_for_operations( - mock_lambda_client, *, replay_status=ReplayStatus.NEW, operations=None -): +def _make_execution_state_for_operations(mock_lambda_client, *, operations=None): return ExecutionState( durable_execution_arn="test_arn", initial_checkpoint_token="token123", # noqa: S106 operations=operations or {}, service_client=mock_lambda_client, plugin_executor=PluginExecutor(plugins=None), - replay_status=replay_status, ) @@ -4295,18 +4274,16 @@ def test_operations_property_returns_snapshot_copy(): assert len(state.operations) == 1 -def test_track_replay_iteration_safe_under_concurrent_update(): - """track_replay must not raise when operations are updated concurrently. +def test_has_prior_operations_iteration_safe_under_concurrent_update(): + """has_prior_operations must not raise when operations are updated concurrently. - A worker thread iterates operations inside track_replay while the checkpoint - path updates the same map. Without consistent locking this raises + A worker thread iterates operations inside has_prior_operations while the + checkpoint path updates the same map. Without consistent locking this raises "dictionary changed size during iteration". """ mock_lambda_client = Mock(spec=LambdaClient) - state = _make_execution_state_for_operations( - mock_lambda_client, replay_status=ReplayStatus.REPLAY - ) - # Seed completed operations so track_replay keeps iterating (stays REPLAY). + state = _make_execution_state_for_operations(mock_lambda_client) + # Seed completed operations so iteration has work to do. for i in range(50): state._operations[f"seed{i}"] = Operation( operation_id=f"seed{i}", @@ -4331,7 +4308,7 @@ def writer(): def reader(): try: for _ in range(2000): - state.track_replay(operation_id="probe") + state.has_prior_operations() except Exception as e: # noqa: BLE001 errors.append(e) @@ -4343,4 +4320,4 @@ def reader(): stop.set() writer_t.join(timeout=5) - assert not errors, f"track_replay raced with concurrent update: {errors}" + assert not errors, f"has_prior_operations raced with concurrent update: {errors}" From 9a60f6bba521237f0da2f3b52987677f91bb1107 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 24 Jun 2026 13:57:47 -0700 Subject: [PATCH 2/2] feat: flip status to new before user code --- .../context.py | 58 +++++++++---- .../tests/context_test.py | 81 +++++++++++++++++++ 2 files changed, 125 insertions(+), 14 deletions(-) diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py index f03b7edc..8491f0a6 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py @@ -481,32 +481,62 @@ def _next_operation_exists(self) -> bool: ).is_existent() @contextmanager - def _replay_aware(self): + def _replay_aware(self, *, executes_user_code: bool = False): """Wrap a single operation with replay-boundary detection. - The boundary has three parts: + Args: + executes_user_code: True for operations that invoke a user-provided + function on (re-)entry — `step` and `wait_for_condition`'s check + (and, transitively, `wait_for_callback`'s submitter, which runs + as a step). When such an operation actually runs the user + function it is, by definition, doing new work (a cached + SUCCEEDED operation returns its checkpoint without invoking the + function). So a non-terminal user-code operation flips the + context to NEW *before* its body runs, including on retries. For + all other operations (default False), a non-terminal next + operation is a pure resume point with no user body, so we keep + replay status through it and flip afterwards. + + The boundary has these parts: - Existence flip (before the op): if we are replaying and the next operation has no checkpoint at all, it is brand-new code, so flip to NEW immediately so the operation and its logs count as new. - - Deferred status flip (after the op): if we are replaying and the next - operation exists but is NOT terminal, that operation is the resume - point. We keep replay status through the operation and flip to NEW - afterwards, so the resuming operation's own logs stay de-duplicated - but subsequent code counts as new. + - User-code flip (before the op): if `executes_user_code` and the next + operation is non-terminal (brand-new OR retrying/re-executing), the + user function is about to run, so flip to NEW before it. + - Deferred status flip (after the op): for non-user-code operations, if + we are replaying and the next operation exists but is NOT terminal, + that operation is the resume point. We keep replay status through the + operation and flip to NEW afterwards, so the resuming operation's own + logs stay de-duplicated but subsequent code counts as new. - Post-op existence flip (after the op): if we are still replaying once the operation completes and the *following* operation does not exist yet, we have reached the replay boundary. """ was_replaying: bool = self.is_replaying() - # Status: exists but not terminal -> defer flip until after the op runs. + # Only peek when replaying; avoids unnecessary checkpoint lookups (and + # any step-id side effects) on the common non-replay path. + next_exists: bool = was_replaying and self._next_operation_exists() + next_terminal: bool = ( + was_replaying and self._next_operation_is_terminal_checkpoint() + ) + + # Deferred flip applies only to non-user-code resume points. For + # user-code ops we flip before instead, so don't defer. flip_after: bool = ( was_replaying - and self._next_operation_exists() - and not self._next_operation_is_terminal_checkpoint() + and not executes_user_code + and next_exists + and not next_terminal ) - # Existence: brand-new next op -> flip before the op runs. - if was_replaying and not self._next_operation_exists(): + # Before-the-op flips: + # - brand-new next op (no checkpoint): always flip to NEW. + # - user-code op that is non-terminal (brand-new or retrying): the user + # function is about to run real work, so flip to NEW before it. + if was_replaying and ( + not next_exists or (executes_user_code and not next_terminal) + ): self._set_replay_status_new() try: yield @@ -747,7 +777,7 @@ def step( logger.debug("Step name: %s", step_name) if not config: config = StepConfig() - with self._replay_aware(): + with self._replay_aware(executes_user_code=True): operation_id = self._create_step_id() executor: StepOperationExecutor[T] = StepOperationExecutor( func=func, @@ -830,7 +860,7 @@ def wait_for_condition( msg = "`config` is required for wait_for_condition" raise ValidationError(msg) - with self._replay_aware(): + with self._replay_aware(executes_user_code=True): operation_id = self._create_step_id() executor: WaitForConditionOperationExecutor[T] = ( WaitForConditionOperationExecutor( diff --git a/packages/aws-durable-execution-sdk-python/tests/context_test.py b/packages/aws-durable-execution-sdk-python/tests/context_test.py index 27f18526..6127b114 100644 --- a/packages/aws-durable-execution-sdk-python/tests/context_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/context_test.py @@ -2515,4 +2515,85 @@ def test_replay_aware_stays_replaying_between_two_completed_ops(): assert ctx.is_replaying() is True +def test_replay_aware_user_code_flips_before_retrying_op(): + """A retrying/re-executing user-code op flips to NEW BEFORE the body runs. + + A step whose checkpoint is non-terminal (e.g. PENDING/STARTED from a retry) + is about to re-run the user function, which is real new work. With + executes_user_code=True the context flips to NEW before the body so the + step's logs (and future plugin state) reflect an executing attempt. + """ + ctx = DurableContext( + state=_replay_state({}), + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + next_id = ctx._peek_next_operation_id() # noqa: SLF001 + # STARTED == non-terminal: a retry attempt about to re-execute. + ctx.state._operations[next_id] = _step_op( # noqa: SLF001 + next_id, OperationStatus.STARTED + ) + + inside_status: list[bool] = [] + with ctx._replay_aware(executes_user_code=True): # noqa: SLF001 + ctx._create_step_id() # noqa: SLF001 + inside_status.append(ctx.is_replaying()) + + # Flipped BEFORE the body (contrast with the non-user-code resume point, + # which stays replaying through the op). + assert inside_status == [False] + assert ctx.is_replaying() is False + + +def test_replay_aware_user_code_stays_replaying_for_completed_op(): + """A cached SUCCEEDED user-code op does not run its body, so stays replaying.""" + ctx = DurableContext( + state=_replay_state({}), + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + # Wrapped op completed; a following op also completed so nothing flips. + first_id = ctx._create_step_id_for_logical_step(1) # noqa: SLF001 + second_id = ctx._create_step_id_for_logical_step(2) # noqa: SLF001 + ctx.state._operations[first_id] = _step_op( # noqa: SLF001 + first_id, OperationStatus.SUCCEEDED + ) + ctx.state._operations[second_id] = _step_op( # noqa: SLF001 + second_id, OperationStatus.SUCCEEDED + ) + + with ctx._replay_aware(executes_user_code=True): # noqa: SLF001 + ctx._create_step_id() # noqa: SLF001 + assert ctx.is_replaying() is True + + assert ctx.is_replaying() is True + + +def test_replay_aware_non_user_code_stays_replaying_through_resume_point(): + """A non-user-code resume point (e.g. wait) stays replaying through the op. + + Contrast with the user-code case: a non-terminal wait is a pure resume + point with no user body, so logs emitted by the resuming op stay + de-duplicated and the flip is deferred until after. + """ + ctx = DurableContext( + state=_replay_state({}), + execution_context=ExecutionContext(durable_execution_arn="arn"), + replay_status=ReplayStatus.REPLAY, + ) + next_id = ctx._peek_next_operation_id() # noqa: SLF001 + ctx.state._operations[next_id] = _step_op( # noqa: SLF001 + next_id, OperationStatus.STARTED + ) + + inside_status: list[bool] = [] + with ctx._replay_aware(): # noqa: SLF001 - executes_user_code defaults False + ctx._create_step_id() # noqa: SLF001 + inside_status.append(ctx.is_replaying()) + + # Stayed replaying THROUGH the op, flipped after. + assert inside_status == [True] + assert ctx.is_replaying() is False + + # endregion per-context replay status