|
17 | 17 | Generic,
|
18 | 18 | Optional,
|
19 | 19 | Sequence,
|
| 20 | + Type, |
20 | 21 | TypeVar,
|
21 | 22 | )
|
22 | 23 |
|
@@ -118,6 +119,31 @@ def to_workflow_handle(token: str, client: Client) -> WorkflowHandle[Any, O]:
|
118 | 119 | return client.get_workflow_handle(workflow_id, run_id=run_id)
|
119 | 120 |
|
120 | 121 |
|
| 122 | +def get_input_and_output_types_from_async_start_method( |
| 123 | + start_method: Callable[ |
| 124 | + [S, I, nexusrpc.handler.StartOperationOptions], |
| 125 | + Awaitable[StartWorkflowOperationResult[O]], |
| 126 | + ], |
| 127 | +) -> tuple[Type[I], Type[O]]: |
| 128 | + input_type, output_type = ( |
| 129 | + nexusrpc.handler.get_input_and_output_types_from_sync_start_method(start_method) |
| 130 | + ) |
| 131 | + origin_type = typing.get_origin(output_type) |
| 132 | + if not origin_type or not issubclass(origin_type, StartWorkflowOperationResult): |
| 133 | + raise TypeError( |
| 134 | + f"The return type of {start_method.__name__} must be a subclass of StartWorkflowOperationResult, " |
| 135 | + f"but is {output_type}" |
| 136 | + ) |
| 137 | + |
| 138 | + args = typing.get_args(output_type) |
| 139 | + if len(args) != 1: |
| 140 | + raise TypeError( |
| 141 | + f"The return type of {start_method.__name__} must have exactly one type parameter, " |
| 142 | + f"but has {len(args)}: {args}" |
| 143 | + ) |
| 144 | + return input_type, args[0] |
| 145 | + |
| 146 | + |
121 | 147 | # TODO(dan): overloads should use SelfType, ParamType, ReturnType?
|
122 | 148 |
|
123 | 149 |
|
@@ -323,10 +349,8 @@ def workflow_run_operation(
|
323 | 349 | def factory(service: S) -> WorkflowRunOperation[I, O, S]:
|
324 | 350 | return WorkflowRunOperation(service, start_method)
|
325 | 351 |
|
326 |
| - input_type, output_type = ( |
327 |
| - nexusrpc.handler.get_input_and_output_types_from_async_start_method( |
328 |
| - start_method |
329 |
| - ) |
| 352 | + input_type, output_type = get_input_and_output_types_from_async_start_method( |
| 353 | + start_method |
330 | 354 | )
|
331 | 355 | factory.__nexus_operation__ = nexusrpc.handler.NexusOperationDefinition(
|
332 | 356 | name=start_method.__name__,
|
|
0 commit comments