Skip to content

Commit 7e8da41

Browse files
committed
Refactor test
1 parent 9f085e8 commit 7e8da41

File tree

1 file changed

+105
-29
lines changed

1 file changed

+105
-29
lines changed

tests/worker/test_nexus.py

Lines changed: 105 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import uuid
33
from dataclasses import dataclass
44
from datetime import timedelta
5-
from typing import Optional, Union
5+
from enum import StrEnum
6+
from typing import Any, Callable, Optional, Union
67

78
import nexusrpc
89
import nexusrpc.handler
@@ -34,16 +35,26 @@
3435
# -----------------------------------------------------------------------------
3536
# Service interface
3637
#
38+
class CallerReference(StrEnum):
39+
IMPLEMENTATION = "implementation"
40+
INTERFACE = "interface"
41+
42+
43+
class OpDefinitionType(StrEnum):
44+
SHORTHAND = "shorthand"
45+
LONGHAND = "longhand"
46+
47+
3748
@dataclass
3849
class SyncResponse:
39-
use_shorthand_defined_operation: bool
50+
op_definition_type: OpDefinitionType
4051

4152

4253
@dataclass
4354
class AsyncResponse:
4455
operation_workflow_id: str
4556
block_forever_waiting_for_cancellation: bool
46-
use_shorthand_defined_operation: bool
57+
op_definition_type: OpDefinitionType
4758

4859

4960
# The order of the two types in this union is critical since the data converter matches
@@ -56,6 +67,7 @@ class AsyncResponse:
5667
class MyInput:
5768
response_type: ResponseType
5869
start_options: nexusrpc.handler.StartOperationOptions
70+
caller_reference: CallerReference
5971

6072

6173
@dataclass
@@ -65,7 +77,7 @@ class MyOutput:
6577

6678

6779
@nexusrpc.interface.service
68-
class MyService:
80+
class MyServiceInterface:
6981
my_sync_or_async_operation: nexusrpc.interface.Operation[MyInput, MyOutput]
7082
my_sync_operation: nexusrpc.interface.Operation[MyInput, MyOutput]
7183
my_async_operation: nexusrpc.interface.Operation[MyInput, MyOutput]
@@ -127,7 +139,7 @@ async def fetch_result(self, *args, **kwargs):
127139
raise NotImplementedError
128140

129141

