diff --git a/temporalio/client/_impl.py b/temporalio/client/_impl.py index 8e33ff910..7ba54746f 100644 --- a/temporalio/client/_impl.py +++ b/temporalio/client/_impl.py @@ -787,6 +787,11 @@ async def start_workflow_update( ): break + # Add response link if its a Nexus operation + nexus_ctx = temporalio.nexus._operation_context._try_start_operation_context() + if nexus_ctx is not None and resp.HasField("link"): + nexus_ctx._add_response_link(resp.link) + # Build the handle. If the user's wait stage is COMPLETED, make sure we # poll for result. handle: WorkflowUpdateHandle[Any] = WorkflowUpdateHandle( @@ -852,6 +857,23 @@ async def _build_update_workflow_execution_request( ) ), ) + # Only set Nexus fields for StartWorkflowUpdateInput, skip for UpdateWithStartUpdateWorkflowInput + if isinstance(input, StartWorkflowUpdateInput): + if input.request_id: + req.request.request_id = input.request_id + if input.links: + req.request.links.extend(input.links) + if input.callbacks: + req.request.completion_callbacks.extend( + temporalio.api.common.v1.Callback( + nexus=temporalio.api.common.v1.Callback.Nexus( + url=callback.url, + header=callback.headers, + ), + links=input.links or [], + ) + for callback in input.callbacks + ) if input.args: req.request.input.args.payloads.extend( await data_converter.encode(input.args) diff --git a/temporalio/client/_interceptor.py b/temporalio/client/_interceptor.py index 5333d487f..4eaab9264 100644 --- a/temporalio/client/_interceptor.py +++ b/temporalio/client/_interceptor.py @@ -322,6 +322,10 @@ class StartWorkflowUpdateInput: ret_type: type | None rpc_metadata: Mapping[str, str | bytes] rpc_timeout: timedelta | None + # The following options are for Nexus Operation-backed updates. Experimental and unstable + callbacks: Sequence[Callback] | None = None + links: Sequence[temporalio.api.common.v1.Link] | None = None + request_id: str | None = None @dataclass diff --git a/temporalio/client/_workflow.py b/temporalio/client/_workflow.py index 8579e8433..f2eef165c 100644 --- a/temporalio/client/_workflow.py +++ b/temporalio/client/_workflow.py @@ -59,6 +59,7 @@ ReturnType, SelfType, ) +from ._callback import Callback from ._exceptions import ( WorkflowContinuedAsNewError, WorkflowFailureError, @@ -896,6 +897,8 @@ async def start_update( rpc_timeout: timedelta | None = None, ) -> WorkflowUpdateHandle[Any]: ... + # draft-review: check why this doesnt currently support run_id and first_execution_run_id + # If it can be supported, wire it up for nexus operation-backed updates as well async def start_update( self, update: str | Callable, @@ -955,6 +958,12 @@ async def _start_update( result_type: type | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, + # run_id: str | None = None, + # first_execution_run_id: str | None = None, + # The following options are for Nexus Operation-backed updates. Experimental and unstable + callbacks: Sequence[Callback] | None = None, + links: Sequence[temporalio.api.common.v1.Link] | None = None, + request_id: str | None = None, ) -> WorkflowUpdateHandle[Any]: if wait_for_stage == WorkflowUpdateStage.ADMITTED: raise ValueError("ADMITTED wait stage not supported") @@ -976,6 +985,11 @@ async def _start_update( rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, wait_for_stage=wait_for_stage, + # run_id=run_id, + # first_execution_run_id=first_execution_run_id, + callbacks=callbacks, + links=links, + request_id=request_id, ) ) diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index 402e4b04e..c3d65c60b 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -25,15 +25,17 @@ wait_for_worker_shutdown_sync, ) from ._operation_handlers import ( + CancelUpdateWorkflowOptions, CancelWorkflowRunOptions, TemporalOperationHandler, ) from ._temporal_client import TemporalNexusClient, TemporalOperationResult -from ._token import WorkflowHandle +from ._token import UpdateHandle, WorkflowHandle __all__ = ( "workflow_run_operation", "CancelWorkflowRunOptions", + "CancelUpdateWorkflowOptions", "Info", "LoggerAdapter", "NexusCallback", @@ -49,6 +51,7 @@ "wait_for_worker_shutdown", "wait_for_worker_shutdown_sync", "WorkflowHandle", + "UpdateHandle", "TemporalNexusClient", "TemporalOperationStartHandlerFunc", "TemporalOperationHandler", diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 1128d1b71..4a92d1a50 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -715,3 +715,43 @@ async def _start_nexus_backing_workflow( ) return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) + + +async def _start_nexus_backed_workflow_update( + *, + temporal_context: _TemporalStartOperationContext, + workflow_id: str, + update: str | Callable, + arg: Any = temporalio.common._arg_unset, + args: Sequence[Any] = [], + id: str | None = None, + result_type: type | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + # run_id: str | None = None, + # first_execution_run_id: str | None = None, +) -> temporalio.client.WorkflowUpdateHandle[Any]: + # Default update ID to the Nexus request ID for retry-safety (matches sdk-go). + update_id = id or temporal_context.nexus_context.request_id + token = OperationToken( + type=OperationTokenType.UPDATE_WORKFLOW, + namespace=temporal_context.client.namespace, + workflow_id=workflow_id, + update_id=update_id, + ).encode() + workflow_handle = temporal_context.client.get_workflow_handle(workflow_id) + return await workflow_handle._start_update( + update, + arg, + args=args, + wait_for_stage=temporalio.client.WorkflowUpdateStage.ACCEPTED, # hardcoded as nexus only supports async updates + id=update_id, + result_type=result_type, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + callbacks=temporal_context._get_callbacks(token), + links=temporal_context._get_request_links(), + request_id=temporal_context.nexus_context.request_id, + # run_id=run_id, + # first_execution_run_id=first_execution_run_id, + ) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index e5c3bd762..df6ff994b 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -138,6 +138,23 @@ class CancelWorkflowRunOptions: """The ID of the workflow to cancel.""" +@dataclass(frozen=True) +class CancelUpdateWorkflowOptions: + """Options for cancelling the workflow update backing a Nexus operation. + + These options are built by :py:class:`TemporalOperationHandler` and passed to + :py:meth:`TemporalOperationHandler.cancel_workflow_update`. + + .. warning:: + This API is experimental and unstable. + """ + + workflow_id: str + """The ID of the workflow where the update is running.""" + update_id: str + """The ID of the update to cancel.""" + + class TemporalOperationHandler(OperationHandler[InputT, OutputT], ABC): """Operation handler for Nexus operations that interact with Temporal. Implementations override the start_operation method. @@ -190,6 +207,13 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: workflow_id=operation_token.workflow_id ) await self.cancel_workflow_run(cancel_ctx, options) + case OperationTokenType.UPDATE_WORKFLOW: + assert operation_token.update_id is not None + options = CancelUpdateWorkflowOptions( + workflow_id=operation_token.workflow_id, + update_id=operation_token.update_id, + ) + await self.cancel_workflow_update(cancel_ctx, options) async def cancel_workflow_run( self, @@ -205,3 +229,23 @@ async def cancel_workflow_run( options.workflow_id ) await workflow_handle.cancel() + + # draft-review: maybe just move it inline, no need for a function just to error out + # check after review in case theres some other way to override/supply custom cancels + async def cancel_workflow_update( + self, + ctx: TemporalCancelOperationContext, # pyright: ignore[reportUnusedParameter] + options: CancelUpdateWorkflowOptions, # pyright: ignore[reportUnusedParameter] + ) -> None: + """Cancels the workflow update backing the Nexus operation. + + .. warning:: + This API is experimental and unstable. + """ + raise HandlerError( + """ + Cancellation is not natively supported for update-workflow Nexus operations. + Override a TemporalOperationHandler and implement this method to run cancellable workflow updates. + """, + type=HandlerErrorType.NOT_IMPLEMENTED, + ) diff --git a/temporalio/nexus/_temporal_client.py b/temporalio/nexus/_temporal_client.py index 08204d89b..60fab42db 100644 --- a/temporalio/nexus/_temporal_client.py +++ b/temporalio/nexus/_temporal_client.py @@ -15,15 +15,17 @@ overload, ) -from nexusrpc import HandlerError, HandlerErrorType +from nexusrpc import HandlerError, HandlerErrorType, OperationError, OperationErrorState from nexusrpc.handler import StartOperationResultAsync, StartOperationResultSync from typing_extensions import Self import temporalio.common from temporalio.nexus._operation_context import ( + _start_nexus_backed_workflow_update, _start_nexus_backing_workflow, _TemporalStartOperationContext, ) +from temporalio.nexus._token import UpdateHandle from temporalio.types import ( MethodAsyncNoParam, MethodAsyncSingleParam, @@ -35,6 +37,7 @@ if TYPE_CHECKING: import temporalio.client + import temporalio.workflow _ResultT = TypeVar("_ResultT") @@ -279,6 +282,85 @@ async def start_workflow( """ ... + # Overload for no-param update + @overload + async def start_workflow_update( + self, + workflow_id: str, + update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], ReturnType], + *, + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for single-param update + @overload + async def start_workflow_update( + self, + workflow_id: str, + update: temporalio.workflow.UpdateMethodMultiParam[ + [SelfType, ParamType], ReturnType + ], + arg: ParamType, + *, + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for multi-param update + @overload + async def start_workflow_update( + self, + workflow_id: str, + update: temporalio.workflow.UpdateMethodMultiParam[MultiParamSpec, ReturnType], + *, + args: MultiParamSpec.args, # type: ignore + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for string-name update + @overload + async def start_workflow_update( + self, + workflow_id: str, + update: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str | None = None, + result_type: type[ReturnType] | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # draft-review: check why run_id and first_execution_run_id are not used + # for update workflow in python sdk + @abstractmethod + async def start_workflow_update( + self, + workflow_id: str, + update: str | Callable, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str | None = None, + result_type: type | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + # run_id: str | None = None, + # first_execution_run_id: str | None = None, + ) -> TemporalOperationResult[Any]: + """Start a workflow update as the backing asynchronous Nexus operation. + + .. warning:: + This API is experimental and unstable. + """ + ... + class _TemporalNexusClient(TemporalNexusClient): # pyright: ignore[reportUnusedClass] """Nexus-aware wrapper around a Temporal Client. @@ -377,3 +459,107 @@ async def start_workflow( ) return TemporalOperationResult.async_token(wf_handle.to_token()) + + # Overload for no-param update + @overload + async def start_workflow_update( + self, + workflow_id: str, + update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], ReturnType], + *, + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for single-param update + @overload + async def start_workflow_update( + self, + workflow_id: str, + update: temporalio.workflow.UpdateMethodMultiParam[ + [SelfType, ParamType], ReturnType + ], + arg: ParamType, + *, + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for multi-param update + @overload + async def start_workflow_update( + self, + workflow_id: str, + update: temporalio.workflow.UpdateMethodMultiParam[MultiParamSpec, ReturnType], + *, + args: MultiParamSpec.args, # type: ignore + id: str | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for string-name update + @overload + async def start_workflow_update( + self, + workflow_id: str, + update: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str | None = None, + result_type: type[ReturnType] | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + async def start_workflow_update( + self, + workflow_id: str, + update: str | Callable, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str | None = None, + result_type: type | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + # run_id: str | None = None, + # first_execution_run_id: str | None = None, + ) -> TemporalOperationResult[Any]: + """Start a workflow update as the backing asynchronous Nexus operation.""" + if not self._temporal_context.nexus_context.callback_url: + raise HandlerError( + "callback URL is required for a workflow update Nexus operation", + type=HandlerErrorType.BAD_REQUEST, + ) + with self._reserve_async_start(): + update_handle = await _start_nexus_backed_workflow_update( + temporal_context=self._temporal_context, + workflow_id=workflow_id, + update=update, + arg=arg, + args=args, + id=id, + result_type=result_type, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + # run_id=run_id, + # first_execution_run_id=first_execution_run_id, + ) + # If the update has already completed, return the result synchronously + # This is in-line with the Go implementation as well + if update_handle._known_outcome is not None: + try: + result = await update_handle.result() + except temporalio.client.WorkflowUpdateFailedError as err: + raise OperationError( + str(err), state=OperationErrorState.FAILED + ) from err + return TemporalOperationResult.sync(result) + nexus_handle = UpdateHandle._unsafe_from_client_workflow_update_handle( + update_handle + ) + return TemporalOperationResult.async_token(nexus_handle.to_token()) diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index d52b54180..0d14b687c 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -14,6 +14,7 @@ class OperationTokenType(IntEnum): """Type discriminator for Nexus operation tokens.""" WORKFLOW = 1 + UPDATE_WORKFLOW = 3 if TYPE_CHECKING: @@ -28,6 +29,7 @@ class OperationToken: type: OperationTokenType namespace: str workflow_id: str + update_id: str | None = None def encode(self) -> str: """Convert handle to a base64url-encoded token string.""" @@ -38,6 +40,8 @@ def encode(self) -> str: } if self.version is not None: token_details["v"] = self.version + if self.update_id is not None: + token_details["uid"] = self.update_id return _base64url_encode_no_padding( json.dumps( token_details, @@ -88,9 +92,24 @@ def decode(cls, token: str) -> Self: f"invalid token: expected workflow id to be a string, got {type(workflow_id)}" ) - if token_type == OperationTokenType.WORKFLOW and not workflow_id: + if ( + token_type == OperationTokenType.WORKFLOW + or token_type == OperationTokenType.UPDATE_WORKFLOW + ): + if not workflow_id: + raise TypeError( + f"invalid token: expected non-empty workflow id for token type `{token_type.name}`" + ) + + update_id = token_details.get("uid") + if not isinstance(update_id, str | None): raise TypeError( - "invalid token: expected non-empty workflow id for token type `WORKFLOW`" + f"invalid token: expected update_id id to be a string or None, got {type(update_id)}" + ) + + if token_type == OperationTokenType.UPDATE_WORKFLOW and not update_id: + raise TypeError( + "invalid token: expected non-empty update id for token type `UPDATE_WORKFLOW`" ) namespace = token_details.get("ns") @@ -105,6 +124,7 @@ def decode(cls, token: str) -> Self: namespace=namespace, workflow_id=workflow_id, version=version, + update_id=update_id, ) @@ -180,6 +200,87 @@ def from_token(cls, token: str) -> WorkflowHandle[OutputT]: ) +@dataclass(frozen=True) +class UpdateHandle(Generic[OutputT]): + """A handle to a workflow update that is backing a Nexus operation. + + Do not instantiate this directly. Use + :py:func:`temporalio.nexus.TemporalNexusClient.start_workflow_update` to create a + handle. + """ + + namespace: str + workflow_id: str + update_id: str + # Version of the token. Treated as v1 if missing. This field is not included in the + # serialized token; it's only used to reject newer token versions on load. + version: int | None = None + + def _to_client_workflow_update_handle( + self, + client: temporalio.client.Client, + result_type: type[OutputT] | None = None, + ) -> temporalio.client.WorkflowUpdateHandle[Any]: + """Create a :py:class:`temporalio.client.WorkflowUpdateHandle` from the token.""" + if client.namespace != self.namespace: + raise ValueError( + f"Client namespace {client.namespace} does not match " + f"operation token namespace {self.namespace}" + ) + workflow_handle = client.get_workflow_handle(self.workflow_id) + return workflow_handle.get_update_handle( + self.update_id, result_type=result_type + ) + + @classmethod + def _unsafe_from_client_workflow_update_handle( + cls, workflow_update_handle: temporalio.client.WorkflowUpdateHandle[Any] + ) -> UpdateHandle[OutputT]: + """Create a :py:class:`UpdateHandle` from a :py:class:`temporalio.client.WorkflowUpdateHandle`. + + This is a private method not intended to be used by users. It does not check + that the supplied client.WorkflowUpdateHandle references a workflow that has been + instrumented to supply the result of a Nexus operation. + """ + return cls( + namespace=workflow_update_handle._client.namespace, + workflow_id=workflow_update_handle.workflow_id, + update_id=workflow_update_handle._id, + ) + + def to_token(self) -> str: + """Convert handle to a base64url-encoded token string.""" + return OperationToken( + type=OperationTokenType.UPDATE_WORKFLOW, + namespace=self.namespace, + workflow_id=self.workflow_id, + update_id=self.update_id, + ).encode() + + @classmethod + def from_token(cls, token: str) -> UpdateHandle[OutputT]: + """Decodes and validates a token from its base64url-encoded string representation.""" + op_token = OperationToken.decode(token) + if op_token.type != OperationTokenType.UPDATE_WORKFLOW: + raise TypeError( + f"invalid update token type: {op_token.type}, expected: {OperationTokenType.UPDATE_WORKFLOW}" + ) + + if op_token.version is not None and op_token.version != 0: + raise TypeError( + "invalid update token: 'v' field, if present, must be 0 or null/absent" + ) + + assert op_token.update_id is not None + + return cls( + namespace=op_token.namespace, + workflow_id=op_token.workflow_id, + update_id=op_token.update_id, + version=op_token.version, + ) + + def _base64url_encode_no_padding(data: bytes) -> str: return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=") diff --git a/tests/conftest.py b/tests/conftest.py index e01773e7e..bfe005ede 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -132,7 +132,7 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: "--dynamic-config-value", "history.enableTransitionHistory=true", "--dynamic-config-value", - "history.enableChasmCallbacks=true", + "history.enableCHASMCallbacks=true", "--dynamic-config-value", "history.enableCHASMSignalBacklinks=true", "--dynamic-config-value", @@ -141,6 +141,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: 'system.system.refreshNexusEndpointsMinWait="0s"', "--dynamic-config-value", "history.enableSignalWithStartFromWorkflow=true", + "--dynamic-config-value", + "history.enableUpdateCallbacks=true", ], dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION, ) diff --git a/tests/nexus/test_operation_token.py b/tests/nexus/test_operation_token.py index 385f4f872..ee5e4fac4 100644 --- a/tests/nexus/test_operation_token.py +++ b/tests/nexus/test_operation_token.py @@ -7,6 +7,7 @@ from temporalio.nexus._token import ( OperationToken, OperationTokenType, + UpdateHandle, WorkflowHandle, ) @@ -42,6 +43,13 @@ def test_workflow_handle_to_from_token_round_trip(): assert WorkflowHandle[str].from_token(handle.to_token()) == handle +def test_update_handle_to_from_token_round_trip(): + handle = UpdateHandle[str]( + namespace="default", workflow_id="workflow-id", update_id="update-id" + ) + assert UpdateHandle[str].from_token(handle.to_token()) == handle + + @pytest.mark.parametrize( ("token", "expected"), [ diff --git a/tests/nexus/test_temporal_operation.py b/tests/nexus/test_temporal_operation.py index c97792c8d..7961cd585 100644 --- a/tests/nexus/test_temporal_operation.py +++ b/tests/nexus/test_temporal_operation.py @@ -23,6 +23,8 @@ class Input: value: str task_queue: str + update_value: str = "" + update_id: str = "" def test_temporal_operation_result_validates_single_result_kind() -> None: @@ -63,6 +65,7 @@ class TestService: retry_after_failed_start: Operation[Input, str] sync_result: Operation[Input, str] custom_cancel: Operation[str, None] + update_op: Operation[Input, str] @service_handler(service=TestService) @@ -216,6 +219,19 @@ async def cancel_workflow_run( return CustomCancelNexusOpHandler() + @nexus.temporal_operation + async def update_op( + self, + _ctx: nexus.TemporalStartOperationContext, + client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[str]: + # input.value carries the target workflow_id, input.update_value has actual update + update_id = input.update_id or str(uuid.uuid4()) + return await client.start_workflow_update( + input.value, UpdatableWorkflow.do_update, input.update_value, id=update_id + ) + @workflow.defn class EchoWorkflowCaller: @@ -227,6 +243,19 @@ async def run(self, input: Input) -> str: return await client.execute_operation(TestService.echo, input) +@workflow.defn +class UpdateWorkflowCaller: + """Simple caller workflow that triggers a workflow update via nexus op""" + + @workflow.run + async def run(self, input: Input) -> str: + client = workflow.create_nexus_client( + service=TestService, + endpoint=make_nexus_endpoint_name(input.task_queue), + ) + return await client.execute_operation(TestService.update_op, input) + + async def test_temporal_operation_start_workflow( client: Client, env: WorkflowEnvironment ): @@ -258,6 +287,87 @@ async def test_temporal_operation_start_workflow( ) +async def test_temporal_operation_update_workflow( + client: Client, env: WorkflowEnvironment +) -> None: + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[TestServiceHandler()], + workflows=[UpdatableWorkflow, UpdateWorkflowCaller], + ): + update_workflow_id = f"updatable-workflow-{uuid.uuid4()}" + await client.start_workflow( + UpdatableWorkflow.run, id=update_workflow_id, task_queue=task_queue + ) + wf_handle = await client.start_workflow( + UpdateWorkflowCaller.run, + Input( + value=update_workflow_id, task_queue=task_queue, update_value="Created" + ), + task_queue=task_queue, + id=f"update-workflow-caller-created", + ) + result = await wf_handle.result() + assert result == "Updated workflow status from Pending to Created" + + await assert_event_subsequence( + wf_handle, + [ + EventType.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED, + EventType.EVENT_TYPE_NEXUS_OPERATION_STARTED, + EventType.EVENT_TYPE_NEXUS_OPERATION_COMPLETED, + ], + ) + + stable_update_id = str(uuid.uuid4()) + wf_handle = await client.start_workflow( + UpdateWorkflowCaller.run, + Input( + value=update_workflow_id, + task_queue=task_queue, + update_value="Processed", + update_id=stable_update_id, + ), + task_queue=task_queue, + id=f"update-workflow-caller-processed", + ) + result = await wf_handle.result() + assert result == "Updated workflow status from Created to Processed" + + # same update_id -> wont be processed again, receives a sync result + wf_handle = await client.start_workflow( + UpdateWorkflowCaller.run, + Input( + value=update_workflow_id, + task_queue=task_queue, + update_value="Processed", + update_id=stable_update_id, + ), + task_queue=task_queue, + id=f"update-workflow-caller-processed", + ) + # wf_handle. + result = await wf_handle.result() + assert result == "Updated workflow status from Created to Processed" + + wf_handle = await client.start_workflow( + UpdateWorkflowCaller.run, + Input( + value=update_workflow_id, + task_queue=task_queue, + update_value="Completed", + ), + task_queue=task_queue, + id=f"update-workflow-caller-completed", + ) + result = await wf_handle.result() + assert result == "Updated workflow status from Processed to Completed" + + @workflow.defn class BlockingWorkflow: def __init__(self) -> None: @@ -724,3 +834,23 @@ async def test_temporal_operation_includes_token_in_callback( ).encode() assert token == expected_token + + +@workflow.defn +class UpdatableWorkflow: + """Workflow that accepts updates and exits when it receives a specific status""" + + def __init__(self) -> None: + self.order_status = "Pending" + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self.order_status == "Completed") + # some more order processing etc + + @workflow.update + async def do_update(self, value: str) -> str: + status = self.order_status + self.order_status = value + update_result = f"Updated workflow status from {status} to {value}" + return update_result