@@ -384,13 +384,24 @@ async def run(
384
384
op_name = "sync_operation"
385
385
else :
386
386
raise TypeError
387
- op_handle = await self .nexus_service .start_operation (
388
- op_name ,
389
- op_input ,
390
- headers = op_input .start_options .headers ,
391
- output_type = OpOutput ,
392
- )
393
- op_output = await op_handle
387
+
388
+ arbitrary_condition = isinstance (op_input .response_type , SyncResponse )
389
+
390
+ if arbitrary_condition :
391
+ op_handle = await self .nexus_service .start_operation (
392
+ op_name ,
393
+ op_input ,
394
+ headers = op_input .start_options .headers ,
395
+ output_type = OpOutput ,
396
+ )
397
+ op_output = await op_handle
398
+ else :
399
+ op_output = await self .nexus_service .execute_operation (
400
+ op_name ,
401
+ op_input ,
402
+ headers = op_input .start_options .headers ,
403
+ output_type = OpOutput ,
404
+ )
394
405
return CallerWfOutput (
395
406
op_output = OpOutput (
396
407
value = op_output .value ,
@@ -596,27 +607,37 @@ async def _start_wf_and_nexus_op(
596
607
"caller_reference" ,
597
608
[CallerReference .IMPL_WITH_INTERFACE , CallerReference .INTERFACE ],
598
609
)
610
+ @pytest .mark .parametrize ("response_type" , [SyncResponse , AsyncResponse ])
599
611
async def test_untyped_caller (
600
612
client : Client ,
601
613
op_definition_type : OpDefinitionType ,
602
614
caller_reference : CallerReference ,
615
+ response_type : ResponseType ,
603
616
):
604
617
task_queue = str (uuid .uuid4 ())
605
618
async with Worker (
606
619
client ,
607
- workflows = [UntypedCallerWorkflow ],
620
+ workflows = [UntypedCallerWorkflow , HandlerWorkflow ],
608
621
nexus_services = [ServiceImpl ()],
609
622
task_queue = task_queue ,
610
623
workflow_runner = UnsandboxedWorkflowRunner (),
611
624
workflow_failure_exception_types = [Exception ],
612
625
):
626
+ if response_type == SyncResponse :
627
+ response_type = SyncResponse (op_definition_type , True )
628
+ else :
629
+ response_type = AsyncResponse (
630
+ str (uuid .uuid4 ()),
631
+ False ,
632
+ op_definition_type ,
633
+ )
613
634
await create_nexus_endpoint (task_queue , client )
614
635
caller_wf_handle = await client .start_workflow (
615
636
UntypedCallerWorkflow .run ,
616
637
args = [
617
638
CallerWfInput (
618
639
op_input = OpInput (
619
- response_type = SyncResponse ( op_definition_type , True ) ,
640
+ response_type = response_type ,
620
641
start_options = nexusrpc .handler .StartOperationOptions (),
621
642
caller_reference = caller_reference ,
622
643
),
@@ -628,7 +649,11 @@ async def test_untyped_caller(
628
649
task_queue = task_queue ,
629
650
)
630
651
result = await caller_wf_handle .result ()
631
- assert result .op_output .value == "sync response"
652
+ assert result .op_output .value == (
653
+ "sync response"
654
+ if isinstance (response_type , SyncResponse )
655
+ else "workflow result"
656
+ )
632
657
assert result .op_output .start_options_received_by_handler
633
658
634
659
0 commit comments