130-
@nexusrpc.handler.service(interface=MyService)
142+
@nexusrpc.handler.service(interface=MyServiceInterface)
131143
class MyServiceImpl:
132144
@nexusrpc.handler.operation
133145
def my_sync_or_async_operation(
@@ -176,7 +188,10 @@ def __init__(
176188
task_queue: str,
177189
) -> None:
178190
self.nexus_service = workflow.NexusClient(
179-
service=MyService,
191+
service={
192+
CallerReference.IMPLEMENTATION: MyServiceImpl,
193+
CallerReference.INTERFACE: MyServiceInterface,
194+
}[input.caller_reference],
180195
endpoint=make_nexus_endpoint_name(task_queue),
181196
schedule_to_close_timeout=timedelta(seconds=10),
182197
)
@@ -190,15 +205,7 @@ async def run(
190205
request_cancel: bool,
191206
task_queue: str,
192207
) -> MyOutput:
193-
if input.response_type.use_shorthand_defined_operation:
194-
operation = (
195-
MyService.my_sync_operation
196-
if isinstance(input.response_type, SyncResponse)
197-
else MyService.my_async_operation
198-
)
199-
else:
200-
operation = MyService.my_sync_or_async_operation
201-
208+
operation = self._get_operation(input)
202209
op_handle = await self.nexus_service.start_operation(
203210
operation,
204211
input,
@@ -231,6 +238,61 @@ async def run(
231238
async def wait_nexus_operation_started(self) -> None:
232239
await workflow.wait_condition(lambda: self._nexus_operation_started)
233240

241+
def _get_operation(
242+
self, input: MyInput
243+
) -> Union[
244+
nexusrpc.interface.Operation[MyInput, MyOutput],
245+
Callable[[Any], nexusrpc.handler.AbstractOperation[MyInput, MyOutput]],
246+
]:
247+
return {
248+
(
249+
SyncResponse,
250+
OpDefinitionType.SHORTHAND,
251+
CallerReference.IMPLEMENTATION,
252+
): MyServiceImpl.my_sync_operation,
253+
(
254+
SyncResponse,
255+
OpDefinitionType.SHORTHAND,
256+
CallerReference.INTERFACE,
257+
): MyServiceInterface.my_sync_operation,
258+
(
259+
SyncResponse,
260+
OpDefinitionType.LONGHAND,
261+
CallerReference.IMPLEMENTATION,
262+
): MyServiceImpl.my_sync_or_async_operation,
263+
(
264+
SyncResponse,
265+
OpDefinitionType.LONGHAND,
266+
CallerReference.INTERFACE,
267+
): MyServiceInterface.my_sync_or_async_operation,
268+
(
269+
AsyncResponse,
270+
OpDefinitionType.SHORTHAND,
271+
CallerReference.IMPLEMENTATION,
272+
): MyServiceImpl.my_async_operation,
273+
(
274+
AsyncResponse,
275+
OpDefinitionType.SHORTHAND,
276+
CallerReference.INTERFACE,
277+
): MyServiceInterface.my_async_operation,
278+
(
279+
AsyncResponse,
280+
OpDefinitionType.LONGHAND,
281+
CallerReference.IMPLEMENTATION,
282+
): MyServiceImpl.my_sync_or_async_operation,
283+
(
284+
AsyncResponse,
285+
OpDefinitionType.LONGHAND,
286+
CallerReference.INTERFACE,
287+
): MyServiceInterface.my_sync_or_async_operation,
288+
}[
289+
{True: SyncResponse, False: AsyncResponse}[
290+
isinstance(input.response_type, SyncResponse)
291+
],
292+
input.response_type.op_definition_type,
293+
input.caller_reference,
294+
]
295+
234296

235297
# -----------------------------------------------------------------------------
236298
# Tests
@@ -239,13 +301,20 @@ async def wait_nexus_operation_started(self) -> None:
239301

240302
# TODO(dan): cross-namespace tests
241303
# TODO(dan): nexus endpoint pytest fixture?
242-
# TODO: get rid of UnsandboxedWorkflowRunner (due to xray)
304+
# TODO(dan): get rid of UnsandboxedWorkflowRunner (due to xray)
305+
# TODO(dan): test headers
243306
@pytest.mark.parametrize("request_cancel", [False, True])
244-
@pytest.mark.parametrize("use_shorthand_defined_operation", [False, True])
307+
@pytest.mark.parametrize(
308+
"op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND]
309+
)
310+
@pytest.mark.parametrize(
311+
"caller_reference", [CallerReference.IMPLEMENTATION, CallerReference.INTERFACE]
312+
)
245313
async def test_sync_response(
246314
client: Client,
247315
request_cancel: bool,
248-
use_shorthand_defined_operation: bool,
316+
op_definition_type: OpDefinitionType,
317+
caller_reference: CallerReference,
249318
):
250319
task_queue = str(uuid.uuid4())
251320
async with Worker(
@@ -260,10 +329,11 @@ async def test_sync_response(
260329
MyCallerWorkflow.run,
261330
args=[
262331
MyInput(
263-
response_type=SyncResponse(use_shorthand_defined_operation),
332+
response_type=SyncResponse(op_definition_type),
264333
start_options=nexusrpc.handler.StartOperationOptions(
265334
headers={"my-header-key": "my-header-value"},
266335
),
336+
caller_reference=caller_reference,
267337
),
268338
request_cancel,
269339
task_queue,
@@ -280,15 +350,19 @@ async def test_sync_response(
280350

281351

282352
@pytest.mark.parametrize("request_cancel", [False, True])
283-
@pytest.mark.parametrize("use_shorthand_defined_operation", [False, True])
353+
@pytest.mark.parametrize(
354+
"op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND]
355+
)
356+
@pytest.mark.parametrize(
357+
"caller_reference", [CallerReference.IMPLEMENTATION, CallerReference.INTERFACE]
358+
)
284359
async def test_async_response(
285360
client: Client,
286361
request_cancel: bool,
287-
use_shorthand_defined_operation: bool,
362+
op_definition_type: OpDefinitionType,
363+
caller_reference: CallerReference,
288364
):
289-
print(
290-
f"🌈 {'test_async_response':<24}: {request_cancel=} {use_shorthand_defined_operation=}"
291-
)
365+
print(f"🌈 {'test_async_response':<24}: {request_cancel=} {op_definition_type=}")
292366
task_queue = str(uuid.uuid4())
293367
async with Worker(
294368
client,
@@ -298,7 +372,7 @@ async def test_async_response(
298372
workflow_runner=UnsandboxedWorkflowRunner(),
299373
):
300374
caller_wf_handle, handler_wf_handle = await _start_wf_and_nexus_op(
301-
client, task_queue, request_cancel, use_shorthand_defined_operation
375+
client, task_queue, request_cancel, op_definition_type, caller_reference
302376
)
303377
# TODO(dan): race here? How do we know it hasn't been canceled already?
304378
handler_wf_info = await handler_wf_handle.describe()
@@ -325,10 +399,10 @@ async def test_async_response(
325399
# ID of first command after update accepted
326400
assert e.__cause__.scheduled_event_id == 6
327401
assert e.__cause__.endpoint == make_nexus_endpoint_name(task_queue)
328-
assert e.__cause__.service == "MyService"
402+
assert e.__cause__.service == "MyServiceInterface"
329403
assert (
330404
e.__cause__.operation == "my_async_operation"
331-
if use_shorthand_defined_operation
405+
if op_definition_type == OpDefinitionType.SHORTHAND
332406
else "my_sync_or_async_operation"
333407
)
334408
assert temporalio.nexus.handler.StartWorkflowOperationResult._decode_token(
@@ -349,7 +423,8 @@ async def _start_wf_and_nexus_op(
349423
client: Client,
350424
task_queue: str,
351425
request_cancel: bool,
352-
use_shorthand_defined_operation: bool,
426+
op_definition_type: OpDefinitionType,
427+
caller_reference: CallerReference,
353428
) -> tuple[WorkflowHandle, WorkflowHandle]:
354429
"""
355430
Start the caller workflow and wait until the Nexus operation has started.
@@ -366,11 +441,12 @@ async def _start_wf_and_nexus_op(
366441
response_type=AsyncResponse(
367442
operation_workflow_id,
368443
block_forever_waiting_for_cancellation,
369-
use_shorthand_defined_operation,
444+
op_definition_type,
370445
),
371446
start_options=nexusrpc.handler.StartOperationOptions(
372447
headers={"my-header-key": "my-header-value"},
373448
),
449+
caller_reference=caller_reference,
374450
),
375451
request_cancel,
376452
task_queue,

0 commit comments

Comments
 (0)