Skip to content

Commit cc7d665

Browse files
committed
WIP: nexus worker
1 parent 1a8793f commit cc7d665

File tree

5 files changed

+103
-39
lines changed

5 files changed

+103
-39
lines changed

temporalio/nexus/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import logging
2+
from collections.abc import Mapping
3+
from typing import Any, Optional
4+
5+
6+
class LoggerAdapter(logging.LoggerAdapter):
7+
def __init__(self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]):
8+
super().__init__(logger, extra or {})
9+
10+
11+
logger = LoggerAdapter(logging.getLogger(__name__), None)
12+
"""Logger that has additional details regarding the current Nexus operation."""

temporalio/worker/_activity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ async def drain_poll_queue(self) -> None:
201201

202202
# Only call this after run()/drain_poll_queue() have returned. This will not
203203
# raise an exception.
204+
# TODO(dan): check accuracy of this comment; I would say it *does* raise an exception.
204205
async def wait_all_completed(self) -> None:
205206
running_tasks = [v.task for v in self._running_activities.values() if v.task]
206207
if running_tasks:

temporalio/worker/_nexus.py

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pprint
99
from typing import (
1010
Any,
11+
Awaitable,
1112
Callable,
1213
Sequence,
1314
Union,
@@ -67,6 +68,7 @@ def __init__(
6768
self._interceptors = interceptors
6869
# TODO(dan): metric_meter
6970
self._metric_meter = metric_meter
71+
self._running_operations: dict[bytes, asyncio.Task] = {}
7072

7173
def _validate_nexus_services(
7274
self, nexus_services: Sequence[Any]
@@ -163,6 +165,11 @@ async def drain_poll_queue(self) -> None:
163165
except temporalio.bridge.worker.PollShutdownError:
164166
return
165167

168+
async def wait_all_completed(self) -> None:
169+
await asyncio.gather(
170+
*self._running_operations.values(), return_exceptions=False
171+
)
172+
166173
# TODO(dan): is it correct to import from temporalio.api.nexus?
167174
# Why are these things not exposed in temporalio.bridge?
168175
async def _handle_start_operation(
@@ -194,53 +201,77 @@ async def _handle_start_operation(
194201
print(
195202
f"🌈@@ worker received task with link: {google.protobuf.json_format.MessageToJson(l)}"
196203
)
204+
205+
# TODO(dan): shouldn't this be set in the _run_nexus_operation context? (that doesn't work currently)
197206
temporalio.nexus.handler._current_context.set(
198207
temporalio.nexus.handler._Context(
199208
client=self._client,
200209
task_queue=self._task_queue,
201210
)
202211
)
212+
self._running_operations[task_token] = asyncio.create_task(
213+
self._run_nexus_operation(task_token, operation.start, input, options)
214+
)
203215

204-
# message NexusTaskCompletion {
205-
# bytes task_token = 1;
206-
# oneof status {
207-
# temporal.api.nexus.v1.Response completed = 2;
208-
# temporal.api.nexus.v1.HandlerError error = 3;
209-
# bool ack_cancel = 4;
210-
# }
211-
# }
216+
# TODO(dan): start type
217+
async def _run_nexus_operation(
218+
self,
219+
task_token: bytes,
220+
start: Callable[..., Awaitable[Any]],
221+
input: Any,
222+
options: nexusrpc.handler.StartOperationOptions,
223+
) -> None:
224+
try:
225+
result = await start(input, options)
226+
except BaseException:
227+
# TODO(dan): mirror appropriate aspects of _run_activity error handling
228+
raise NotImplementedError(
229+
"TODO: Nexus operation error handling not implemented"
230+
)
212231

213-
result = await operation.start(input, options)
232+
try:
233+
# Send task completion
234+
if isinstance(result, nexusrpc.handler.StartOperationAsyncResult):
235+
print(f"🟢 Nexus operation started with async response {result}")
236+
op_resp = temporalio.api.nexus.v1.StartOperationResponse(
237+
async_success=temporalio.api.nexus.v1.StartOperationResponse.Async(
238+
operation_token=result.token,
239+
links=[
240+
temporalio.api.nexus.v1.Link(url=l.url, type=l.type)
241+
for l in result.links
242+
],
243+
)
244+
)
245+
else:
246+
# TODO(dan): are we going to use StartOperationSyncResult from nexusrpc?
247+
# (contains links and headers in addition to result) IIRC Go does something
248+
# like that.
249+
[payload] = await self._data_converter.encode([result])
250+
op_resp = temporalio.api.nexus.v1.StartOperationResponse(
251+
sync_success=temporalio.api.nexus.v1.StartOperationResponse.Sync(
252+
payload=payload
253+
)
254+
)
255+
# message NexusTaskCompletion {
256+
# bytes task_token = 1;
257+
# oneof status {
258+
# temporal.api.nexus.v1.Response completed = 2;
259+
# temporal.api.nexus.v1.HandlerError error = 3;
260+
# bool ack_cancel = 4;
261+
# }
262+
# }
214263

215-
# Send task completion
216-
if isinstance(result, nexusrpc.handler.StartOperationAsyncResult):
217-
print(
218-
f"🟢 Nexus operation {request.operation} started with async response {result}"
219-
)
220-
op_resp = temporalio.api.nexus.v1.StartOperationResponse(
221-
async_success=temporalio.api.nexus.v1.StartOperationResponse.Async(
222-
operation_token=result.token,
223-
links=[
224-
temporalio.api.nexus.v1.Link(url=l.url, type=l.type)
225-
for l in result.links
226-
],
264+
await self._bridge_worker().complete_nexus_task(
265+
temporalio.bridge.proto.nexus.NexusTaskCompletion(
266+
task_token=task_token,
267+
completed=temporalio.api.nexus.v1.Response(start_operation=op_resp),
227268
)
228269
)
229-
else:
230-
# TODO(dan): are we going to use StartOperationSyncResult from nexusrpc?
231-
# (contains links and headers in addition to result) IIRC Go does something
232-
# like that.
233-
[payload] = await self._data_converter.encode([result])
234-
op_resp = temporalio.api.nexus.v1.StartOperationResponse(
235-
sync_success=temporalio.api.nexus.v1.StartOperationResponse.Sync(
236-
payload=payload
237-
)
270+
del self._running_operations[task_token]
271+
except Exception:
272+
temporalio.nexus.logger.exception(
273+
"Failed to send Nexus operation completion"
238274
)
239-
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
240-
task_token=task_token,
241-
completed=temporalio.api.nexus.v1.Response(start_operation=op_resp),
242-
)
243-
await self._bridge_worker().complete_nexus_task(completion)
244275

245276
async def _handle_cancel_operation(
246277
self, request: temporalio.api.nexus.v1.CancelOperationRequest, task_token: bytes

temporalio/worker/_worker.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def __init__(
146146
maximum=5
147147
),
148148
) -> None:
149+
# TODO(dan): consider not allowing max_workers < max_concurrent_nexus_operations?
150+
# TODO(dan): Nexus tuner support?
149151
"""Create a worker to process workflows and/or activities.
150152
151153
Args:
@@ -171,6 +173,14 @@ def __init__(
171173
executor should at least be ``max_concurrent_activities`` or a
172174
warning is issued. Note, a broken-executor failure from this
173175
executor will cause the worker to fail and shutdown.
176+
nexus_operation_executor: Concurrent executor to use for non-async
177+
Nexus operations. This is required if any operation start methods
178+
are non-async. :py:class:`concurrent.futures.ThreadPoolExecutor`
179+
is recommended. If this is a
180+
:py:class:`concurrent.futures.ProcessPoolExecutor`, all non-async
181+
start methods must be picklable. ``max_workers`` on the executor
182+
should at least be ``max_concurrent_nexus_operations`` or a warning
183+
is issued.
174184
workflow_task_executor: Thread pool executor for workflow tasks. If
175185
this is not present, a new
176186
:py:class:`concurrent.futures.ThreadPoolExecutor` will be
@@ -199,6 +209,8 @@ def __init__(
199209
will ever be given to the activity worker concurrently. Mutually exclusive with ``tuner``.
200210
max_concurrent_local_activities: Maximum number of local activity
201211
tasks that will ever be given to the activity worker concurrently. Mutually exclusive with ``tuner``.
212+
max_concurrent_nexus_operations: Maximum number of Nexus operations that
213+
will ever be given to the Nexus worker concurrently. Mutually exclusive with ``tuner``.
202214
max_concurrent_workflow_tasks: Maximum allowed number of
203215
tasks that will ever be given to the workflow worker at one time. Mutually exclusive with ``tuner``.
204216
tuner: Provide a custom :py:class:`WorkerTuner`. Mutually exclusive with the
@@ -688,12 +700,15 @@ async def raise_on_shutdown():
688700
for task in tasks.values():
689701
task.cancel()
690702

691-
# If there's an activity worker, we have to let all activity completions
692-
# finish. We cannot guarantee that because poll shutdown completed
693-
# (which means activities completed) that they got flushed to the
694-
# server.
703+
# Let all activity / nexus operations completions finish. We cannot guarantee that
704+
# because poll shutdown completed (which means activities/operations completed)
705+
# that they got flushed to the server.
695706
if self._activity_worker:
696707
await self._activity_worker.wait_all_completed()
708+
if self._nexus_worker:
709+
await self._nexus_worker.wait_all_completed()
710+
711+
# TODO(dan): check that we do all appropriate things for nexus worker that we do for activity worker
697712

698713
# Do final shutdown
699714
try:

tests/worker/test_nexus.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
from temporalio.exceptions import CancelledError, NexusOperationError
3232
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
3333

34+
# TODO(dan): test availability of Temporal client etc in async context set by worker
35+
# TODO(dan): test worker shutdown, wait_all_completed, drain etc
36+
# TODO(dan): test worker op handling failure
37+
# TODO(dan): test contextual logger
38+
3439
# -----------------------------------------------------------------------------
3540
# Test definition
3641
#

0 commit comments

Comments
 (0)