1- import asyncio
21import uuid
32from dataclasses import dataclass
43from datetime import timedelta
5- from typing import Union , cast
4+ from typing import Union
65
76import nexusrpc
87import nexusrpc .handler
@@ -34,6 +33,7 @@ class SyncResponse:
3433@dataclass
3534class AsyncResponse :
3635 operation_workflow_id : str
36+ block_forever_waiting_for_cancellation : bool
3737
3838
3939# The ordering in this union is critical since the data converter matches eagerly,
@@ -47,7 +47,6 @@ class AsyncResponse:
4747@dataclass
4848class MyInput :
4949 response_type : ResponseType
50- block_forever_waiting_for_cancellation : bool
5150
5251
5352@dataclass
@@ -86,7 +85,7 @@ async def start(
8685 elif isinstance (input .response_type , AsyncResponse ):
8786 return await temporalio .nexus .handler .start_workflow (
8887 MyHandlerWorkflow .run ,
89- input .block_forever_waiting_for_cancellation ,
88+ input .response_type . block_forever_waiting_for_cancellation ,
9089 id = input .response_type .operation_workflow_id ,
9190 options = options ,
9291 )
@@ -126,7 +125,7 @@ class MyCallerWorkflow:
126125 def __init__ (
127126 self ,
128127 response_type : ResponseType ,
129- should_cancel : bool ,
128+ request_cancel : bool ,
130129 task_queue : str ,
131130 ) -> None :
132131 self .nexus_service = workflow .NexusClient (
@@ -141,39 +140,33 @@ def __init__(
141140 async def run (
142141 self ,
143142 response_type : ResponseType ,
144- should_cancel : bool ,
143+ request_cancel : bool ,
145144 task_queue : str ,
146145 ) -> str :
147146 op_handle = await self .nexus_service .start_operation (
148147 MyService .my_operation ,
149- MyInput (
150- response_type ,
151- block_forever_waiting_for_cancellation = should_cancel ,
152- ),
148+ MyInput (response_type ),
153149 )
150+ print (f"🌈 { 'after await start' :<24} : { op_handle } " )
154151 self ._nexus_operation_started = True
155- task = cast (asyncio .Task , getattr (op_handle , "_task" ))
156152 if isinstance (response_type , SyncResponse ):
157153 assert op_handle .operation_token is None
158- # TODO(dan): I expected task to be done at this point
159- # assert task.done()
160- # assert not task.exception()
161- if should_cancel :
162- # TODO(dan): why does this assert pass (same Q as above re task.done())
163- assert op_handle .cancel ()
164- elif isinstance (response_type , AsyncResponse ):
154+ else :
165155 assert op_handle .operation_token
166- assert not task .done ()
167- # Allow the test to control when we proceed so that it can make initial
168- # assertions.
156+ # Allow the test to make assertions before signalling us to proceed.
169157 await workflow .wait_condition (lambda : self ._proceed )
158+ print (f"🌈 { 'after await proceed' :<24} : { op_handle } " )
170159
171- if should_cancel :
172- # We cannot assert that cancel() returns True because it's possible that a
173- # resolve_nexus_operation job has already come in.
174- op_handle .cancel ()
160+ if request_cancel :
161+ # Even for SyncResponse, the op_handle future is not done at this point; that
162+ # transition doesn't happen until the handle is awaited.
163+ print (f"🌈 { 'before op_handle.cancel' :<24} : { op_handle } " )
164+ cancel_ret = op_handle .cancel ()
165+ print (f"🌈 { 'cancel returned' :<24} : { cancel_ret } " )
175166
167+ print (f"🌈 { 'before await op_handle' :<24} : { op_handle } " )
176168 result = await op_handle
169+ print (f"🌈 { 'after await op_handle' :<24} : { op_handle } " )
177170 return result .val
178171
179172 @workflow .update
@@ -190,32 +183,64 @@ def proceed(self) -> None:
190183#
191184
192185
193- # TODO(dan): cross-namespace tests
194- # TODO(dan): nexus endpoint pytest fixture?
195- @pytest .mark .parametrize ("should_attempt_cancel" , [False , True ])
196- async def test_sync_response (client : Client , should_attempt_cancel : bool ):
197- task_queue = str (uuid .uuid4 ())
198- async with Worker (
199- client ,
200- nexus_services = [MyServiceImpl ()],
201- workflows = [MyCallerWorkflow , MyHandlerWorkflow ],
202- task_queue = task_queue ,
203- workflow_runner = UnsandboxedWorkflowRunner (),
204- ):
205- await create_nexus_endpoint (task_queue , client )
206- wf_handle = await client .start_workflow (
207- MyCallerWorkflow .run ,
208- args = [SyncResponse (), should_attempt_cancel , task_queue ],
209- id = str (uuid .uuid4 ()),
210- task_queue = task_queue ,
211- )
212- # The response is synchronous, so the workflow's attempt to cancel the
213- # NexusOperationHandle do not result in cancellation.
214- result = await wf_handle .result ()
215- assert result == "sync response"
186+ # When request_cancel is True, the NexusOperationHandle in the workflow evolves
187+ # through the following states:
188+ # start_fut result_fut handle_task w/ fut_waiter (task._must_cancel)
189+ #
190+ # Case 1: Sync Nexus operation response w/ cancellation of NexusOperationHandle
191+ # -----------------------------------------------------------------------------
192+ # >>>>>>>>>>>> WFT 1
193+ # after await start : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (False)
194+ # before op_handle.cancel : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (False)
195+ # Future_8240[FINISHED].cancel() -> False # no state transition; fut_waiter is already finished
196+ # cancel returned : True
197+ # before await op_handle : Future_7856[FINISHED] Future_7984[FINISHED] Task[PENDING] fut_waiter = Future_8240[FINISHED]) (True)
198+ # --> Despite cancel having been requested, this await on the nexus op handle does not
199+ # raise CancelledError, because the task's underlying fut_waiter is already finished.
200+ # after await op_handle : Future_7856[FINISHED] Future_7984[FINISHED] Task[FINISHED] fut_waiter = None) (False)
201+ # <<<<<<<<<<<< END WFT 1
202+ #
203+
204+ # Case 2: Async Nexus operation response w/ cancellation of NexusOperationHandle
205+ # ------------------------------------------------------------------------------
206+ # >>>>>>>>>>>> WFT 1
207+ # after await start : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False)
208+ # >>>>>>>>>>>> WFT 2
209+ # >>>>>>>>>>>> WFT 3
210+ # after await proceed : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False)
211+ # before op_handle.cancel : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[PENDING]) (False)
212+ # Future_7952[PENDING].cancel() -> True # transition to cancelled state; fut_waiter was not finished
213+ # cancel returned : True
214+ # before await op_handle : Future_7568[FINISHED] Future_7696[PENDING] Task[PENDING] fut_waiter = Future_7952[CANCELLED]) (False)
215+ # --> This await on the nexus op handle raises CancelledError, because the task's underlying fut_waiter is cancelled.
216+
217+ # Thus in the sync case, although the caller workflow attempted to cancel the
218+ # NexusOperationHandle, this did not result in a CancelledError when the handle was
219+ # awaited, because both resolve_nexus_operation_start and resolve_nexus_operation jobs
220+ # were sent in the same activation and hence the task's fut_waiter was already finished.
221+ #
222+ # But in the async case, at the time that we cancel the NexusOperationHandle, only the
223+ # resolve_nexus_operation_start job had been sent; the result_fut was unresolved. Thus
224+ # when the handle was awaited, CancelledError was raised.
225+
226+ # To create output like that above, set the following __repr__s:
227+ # asyncio.Future:
228+ # def __repr__(self):
229+ # return f"{self.__class__.__name__}_{str(id(self))[-4:]}[{self._state}]"
230+ # _NexusOperationHandle:
231+ # def __repr__(self) -> str:
232+ # return (
233+ # f"{self._start_fut} "
234+ # f"{self._result_fut} "
235+ # f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})"
236+ # )
216237
217238
218- async def test_async_response (client : Client ):
239+ # TODO(dan): cross-namespace tests
240+ # TODO(dan): nexus endpoint pytest fixture?
241+ # TODO: get rid of UnsandboxedWorkflowRunner (due to xray)
242+ @pytest .mark .parametrize ("request_cancel" , [False , True ])
243+ async def test_sync_response (client : Client , request_cancel : bool ):
219244 task_queue = str (uuid .uuid4 ())
220245 async with Worker (
221246 client ,
@@ -224,37 +249,22 @@ async def test_async_response(client: Client):
224249 task_queue = task_queue ,
225250 workflow_runner = UnsandboxedWorkflowRunner (),
226251 ):
227- operation_workflow_id = str (uuid .uuid4 ())
228- operation_workflow_handle = client .get_workflow_handle (operation_workflow_id )
229252 await create_nexus_endpoint (task_queue , client )
230-
231- # Start the caller workflow
232253 wf_handle = await client .start_workflow (
233254 MyCallerWorkflow .run ,
234- args = [AsyncResponse ( operation_workflow_id ), False , task_queue ],
255+ args = [SyncResponse ( ), request_cancel , task_queue ],
235256 id = str (uuid .uuid4 ()),
236257 task_queue = task_queue ,
237258 )
238259
239- # Wait for the Nexus operation to start and check that the operation-backing workflow now exists.
240- await wf_handle .execute_update (MyCallerWorkflow .wait_nexus_operation_started )
241- wf_details = await operation_workflow_handle .describe ()
242- assert wf_details .status in [
243- WorkflowExecutionStatus .RUNNING ,
244- WorkflowExecutionStatus .COMPLETED ,
245- ]
246-
247- # Wait for the Nexus operation to complete and check that the operation-backing
248- # workflow has completed.
249- await wf_handle .signal (MyCallerWorkflow .proceed )
250-
251- wf_details = await operation_workflow_handle .describe ()
252- assert wf_details .status == WorkflowExecutionStatus .COMPLETED
260+ # The operation result is returned even when request_cancel=True, because the
261+ # response was synchronous and it could not be cancelled. See explanation above.
253262 result = await wf_handle .result ()
254- assert result == "workflow result "
263+ assert result == "sync response "
255264
256265
257- async def test_cancellation_of_async_response (client : Client ):
266+ @pytest .mark .parametrize ("request_cancel" , [False , True ])
267+ async def test_async_response (client : Client , request_cancel : bool ):
258268 task_queue = str (uuid .uuid4 ())
259269 async with Worker (
260270 client ,
@@ -268,9 +278,16 @@ async def test_cancellation_of_async_response(client: Client):
268278 await create_nexus_endpoint (task_queue , client )
269279
270280 # Start the caller workflow
281+ block_forever_waiting_for_cancellation = request_cancel
271282 wf_handle = await client .start_workflow (
272283 MyCallerWorkflow .run ,
273- args = [AsyncResponse (operation_workflow_id ), True , task_queue ],
284+ args = [
285+ AsyncResponse (
286+ operation_workflow_id , block_forever_waiting_for_cancellation
287+ ),
288+ request_cancel ,
289+ task_queue ,
290+ ],
274291 id = str (uuid .uuid4 ()),
275292 task_queue = task_queue ,
276293 )
@@ -284,15 +301,23 @@ async def test_cancellation_of_async_response(client: Client):
284301 ]
285302
286303 await wf_handle .signal (MyCallerWorkflow .proceed )
287- # The caller workflow will now cancel the op_handle, and await it.
288304
289- # TODO(dan): assert what type of exception is raised here
290- with pytest .raises (BaseException ) as ei :
291- await wf_handle .result ()
292- e = ei .value
293- print (f"🌈 workflow failed: { e .__class__ .__name__ } ({ e } )" )
294- wf_details = await operation_workflow_handle .describe ()
295- assert wf_details .status == WorkflowExecutionStatus .CANCELED
305+ # The operation response was asynchronous and so request_cancel is honored. See
306+ # explanation above.
307+ if request_cancel :
308+ # The caller workflow now cancels the op_handle, and awaits it, resulting in a
309+ # CancellationError in the caller workflow.
310+ with pytest .raises (BaseException ) as ei :
311+ await wf_handle .result ()
312+ e = ei .value
313+ print (f"🌈 workflow failed: { e .__class__ .__name__ } ({ e } )" )
314+ wf_details = await operation_workflow_handle .describe ()
315+ assert wf_details .status == WorkflowExecutionStatus .CANCELED
316+ else :
317+ wf_details = await operation_workflow_handle .describe ()
318+ assert wf_details .status == WorkflowExecutionStatus .COMPLETED
319+ result = await wf_handle .result ()
320+ assert result == "workflow result"
296321
297322
298323def make_nexus_endpoint_name (task_queue : str ) -> str :
0 commit comments