Skip to content

Commit 702c868

Browse files
authored
Fix uws client-side result type instantiation (#712)
* Add tests for UwS return types * Fix UwS client-side workflow and update result type instantiation * Refactor
1 parent 540faeb commit 702c868

File tree

3 files changed

+74
-15
lines changed

3 files changed

+74
-15
lines changed

temporalio/client.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ async def start_workflow(
514514
temporalio.common._warn_on_deprecated_search_attributes(
515515
search_attributes, stack_level=stack_level
516516
)
517-
name, result_type_from_run_fn = (
517+
name, result_type_from_type_hint = (
518518
temporalio.workflow._Definition.get_name_and_result_type(workflow)
519519
)
520520

@@ -539,7 +539,7 @@ async def start_workflow(
539539
static_details=static_details,
540540
start_signal=start_signal,
541541
start_signal_args=start_signal_args,
542-
ret_type=result_type or result_type_from_run_fn,
542+
ret_type=result_type or result_type_from_type_hint,
543543
rpc_metadata=rpc_metadata,
544544
rpc_timeout=rpc_timeout,
545545
request_eager_start=request_eager_start,
@@ -1105,7 +1105,7 @@ def on_start_success(
11051105
start_workflow_operation._start_workflow_input.id,
11061106
first_execution_run_id=start_response.run_id,
11071107
result_run_id=start_response.run_id,
1108-
result_type=result_type,
1108+
result_type=start_workflow_operation._start_workflow_input.ret_type,
11091109
)
11101110
)
11111111

@@ -2335,17 +2335,10 @@ async def _start_update(
23352335
) -> WorkflowUpdateHandle[Any]:
23362336
if wait_for_stage == WorkflowUpdateStage.ADMITTED:
23372337
raise ValueError("ADMITTED wait stage not supported")
2338-
update_name: str
2339-
ret_type = result_type
2340-
if isinstance(update, temporalio.workflow.UpdateMethodMultiParam):
2341-
defn = update._defn
2342-
if not defn.name:
2343-
raise RuntimeError("Cannot invoke dynamic update definition")
2344-
# TODO(cretz): Check count/type of args at runtime?
2345-
update_name = defn.name
2346-
ret_type = defn.ret_type
2347-
else:
2348-
update_name = str(update)
2338+
2339+
update_name, result_type_from_type_hint = (
2340+
temporalio.workflow._UpdateDefinition.get_name_and_result_type(update)
2341+
)
23492342

23502343
return await self._client._impl.start_workflow_update(
23512344
StartWorkflowUpdateInput(
@@ -2356,7 +2349,7 @@ async def _start_update(
23562349
update=update_name,
23572350
args=temporalio.common._arg_or_args(arg, args),
23582351
headers={},
2359-
ret_type=ret_type,
2352+
ret_type=result_type or result_type_from_type_hint,
23602353
rpc_metadata=rpc_metadata,
23612354
rpc_timeout=rpc_timeout,
23622355
wait_for_stage=wait_for_stage,
@@ -6183,6 +6176,7 @@ async def _start_workflow_update_with_start(
61836176
workflow_id=start_input.id,
61846177
workflow_run_id=start_response.run_id,
61856178
known_outcome=known_outcome,
6179+
result_type=update_input.ret_type,
61866180
)
61876181
if update_input.wait_for_stage == WorkflowUpdateStage.COMPLETED:
61886182
await handle._poll_until_outcome()

temporalio/workflow.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import temporalio.common
5454
import temporalio.converter
5555
import temporalio.exceptions
56+
import temporalio.workflow
5657

5758
from .types import (
5859
AnyType,
@@ -1783,6 +1784,20 @@ def set_validator(self, validator: Callable[..., None]) -> None:
17831784
raise RuntimeError(f"Validator already set for update {self.name}")
17841785
object.__setattr__(self, "validator", validator)
17851786

1787+
@classmethod
1788+
def get_name_and_result_type(
1789+
cls,
1790+
name_or_update_fn: Union[str, Callable[..., Any]],
1791+
) -> Tuple[str, Optional[Type]]:
1792+
if isinstance(name_or_update_fn, temporalio.workflow.UpdateMethodMultiParam):
1793+
defn = name_or_update_fn._defn
1794+
if not defn.name:
1795+
raise RuntimeError("Cannot invoke dynamic update definition")
1796+
# TODO(cretz): Check count/type of args at runtime?
1797+
return defn.name, defn.ret_type
1798+
else:
1799+
return str(name_or_update_fn), None
1800+
17861801

17871802
# See https://mypy.readthedocs.io/en/latest/runtime_troubles.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
17881803
if TYPE_CHECKING:

tests/worker/test_update_with_start.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import uuid
44
from contextlib import contextmanager
5+
from dataclasses import dataclass
56
from datetime import timedelta
67
from enum import Enum
78
from typing import Any, Iterator
@@ -515,3 +516,52 @@ def test_with_start_workflow_operation_requires_conflict_policy():
515516
id="wid-1",
516517
task_queue="test-queue",
517518
)
519+
520+
521+
@dataclass
522+
class DataClass1:
523+
a: int
524+
b: str
525+
526+
527+
@dataclass
528+
class DataClass2:
529+
a: int
530+
b: str
531+
532+
533+
@workflow.defn
534+
class WorkflowCanReturnDataClass:
535+
def __init__(self) -> None:
536+
self.received_update = False
537+
538+
@workflow.run
539+
async def run(self) -> DataClass1:
540+
await workflow.wait_condition(lambda: self.received_update)
541+
return DataClass1(a=1, b="workflow-result")
542+
543+
@workflow.update
544+
async def update(self) -> DataClass2:
545+
self.received_update = True
546+
return DataClass2(a=2, b="update-result")
547+
548+
549+
async def test_workflow_and_update_can_return_dataclass(client: Client):
550+
async with new_worker(client, WorkflowCanReturnDataClass) as worker:
551+
start_op = WithStartWorkflowOperation(
552+
WorkflowCanReturnDataClass.run,
553+
id=f"workflow-{uuid.uuid4()}",
554+
task_queue=worker.task_queue,
555+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
556+
)
557+
558+
update_handle = await client.start_update_with_start_workflow(
559+
WorkflowCanReturnDataClass.update,
560+
wait_for_stage=WorkflowUpdateStage.COMPLETED,
561+
start_workflow_operation=start_op,
562+
)
563+
564+
assert await update_handle.result() == DataClass2(a=2, b="update-result")
565+
566+
wf_handle = await start_op.workflow_handle()
567+
assert await wf_handle.result() == DataClass1(a=1, b="workflow-result")

0 commit comments

Comments
 (0)