Skip to content

Commit cfcb35f

Browse files
committed
Move StartWorkflowOperationResult type helper into temporal sdk
1 parent 98a69fa commit cfcb35f

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

temporalio/nexus/handler.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Generic,
1818
Optional,
1919
Sequence,
20+
Type,
2021
TypeVar,
2122
)
2223

@@ -118,6 +119,31 @@ def to_workflow_handle(token: str, client: Client) -> WorkflowHandle[Any, O]:
118119
return client.get_workflow_handle(workflow_id, run_id=run_id)
119120

120121

122+
def get_input_and_output_types_from_async_start_method(
123+
start_method: Callable[
124+
[S, I, nexusrpc.handler.StartOperationOptions],
125+
Awaitable[StartWorkflowOperationResult[O]],
126+
],
127+
) -> tuple[Type[I], Type[O]]:
128+
input_type, output_type = (
129+
nexusrpc.handler.get_input_and_output_types_from_sync_start_method(start_method)
130+
)
131+
origin_type = typing.get_origin(output_type)
132+
if not origin_type or not issubclass(origin_type, StartWorkflowOperationResult):
133+
raise TypeError(
134+
f"The return type of {start_method.__name__} must be a subclass of StartWorkflowOperationResult, "
135+
f"but is {output_type}"
136+
)
137+
138+
args = typing.get_args(output_type)
139+
if len(args) != 1:
140+
raise TypeError(
141+
f"The return type of {start_method.__name__} must have exactly one type parameter, "
142+
f"but has {len(args)}: {args}"
143+
)
144+
return input_type, args[0]
145+
146+
121147
# TODO(dan): overloads should use SelfType, ParamType, ReturnType?
122148

123149

@@ -323,10 +349,8 @@ def workflow_run_operation(
323349
def factory(service: S) -> WorkflowRunOperation[I, O, S]:
324350
return WorkflowRunOperation(service, start_method)
325351

326-
input_type, output_type = (
327-
nexusrpc.handler.get_input_and_output_types_from_async_start_method(
328-
start_method
329-
)
352+
input_type, output_type = get_input_and_output_types_from_async_start_method(
353+
start_method
330354
)
331355
factory.__nexus_operation__ = nexusrpc.handler.NexusOperationDefinition(
332356
name=start_method.__name__,

0 commit comments

Comments
 (0)