2
2
import uuid
3
3
from dataclasses import dataclass
4
4
from datetime import timedelta
5
- from typing import Optional , Union
5
+ from enum import StrEnum
6
+ from typing import Any , Callable , Optional , Union
6
7
7
8
import nexusrpc
8
9
import nexusrpc .handler
34
35
# -----------------------------------------------------------------------------
35
36
# Service interface
36
37
#
38
+ class CallerReference (StrEnum ):
39
+ IMPLEMENTATION = "implementation"
40
+ INTERFACE = "interface"
41
+
42
+
43
+ class OpDefinitionType (StrEnum ):
44
+ SHORTHAND = "shorthand"
45
+ LONGHAND = "longhand"
46
+
47
+
37
48
@dataclass
38
49
class SyncResponse :
39
- use_shorthand_defined_operation : bool
50
+ op_definition_type : OpDefinitionType
40
51
41
52
42
53
@dataclass
43
54
class AsyncResponse :
44
55
operation_workflow_id : str
45
56
block_forever_waiting_for_cancellation : bool
46
- use_shorthand_defined_operation : bool
57
+ op_definition_type : OpDefinitionType
47
58
48
59
49
60
# The order of the two types in this union is critical since the data converter matches
@@ -56,6 +67,7 @@ class AsyncResponse:
56
67
class MyInput :
57
68
response_type : ResponseType
58
69
start_options : nexusrpc .handler .StartOperationOptions
70
+ caller_reference : CallerReference
59
71
60
72
61
73
@dataclass
@@ -65,7 +77,7 @@ class MyOutput:
65
77
66
78
67
79
@nexusrpc .interface .service
68
- class MyService :
80
+ class MyServiceInterface :
69
81
my_sync_or_async_operation : nexusrpc .interface .Operation [MyInput , MyOutput ]
70
82
my_sync_operation : nexusrpc .interface .Operation [MyInput , MyOutput ]
71
83
my_async_operation : nexusrpc .interface .Operation [MyInput , MyOutput ]
@@ -127,7 +139,7 @@ async def fetch_result(self, *args, **kwargs):
127
139
raise NotImplementedError
128
140
129
141
130
- @nexusrpc .handler .service (interface = MyService )
142
+ @nexusrpc .handler .service (interface = MyServiceInterface )
131
143
class MyServiceImpl :
132
144
@nexusrpc .handler .operation
133
145
def my_sync_or_async_operation (
@@ -176,7 +188,10 @@ def __init__(
176
188
task_queue : str ,
177
189
) -> None :
178
190
self .nexus_service = workflow .NexusClient (
179
- service = MyService ,
191
+ service = {
192
+ CallerReference .IMPLEMENTATION : MyServiceImpl ,
193
+ CallerReference .INTERFACE : MyServiceInterface ,
194
+ }[input .caller_reference ],
180
195
endpoint = make_nexus_endpoint_name (task_queue ),
181
196
schedule_to_close_timeout = timedelta (seconds = 10 ),
182
197
)
@@ -190,15 +205,7 @@ async def run(
190
205
request_cancel : bool ,
191
206
task_queue : str ,
192
207
) -> MyOutput :
193
- if input .response_type .use_shorthand_defined_operation :
194
- operation = (
195
- MyService .my_sync_operation
196
- if isinstance (input .response_type , SyncResponse )
197
- else MyService .my_async_operation
198
- )
199
- else :
200
- operation = MyService .my_sync_or_async_operation
201
-
208
+ operation = self ._get_operation (input )
202
209
op_handle = await self .nexus_service .start_operation (
203
210
operation ,
204
211
input ,
@@ -231,6 +238,61 @@ async def run(
231
238
async def wait_nexus_operation_started (self ) -> None :
232
239
await workflow .wait_condition (lambda : self ._nexus_operation_started )
233
240
241
+ def _get_operation (
242
+ self , input : MyInput
243
+ ) -> Union [
244
+ nexusrpc .interface .Operation [MyInput , MyOutput ],
245
+ Callable [[Any ], nexusrpc .handler .AbstractOperation [MyInput , MyOutput ]],
246
+ ]:
247
+ return {
248
+ (
249
+ SyncResponse ,
250
+ OpDefinitionType .SHORTHAND ,
251
+ CallerReference .IMPLEMENTATION ,
252
+ ): MyServiceImpl .my_sync_operation ,
253
+ (
254
+ SyncResponse ,
255
+ OpDefinitionType .SHORTHAND ,
256
+ CallerReference .INTERFACE ,
257
+ ): MyServiceInterface .my_sync_operation ,
258
+ (
259
+ SyncResponse ,
260
+ OpDefinitionType .LONGHAND ,
261
+ CallerReference .IMPLEMENTATION ,
262
+ ): MyServiceImpl .my_sync_or_async_operation ,
263
+ (
264
+ SyncResponse ,
265
+ OpDefinitionType .LONGHAND ,
266
+ CallerReference .INTERFACE ,
267
+ ): MyServiceInterface .my_sync_or_async_operation ,
268
+ (
269
+ AsyncResponse ,
270
+ OpDefinitionType .SHORTHAND ,
271
+ CallerReference .IMPLEMENTATION ,
272
+ ): MyServiceImpl .my_async_operation ,
273
+ (
274
+ AsyncResponse ,
275
+ OpDefinitionType .SHORTHAND ,
276
+ CallerReference .INTERFACE ,
277
+ ): MyServiceInterface .my_async_operation ,
278
+ (
279
+ AsyncResponse ,
280
+ OpDefinitionType .LONGHAND ,
281
+ CallerReference .IMPLEMENTATION ,
282
+ ): MyServiceImpl .my_sync_or_async_operation ,
283
+ (
284
+ AsyncResponse ,
285
+ OpDefinitionType .LONGHAND ,
286
+ CallerReference .INTERFACE ,
287
+ ): MyServiceInterface .my_sync_or_async_operation ,
288
+ }[
289
+ {True : SyncResponse , False : AsyncResponse }[
290
+ isinstance (input .response_type , SyncResponse )
291
+ ],
292
+ input .response_type .op_definition_type ,
293
+ input .caller_reference ,
294
+ ]
295
+
234
296
235
297
# -----------------------------------------------------------------------------
236
298
# Tests
@@ -239,13 +301,20 @@ async def wait_nexus_operation_started(self) -> None:
239
301
240
302
# TODO(dan): cross-namespace tests
241
303
# TODO(dan): nexus endpoint pytest fixture?
242
- # TODO: get rid of UnsandboxedWorkflowRunner (due to xray)
304
+ # TODO(dan): get rid of UnsandboxedWorkflowRunner (due to xray)
305
+ # TODO(dan): test headers
243
306
@pytest .mark .parametrize ("request_cancel" , [False , True ])
244
- @pytest .mark .parametrize ("use_shorthand_defined_operation" , [False , True ])
307
+ @pytest .mark .parametrize (
308
+ "op_definition_type" , [OpDefinitionType .SHORTHAND , OpDefinitionType .LONGHAND ]
309
+ )
310
+ @pytest .mark .parametrize (
311
+ "caller_reference" , [CallerReference .IMPLEMENTATION , CallerReference .INTERFACE ]
312
+ )
245
313
async def test_sync_response (
246
314
client : Client ,
247
315
request_cancel : bool ,
248
- use_shorthand_defined_operation : bool ,
316
+ op_definition_type : OpDefinitionType ,
317
+ caller_reference : CallerReference ,
249
318
):
250
319
task_queue = str (uuid .uuid4 ())
251
320
async with Worker (
@@ -260,10 +329,11 @@ async def test_sync_response(
260
329
MyCallerWorkflow .run ,
261
330
args = [
262
331
MyInput (
263
- response_type = SyncResponse (use_shorthand_defined_operation ),
332
+ response_type = SyncResponse (op_definition_type ),
264
333
start_options = nexusrpc .handler .StartOperationOptions (
265
334
headers = {"my-header-key" : "my-header-value" },
266
335
),
336
+ caller_reference = caller_reference ,
267
337
),
268
338
request_cancel ,
269
339
task_queue ,
@@ -280,15 +350,19 @@ async def test_sync_response(
280
350
281
351
282
352
@pytest .mark .parametrize ("request_cancel" , [False , True ])
283
- @pytest .mark .parametrize ("use_shorthand_defined_operation" , [False , True ])
353
+ @pytest .mark .parametrize (
354
+ "op_definition_type" , [OpDefinitionType .SHORTHAND , OpDefinitionType .LONGHAND ]
355
+ )
356
+ @pytest .mark .parametrize (
357
+ "caller_reference" , [CallerReference .IMPLEMENTATION , CallerReference .INTERFACE ]
358
+ )
284
359
async def test_async_response (
285
360
client : Client ,
286
361
request_cancel : bool ,
287
- use_shorthand_defined_operation : bool ,
362
+ op_definition_type : OpDefinitionType ,
363
+ caller_reference : CallerReference ,
288
364
):
289
- print (
290
- f"🌈 { 'test_async_response' :<24} : { request_cancel = } { use_shorthand_defined_operation = } "
291
- )
365
+ print (f"🌈 { 'test_async_response' :<24} : { request_cancel = } { op_definition_type = } " )
292
366
task_queue = str (uuid .uuid4 ())
293
367
async with Worker (
294
368
client ,
@@ -298,7 +372,7 @@ async def test_async_response(
298
372
workflow_runner = UnsandboxedWorkflowRunner (),
299
373
):
300
374
caller_wf_handle , handler_wf_handle = await _start_wf_and_nexus_op (
301
- client , task_queue , request_cancel , use_shorthand_defined_operation
375
+ client , task_queue , request_cancel , op_definition_type , caller_reference
302
376
)
303
377
# TODO(dan): race here? How do we know it hasn't been canceled already?
304
378
handler_wf_info = await handler_wf_handle .describe ()
@@ -325,10 +399,10 @@ async def test_async_response(
325
399
# ID of first command after update accepted
326
400
assert e .__cause__ .scheduled_event_id == 6
327
401
assert e .__cause__ .endpoint == make_nexus_endpoint_name (task_queue )
328
- assert e .__cause__ .service == "MyService "
402
+ assert e .__cause__ .service == "MyServiceInterface "
329
403
assert (
330
404
e .__cause__ .operation == "my_async_operation"
331
- if use_shorthand_defined_operation
405
+ if op_definition_type == OpDefinitionType . SHORTHAND
332
406
else "my_sync_or_async_operation"
333
407
)
334
408
assert temporalio .nexus .handler .StartWorkflowOperationResult ._decode_token (
@@ -349,7 +423,8 @@ async def _start_wf_and_nexus_op(
349
423
client : Client ,
350
424
task_queue : str ,
351
425
request_cancel : bool ,
352
- use_shorthand_defined_operation : bool ,
426
+ op_definition_type : OpDefinitionType ,
427
+ caller_reference : CallerReference ,
353
428
) -> tuple [WorkflowHandle , WorkflowHandle ]:
354
429
"""
355
430
Start the caller workflow and wait until the Nexus operation has started.
@@ -366,11 +441,12 @@ async def _start_wf_and_nexus_op(
366
441
response_type = AsyncResponse (
367
442
operation_workflow_id ,
368
443
block_forever_waiting_for_cancellation ,
369
- use_shorthand_defined_operation ,
444
+ op_definition_type ,
370
445
),
371
446
start_options = nexusrpc .handler .StartOperationOptions (
372
447
headers = {"my-header-key" : "my-header-value" },
373
448
),
449
+ caller_reference = caller_reference ,
374
450
),
375
451
request_cancel ,
376
452
task_queue ,
0 commit comments