Skip to content

Commit 0ac2053

Browse files
committed
Test exception in Nexus start method
1 parent cc7d665 commit 0ac2053

File tree

1 file changed

+82
-10
lines changed

1 file changed

+82
-10
lines changed

tests/worker/test_nexus.py

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,15 @@ class OpDefinitionType(IntEnum):
5656
class SyncResponse:
5757
op_definition_type: OpDefinitionType
5858
use_async_def: bool
59+
exception_in_operation_start: bool
5960

6061

6162
@dataclass
6263
class AsyncResponse:
6364
operation_workflow_id: str
6465
block_forever_waiting_for_cancellation: bool
6566
op_definition_type: OpDefinitionType
67+
exception_in_operation_start: bool
6668

6769

6870
# The order of the two types in this union is critical since the data converter matches
@@ -101,6 +103,10 @@ class ServiceInterface:
101103
#
102104

103105

106+
class CustomError(Exception):
107+
pass
108+
109+
104110
@dataclass
105111
class HandlerWfInput:
106112
op_input: OpInput
@@ -136,6 +142,8 @@ async def start(
136142
OpOutput,
137143
temporalio.nexus.handler.StartWorkflowOperationResult[HandlerWfOutput],
138144
]:
145+
if input.response_type.exception_in_operation_start:
146+
raise CustomError("Error in Nexus operation start method")
139147
if isinstance(input.response_type, SyncResponse):
140148
return OpOutput(
141149
value="sync response",
@@ -178,6 +186,8 @@ async def sync_operation(
178186
self, input: OpInput, options: nexusrpc.handler.StartOperationOptions
179187
) -> OpOutput:
180188
assert isinstance(input.response_type, SyncResponse)
189+
if input.response_type.exception_in_operation_start:
190+
raise CustomError("Error in Nexus operation start method")
181191
return OpOutput(
182192
value="sync response",
183193
start_options_received_by_handler=options,
@@ -188,6 +198,8 @@ def non_async_sync_operation(
188198
self, input: OpInput, options: nexusrpc.handler.StartOperationOptions
189199
) -> OpOutput:
190200
assert isinstance(input.response_type, SyncResponse)
201+
if input.response_type.exception_in_operation_start:
202+
raise CustomError("Error in Nexus operation start method")
191203
return OpOutput(
192204
value="sync response",
193205
start_options_received_by_handler=options,
@@ -198,6 +210,8 @@ async def async_operation(
198210
self, input: OpInput, options: nexusrpc.handler.StartOperationOptions
199211
) -> temporalio.nexus.handler.StartWorkflowOperationResult[HandlerWfOutput]:
200212
assert isinstance(input.response_type, AsyncResponse)
213+
if input.response_type.exception_in_operation_start:
214+
raise CustomError("Error in Nexus operation start method")
201215
return await temporalio.nexus.handler.start_workflow(
202216
HandlerWorkflow.run,
203217
args=[HandlerWfInput(op_input=input), options],
@@ -424,6 +438,8 @@ async def run(
424438
# TODO(dan): nexus endpoint pytest fixture?
425439
# TODO(dan): get rid of UnsandboxedWorkflowRunner (due to xray)
426440
# TODO(dan): test headers
441+
# TODO(dan): enable True for exception_in_operation_start
442+
@pytest.mark.parametrize("exception_in_operation_start", [False])
427443
@pytest.mark.parametrize("request_cancel", [False, True])
428444
@pytest.mark.parametrize(
429445
"op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND]
@@ -434,6 +450,7 @@ async def run(
434450
)
435451
async def test_sync_response(
436452
client: Client,
453+
exception_in_operation_start: bool,
437454
request_cancel: bool,
438455
op_definition_type: OpDefinitionType,
439456
caller_reference: CallerReference,
@@ -453,7 +470,11 @@ async def test_sync_response(
453470
args=[
454471
CallerWfInput(
455472
op_input=OpInput(
456-
response_type=SyncResponse(op_definition_type, True),
473+
response_type=SyncResponse(
474+
op_definition_type=op_definition_type,
475+
use_async_def=True,
476+
exception_in_operation_start=exception_in_operation_start,
477+
),
457478
start_options=nexusrpc.handler.StartOperationOptions(
458479
headers={"header-key": "header-value"},
459480
),
@@ -469,11 +490,30 @@ async def test_sync_response(
469490

470491
# The operation result is returned even when request_cancel=True, because the
471492
# response was synchronous and it could not be cancelled. See explanation below.
472-
result = await caller_wf_handle.result()
473-
assert result.op_output.value == "sync response"
474-
assert result.op_output.start_options_received_by_handler
493+
if exception_in_operation_start:
494+
with pytest.raises(WorkflowFailureError) as ei:
495+
await caller_wf_handle.result()
496+
e = ei.value
497+
assert isinstance(e, WorkflowFailureError)
498+
assert isinstance(e.__cause__, NexusOperationError)
499+
assert isinstance(e.__cause__.__cause__, CustomError)
500+
# ID of first command after update accepted
501+
assert e.__cause__.scheduled_event_id == 6
502+
assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue)
503+
assert e.__cause__.service == "ServiceInterface"
504+
assert (
505+
e.__cause__.operation == "sync_operation"
506+
if op_definition_type == OpDefinitionType.SHORTHAND
507+
else "sync_or_async_operation"
508+
)
509+
else:
510+
result = await caller_wf_handle.result()
511+
assert result.op_output.value == "sync response"
512+
assert result.op_output.start_options_received_by_handler
475513

476514

515+
# TODO(dan): enable True for exception_in_operation_start
516+
@pytest.mark.parametrize("exception_in_operation_start", [False])
477517
@pytest.mark.parametrize("request_cancel", [False, True])
478518
@pytest.mark.parametrize(
479519
"op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND]
@@ -484,6 +524,7 @@ async def test_sync_response(
484524
)
485525
async def test_async_response(
486526
client: Client,
527+
exception_in_operation_start: bool,
487528
request_cancel: bool,
488529
op_definition_type: OpDefinitionType,
489530
caller_reference: CallerReference,
@@ -499,7 +540,12 @@ async def test_async_response(
499540
workflow_failure_exception_types=[Exception],
500541
):
501542
caller_wf_handle, handler_wf_handle = await _start_wf_and_nexus_op(
502-
client, task_queue, request_cancel, op_definition_type, caller_reference
543+
client,
544+
task_queue,
545+
exception_in_operation_start,
546+
request_cancel,
547+
op_definition_type,
548+
caller_reference,
503549
)
504550
# TODO(dan): race here? How do we know it hasn't been canceled already?
505551
handler_wf_info = await handler_wf_handle.describe()
@@ -514,7 +560,23 @@ async def test_async_response(
514560
caller_wf_handle, handler_wf_handle
515561
)
516562

517-
if request_cancel:
563+
if exception_in_operation_start:
564+
with pytest.raises(WorkflowFailureError) as ei:
565+
await caller_wf_handle.result()
566+
e = ei.value
567+
assert isinstance(e, WorkflowFailureError)
568+
assert isinstance(e.__cause__, NexusOperationError)
569+
assert isinstance(e.__cause__.__cause__, CustomError)
570+
# ID of first command after update accepted
571+
assert e.__cause__.scheduled_event_id == 6
572+
assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue)
573+
assert e.__cause__.service == "ServiceInterface"
574+
assert (
575+
e.__cause__.operation == "async_operation"
576+
if op_definition_type == OpDefinitionType.SHORTHAND
577+
else "sync_or_async_operation"
578+
)
579+
elif request_cancel:
518580
# The operation response was asynchronous and so request_cancel is honored. See
519581
# explanation below.
520582
with pytest.raises(WorkflowFailureError) as ei:
@@ -549,6 +611,7 @@ async def test_async_response(
549611
async def _start_wf_and_nexus_op(
550612
client: Client,
551613
task_queue: str,
614+
exception_in_operation_start: bool,
552615
request_cancel: bool,
553616
op_definition_type: OpDefinitionType,
554617
caller_reference: CallerReference,
@@ -574,6 +637,7 @@ async def _start_wf_and_nexus_op(
574637
operation_workflow_id,
575638
block_forever_waiting_for_cancellation,
576639
op_definition_type,
640+
exception_in_operation_start=exception_in_operation_start,
577641
),
578642
start_options=nexusrpc.handler.StartOperationOptions(
579643
headers={"header-key": "header-value"},
@@ -605,6 +669,8 @@ async def _start_wf_and_nexus_op(
605669
return caller_wf_handle, handler_wf_handle
606670

607671

672+
# TODO(dan): enable True for exception_in_operation_start
673+
@pytest.mark.parametrize("exception_in_operation_start", [False])
608674
@pytest.mark.parametrize(
609675
"op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND]
610676
)
@@ -615,6 +681,7 @@ async def _start_wf_and_nexus_op(
615681
@pytest.mark.parametrize("response_type", [SyncResponse, AsyncResponse])
616682
async def test_untyped_caller(
617683
client: Client,
684+
exception_in_operation_start: bool,
618685
op_definition_type: OpDefinitionType,
619686
caller_reference: CallerReference,
620687
response_type: ResponseType,
@@ -629,12 +696,17 @@ async def test_untyped_caller(
629696
workflow_failure_exception_types=[Exception],
630697
):
631698
if response_type == SyncResponse:
632-
response_type = SyncResponse(op_definition_type, True)
699+
response_type = SyncResponse(
700+
op_definition_type=op_definition_type,
701+
use_async_def=True,
702+
exception_in_operation_start=exception_in_operation_start,
703+
)
633704
else:
634705
response_type = AsyncResponse(
635-
str(uuid.uuid4()),
636-
False,
637-
op_definition_type,
706+
operation_workflow_id=str(uuid.uuid4()),
707+
block_forever_waiting_for_cancellation=False,
708+
op_definition_type=op_definition_type,
709+
exception_in_operation_start=exception_in_operation_start,
638710
)
639711
await create_nexus_endpoint(task_queue, client)
640712
caller_wf_handle = await client.start_workflow(

0 commit comments

Comments
 (0)