Skip to content

Commit 66e7650

Browse files
authored
Prevent re-use of update-with-start WithStartWorkflowOperation (#714)
* Add test that WithStartWorkflowOperation cannot be reused * RuntimeError if start_op is reused
1 parent 702c868 commit 66e7650

File tree

2 files changed

+269
-31
lines changed

2 files changed

+269
-31
lines changed

temporalio/client.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,30 +1073,27 @@ async def _start_update_with_start(
10731073
) -> WorkflowUpdateHandle[Any]:
10741074
if wait_for_stage == WorkflowUpdateStage.ADMITTED:
10751075
raise ValueError("ADMITTED wait stage not supported")
1076-
update_name: str
1077-
ret_type = result_type
1078-
if isinstance(update, temporalio.workflow.UpdateMethodMultiParam):
1079-
defn = update._defn
1080-
if not defn.name:
1081-
raise RuntimeError("Cannot invoke dynamic update definition")
1082-
# TODO(cretz): Check count/type of args at runtime?
1083-
update_name = defn.name
1084-
ret_type = defn.ret_type
1085-
else:
1086-
update_name = str(update)
1076+
1077+
if start_workflow_operation._used:
1078+
raise RuntimeError("WithStartWorkflowOperation cannot be reused")
1079+
start_workflow_operation._used = True
1080+
1081+
update_name, result_type_from_type_hint = (
1082+
temporalio.workflow._UpdateDefinition.get_name_and_result_type(update)
1083+
)
10871084

10881085
update_input = UpdateWithStartUpdateWorkflowInput(
10891086
update_id=id,
10901087
update=update_name,
10911088
args=temporalio.common._arg_or_args(arg, args),
10921089
headers={},
1093-
ret_type=ret_type,
1090+
ret_type=result_type or result_type_from_type_hint,
10941091
rpc_metadata=rpc_metadata,
10951092
rpc_timeout=rpc_timeout,
10961093
wait_for_stage=wait_for_stage,
10971094
)
10981095

1099-
def on_start_success(
1096+
def on_start(
11001097
start_response: temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse,
11011098
):
11021099
start_workflow_operation._workflow_handle.set_result(
@@ -1109,16 +1106,16 @@ def on_start_success(
11091106
)
11101107
)
11111108

1112-
def on_start_failure(
1109+
def on_start_error(
11131110
error: BaseException,
11141111
):
11151112
start_workflow_operation._workflow_handle.set_exception(error)
11161113

11171114
input = StartWorkflowUpdateWithStartInput(
11181115
start_workflow_input=start_workflow_operation._start_workflow_input,
11191116
update_workflow_input=update_input,
1120-
_on_start=on_start_success,
1121-
_on_start_error=on_start_failure,
1117+
_on_start=on_start,
1118+
_on_start_error=on_start_error,
11221119
)
11231120

11241121
return await self._impl.start_update_with_start_workflow(input)
@@ -2621,6 +2618,7 @@ def __init__(
26212618
rpc_timeout=rpc_timeout,
26222619
)
26232620
self._workflow_handle: Future[WorkflowHandle[SelfType, ReturnType]] = Future()
2621+
self._used = False
26242622

26252623
async def workflow_handle(self) -> WorkflowHandle[SelfType, ReturnType]:
26262624
"""Wait until workflow is running and return a WorkflowHandle.

tests/worker/test_update_with_start.py

Lines changed: 255 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,30 @@ async def done(self):
6868
self.received_done_signal = True
6969

7070

71+
async def test_with_start_workflow_operation_cannot_be_reused(client: Client):
72+
async with new_worker(client, WorkflowForUpdateWithStartTest) as worker:
73+
start_op = WithStartWorkflowOperation(
74+
WorkflowForUpdateWithStartTest.run,
75+
0,
76+
id=f"wid-{uuid.uuid4()}",
77+
task_queue=worker.task_queue,
78+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
79+
)
80+
81+
async def start_update_with_start(start_op: WithStartWorkflowOperation):
82+
return await client.start_update_with_start_workflow(
83+
WorkflowForUpdateWithStartTest.my_non_blocking_update,
84+
"1",
85+
wait_for_stage=WorkflowUpdateStage.COMPLETED,
86+
start_workflow_operation=start_op,
87+
)
88+
89+
await start_update_with_start(start_op)
90+
with pytest.raises(RuntimeError) as exc_info:
91+
await start_update_with_start(start_op)
92+
assert "WithStartWorkflowOperation cannot be reused" in str(exc_info.value)
93+
94+
7195
class ExpectErrorWhenWorkflowExists(Enum):
7296
YES = "yes"
7397
NO = "no"
@@ -387,7 +411,7 @@ def make_start_op(workflow_id: str):
387411
assert (await start_op_4.workflow_handle()).first_execution_run_id is not None
388412

389413

390-
async def test_update_with_start_failure_start_workflow_error(
414+
async def test_update_with_start_workflow_already_started_error(
391415
client: Client, env: WorkflowEnvironment
392416
):
393417
"""
@@ -520,13 +544,13 @@ def test_with_start_workflow_operation_requires_conflict_policy():
520544

521545
@dataclass
522546
class DataClass1:
523-
a: int
547+
a: str
524548
b: str
525549

526550

527551
@dataclass
528552
class DataClass2:
529-
a: int
553+
a: str
530554
b: str
531555

532556

@@ -536,32 +560,248 @@ def __init__(self) -> None:
536560
self.received_update = False
537561

538562
@workflow.run
539-
async def run(self) -> DataClass1:
563+
async def run(self, arg: str) -> DataClass1:
540564
await workflow.wait_condition(lambda: self.received_update)
541-
return DataClass1(a=1, b="workflow-result")
565+
return DataClass1(a=arg, b="workflow-result")
542566

543567
@workflow.update
544-
async def update(self) -> DataClass2:
568+
async def my_update(self, arg: str) -> DataClass2:
545569
self.received_update = True
546-
return DataClass2(a=2, b="update-result")
570+
return DataClass2(a=arg, b="update-result")
547571

548572

549573
async def test_workflow_and_update_can_return_dataclass(client: Client):
550574
async with new_worker(client, WorkflowCanReturnDataClass) as worker:
551-
start_op = WithStartWorkflowOperation(
552-
WorkflowCanReturnDataClass.run,
553-
id=f"workflow-{uuid.uuid4()}",
554-
task_queue=worker.task_queue,
555-
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
575+
576+
def make_start_op(workflow_id: str):
577+
return WithStartWorkflowOperation(
578+
WorkflowCanReturnDataClass.run,
579+
"workflow-arg",
580+
id=workflow_id,
581+
task_queue=worker.task_queue,
582+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
583+
)
584+
585+
# no-param update-function overload
586+
start_op = make_start_op(f"wf-{uuid.uuid4()}")
587+
588+
update_handle = await client.start_update_with_start_workflow(
589+
WorkflowCanReturnDataClass.my_update,
590+
"update-arg",
591+
wait_for_stage=WorkflowUpdateStage.COMPLETED,
592+
start_workflow_operation=start_op,
593+
)
594+
595+
assert await update_handle.result() == DataClass2(
596+
a="update-arg", b="update-result"
556597
)
557598

599+
wf_handle = await start_op.workflow_handle()
600+
assert await wf_handle.result() == DataClass1(
601+
a="workflow-arg", b="workflow-result"
602+
)
603+
604+
# no-param update-string-name overload
605+
start_op = make_start_op(f"wf-{uuid.uuid4()}")
606+
558607
update_handle = await client.start_update_with_start_workflow(
559-
WorkflowCanReturnDataClass.update,
608+
"my_update",
609+
"update-arg",
560610
wait_for_stage=WorkflowUpdateStage.COMPLETED,
561611
start_workflow_operation=start_op,
612+
result_type=DataClass2,
562613
)
563614

564-
assert await update_handle.result() == DataClass2(a=2, b="update-result")
615+
assert await update_handle.result() == DataClass2(
616+
a="update-arg", b="update-result"
617+
)
565618

566619
wf_handle = await start_op.workflow_handle()
567-
assert await wf_handle.result() == DataClass1(a=1, b="workflow-result")
620+
assert await wf_handle.result() == DataClass1(
621+
a="workflow-arg", b="workflow-result"
622+
)
623+
624+
625+
@dataclass
626+
class WorkflowResult:
627+
result: str
628+
629+
630+
@dataclass
631+
class UpdateResult:
632+
result: str
633+
634+
635+
@workflow.defn
636+
class NoParamWorkflow:
637+
def __init__(self) -> None:
638+
self.received_update = False
639+
640+
@workflow.run
641+
async def my_workflow_run(self) -> WorkflowResult:
642+
await workflow.wait_condition(lambda: self.received_update)
643+
return WorkflowResult(result="workflow-result")
644+
645+
@workflow.update(name="my_update")
646+
async def update(self) -> UpdateResult:
647+
self.received_update = True
648+
return UpdateResult(result="update-result")
649+
650+
651+
@workflow.defn
652+
class OneParamWorkflow:
653+
def __init__(self) -> None:
654+
self.received_update = False
655+
656+
@workflow.run
657+
async def my_workflow_run(self, arg: str) -> WorkflowResult:
658+
await workflow.wait_condition(lambda: self.received_update)
659+
return WorkflowResult(result=arg)
660+
661+
@workflow.update(name="my_update")
662+
async def update(self, arg: str) -> UpdateResult:
663+
self.received_update = True
664+
return UpdateResult(result=arg)
665+
666+
667+
@workflow.defn
668+
class TwoParamWorkflow:
669+
def __init__(self) -> None:
670+
self.received_update = False
671+
672+
@workflow.run
673+
async def my_workflow_run(self, arg1: str, arg2: str) -> WorkflowResult:
674+
await workflow.wait_condition(lambda: self.received_update)
675+
return WorkflowResult(result=arg1 + "-" + arg2)
676+
677+
@workflow.update(name="my_update")
678+
async def update(self, arg1: str, arg2: str) -> UpdateResult:
679+
self.received_update = True
680+
return UpdateResult(result=arg1 + "-" + arg2)
681+
682+
683+
async def test_update_with_start_no_param(client: Client):
684+
async with new_worker(client, NoParamWorkflow) as worker:
685+
# No-params typed
686+
no_param_start_op = WithStartWorkflowOperation(
687+
NoParamWorkflow.my_workflow_run,
688+
id=f"wf-{uuid.uuid4()}",
689+
task_queue=worker.task_queue,
690+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
691+
)
692+
update_handle = await client.start_update_with_start_workflow(
693+
NoParamWorkflow.update,
694+
wait_for_stage=WorkflowUpdateStage.COMPLETED,
695+
start_workflow_operation=no_param_start_op,
696+
)
697+
assert await update_handle.result() == UpdateResult(result="update-result")
698+
wf_handle = await no_param_start_op.workflow_handle()
699+
assert await wf_handle.result() == WorkflowResult(result="workflow-result")
700+
701+
# No-params string name
702+
no_param_start_op = WithStartWorkflowOperation(
703+
"NoParamWorkflow",
704+
id=f"wf-{uuid.uuid4()}",
705+
task_queue=worker.task_queue,
706+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
707+
result_type=WorkflowResult,
708+
)
709+
update_handle = await client.start_update_with_start_workflow(
710+
"my_update",
711+
wait_for_stage=WorkflowUpdateStage.COMPLETED,
712+
start_workflow_operation=no_param_start_op,
713+
result_type=UpdateResult,
714+
)
715+
assert await update_handle.result() == UpdateResult(result="update-result")
716+
wf_handle = await no_param_start_op.workflow_handle()
717+
assert await wf_handle.result() == WorkflowResult(result="workflow-result")
718+
719+
720+
async def test_update_with_start_one_param(client: Client):
721+
async with new_worker(client, OneParamWorkflow) as worker:
722+
# One-param typed
723+
one_param_start_op = WithStartWorkflowOperation(
724+
OneParamWorkflow.my_workflow_run,
725+
"workflow-arg",
726+
id=f"wf-{uuid.uuid4()}",
727+
task_queue=worker.task_queue,
728+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
729+
)
730+
update_handle = await client.start_update_with_start_workflow(
731+
OneParamWorkflow.update,
732+
"update-arg",
733+
wait_for_stage=WorkflowUpdateStage.COMPLETED,
734+
start_workflow_operation=one_param_start_op,
735+
)
736+
assert await update_handle.result() == UpdateResult(result="update-arg")
737+
wf_handle = await one_param_start_op.workflow_handle()
738+
assert await wf_handle.result() == WorkflowResult(result="workflow-arg")
739+
740+
# One-param string name
741+
one_param_start_op = WithStartWorkflowOperation(
742+
"OneParamWorkflow",
743+
"workflow-arg",
744+
id=f"wf-{uuid.uuid4()}",
745+
task_queue=worker.task_queue,
746+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
747+
result_type=WorkflowResult,
748+
)
749+
update_handle = await client.start_update_with_start_workflow(
750+
"my_update",
751+
"update-arg",
752+
wait_for_stage=WorkflowUpdateStage.COMPLETED,
753+
start_workflow_operation=one_param_start_op,
754+
result_type=UpdateResult,
755+
)
756+
assert await update_handle.result() == UpdateResult(result="update-arg")
757+
wf_handle = await one_param_start_op.workflow_handle()
758+
assert await wf_handle.result() == WorkflowResult(result="workflow-arg")
759+
760+
761+
async def test_update_with_start_two_param(client: Client):
762+
async with new_worker(client, TwoParamWorkflow) as worker:
763+
# Two-params typed
764+
two_param_start_op = WithStartWorkflowOperation(
765+
TwoParamWorkflow.my_workflow_run,
766+
args=("workflow-arg1", "workflow-arg2"),
767+
id=f"wf-{uuid.uuid4()}",
768+
task_queue=worker.task_queue,
769+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
770+
)
771+
update_handle = await client.start_update_with_start_workflow(
772+
TwoParamWorkflow.update,
773+
args=("update-arg1", "update-arg2"),
774+
wait_for_stage=WorkflowUpdateStage.COMPLETED,
775+
start_workflow_operation=two_param_start_op,
776+
)
777+
assert await update_handle.result() == UpdateResult(
778+
result="update-arg1-update-arg2"
779+
)
780+
wf_handle = await two_param_start_op.workflow_handle()
781+
assert await wf_handle.result() == WorkflowResult(
782+
result="workflow-arg1-workflow-arg2"
783+
)
784+
785+
# Two-params string name
786+
two_param_start_op = WithStartWorkflowOperation(
787+
"TwoParamWorkflow",
788+
args=("workflow-arg1", "workflow-arg2"),
789+
id=f"wf-{uuid.uuid4()}",
790+
task_queue=worker.task_queue,
791+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
792+
result_type=WorkflowResult,
793+
)
794+
update_handle = await client.start_update_with_start_workflow(
795+
"my_update",
796+
args=("update-arg1", "update-arg2"),
797+
wait_for_stage=WorkflowUpdateStage.COMPLETED,
798+
start_workflow_operation=two_param_start_op,
799+
result_type=UpdateResult,
800+
)
801+
assert await update_handle.result() == UpdateResult(
802+
result="update-arg1-update-arg2"
803+
)
804+
wf_handle = await two_param_start_op.workflow_handle()
805+
assert await wf_handle.result() == WorkflowResult(
806+
result="workflow-arg1-workflow-arg2"
807+
)

0 commit comments

Comments
 (0)