|
8 | 8 | import pprint
|
9 | 9 | from typing import (
|
10 | 10 | Any,
|
| 11 | + Awaitable, |
11 | 12 | Callable,
|
12 | 13 | Sequence,
|
13 | 14 | Union,
|
@@ -67,6 +68,7 @@ def __init__(
|
67 | 68 | self._interceptors = interceptors
|
68 | 69 | # TODO(dan): metric_meter
|
69 | 70 | self._metric_meter = metric_meter
|
| 71 | + self._running_operations: dict[bytes, asyncio.Task] = {} |
70 | 72 |
|
71 | 73 | def _validate_nexus_services(
|
72 | 74 | self, nexus_services: Sequence[Any]
|
@@ -163,6 +165,11 @@ async def drain_poll_queue(self) -> None:
|
163 | 165 | except temporalio.bridge.worker.PollShutdownError:
|
164 | 166 | return
|
165 | 167 |
|
| 168 | + async def wait_all_completed(self) -> None: |
| 169 | + await asyncio.gather( |
| 170 | + *self._running_operations.values(), return_exceptions=False |
| 171 | + ) |
| 172 | + |
166 | 173 | # TODO(dan): is it correct to import from temporalio.api.nexus?
|
167 | 174 | # Why are these things not exposed in temporalio.bridge?
|
168 | 175 | async def _handle_start_operation(
|
@@ -194,53 +201,77 @@ async def _handle_start_operation(
|
194 | 201 | print(
|
195 | 202 | f"🌈@@ worker received task with link: {google.protobuf.json_format.MessageToJson(l)}"
|
196 | 203 | )
|
| 204 | + |
| 205 | + # TODO(dan): shouldn't this be set in the _run_nexus_operation context? (that doesn't work currently) |
197 | 206 | temporalio.nexus.handler._current_context.set(
|
198 | 207 | temporalio.nexus.handler._Context(
|
199 | 208 | client=self._client,
|
200 | 209 | task_queue=self._task_queue,
|
201 | 210 | )
|
202 | 211 | )
|
| 212 | + self._running_operations[task_token] = asyncio.create_task( |
| 213 | + self._run_nexus_operation(task_token, operation.start, input, options) |
| 214 | + ) |
203 | 215 |
|
204 |
| - # message NexusTaskCompletion { |
205 |
| - # bytes task_token = 1; |
206 |
| - # oneof status { |
207 |
| - # temporal.api.nexus.v1.Response completed = 2; |
208 |
| - # temporal.api.nexus.v1.HandlerError error = 3; |
209 |
| - # bool ack_cancel = 4; |
210 |
| - # } |
211 |
| - # } |
| 216 | + # TODO(dan): start type |
| 217 | + async def _run_nexus_operation( |
| 218 | + self, |
| 219 | + task_token: bytes, |
| 220 | + start: Callable[..., Awaitable[Any]], |
| 221 | + input: Any, |
| 222 | + options: nexusrpc.handler.StartOperationOptions, |
| 223 | + ) -> None: |
| 224 | + try: |
| 225 | + result = await start(input, options) |
| 226 | + except BaseException: |
| 227 | + # TODO(dan): mirror appropriate aspects of _run_activity error handling |
| 228 | + raise NotImplementedError( |
| 229 | + "TODO: Nexus operation error handling not implemented" |
| 230 | + ) |
212 | 231 |
|
213 |
| - result = await operation.start(input, options) |
| 232 | + try: |
| 233 | + # Send task completion |
| 234 | + if isinstance(result, nexusrpc.handler.StartOperationAsyncResult): |
| 235 | + print(f"🟢 Nexus operation started with async response {result}") |
| 236 | + op_resp = temporalio.api.nexus.v1.StartOperationResponse( |
| 237 | + async_success=temporalio.api.nexus.v1.StartOperationResponse.Async( |
| 238 | + operation_token=result.token, |
| 239 | + links=[ |
| 240 | + temporalio.api.nexus.v1.Link(url=l.url, type=l.type) |
| 241 | + for l in result.links |
| 242 | + ], |
| 243 | + ) |
| 244 | + ) |
| 245 | + else: |
| 246 | + # TODO(dan): are we going to use StartOperationSyncResult from nexusrpc? |
| 247 | + # (contains links and headers in addition to result) IIRC Go does something |
| 248 | + # like that. |
| 249 | + [payload] = await self._data_converter.encode([result]) |
| 250 | + op_resp = temporalio.api.nexus.v1.StartOperationResponse( |
| 251 | + sync_success=temporalio.api.nexus.v1.StartOperationResponse.Sync( |
| 252 | + payload=payload |
| 253 | + ) |
| 254 | + ) |
| 255 | + # message NexusTaskCompletion { |
| 256 | + # bytes task_token = 1; |
| 257 | + # oneof status { |
| 258 | + # temporal.api.nexus.v1.Response completed = 2; |
| 259 | + # temporal.api.nexus.v1.HandlerError error = 3; |
| 260 | + # bool ack_cancel = 4; |
| 261 | + # } |
| 262 | + # } |
214 | 263 |
|
215 |
| - # Send task completion |
216 |
| - if isinstance(result, nexusrpc.handler.StartOperationAsyncResult): |
217 |
| - print( |
218 |
| - f"🟢 Nexus operation {request.operation} started with async response {result}" |
219 |
| - ) |
220 |
| - op_resp = temporalio.api.nexus.v1.StartOperationResponse( |
221 |
| - async_success=temporalio.api.nexus.v1.StartOperationResponse.Async( |
222 |
| - operation_token=result.token, |
223 |
| - links=[ |
224 |
| - temporalio.api.nexus.v1.Link(url=l.url, type=l.type) |
225 |
| - for l in result.links |
226 |
| - ], |
| 264 | + await self._bridge_worker().complete_nexus_task( |
| 265 | + temporalio.bridge.proto.nexus.NexusTaskCompletion( |
| 266 | + task_token=task_token, |
| 267 | + completed=temporalio.api.nexus.v1.Response(start_operation=op_resp), |
227 | 268 | )
|
228 | 269 | )
|
229 |
| - else: |
230 |
| - # TODO(dan): are we going to use StartOperationSyncResult from nexusrpc? |
231 |
| - # (contains links and headers in addition to result) IIRC Go does something |
232 |
| - # like that. |
233 |
| - [payload] = await self._data_converter.encode([result]) |
234 |
| - op_resp = temporalio.api.nexus.v1.StartOperationResponse( |
235 |
| - sync_success=temporalio.api.nexus.v1.StartOperationResponse.Sync( |
236 |
| - payload=payload |
237 |
| - ) |
| 270 | + del self._running_operations[task_token] |
| 271 | + except Exception: |
| 272 | + temporalio.nexus.logger.exception( |
| 273 | + "Failed to send Nexus operation completion" |
238 | 274 | )
|
239 |
| - completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( |
240 |
| - task_token=task_token, |
241 |
| - completed=temporalio.api.nexus.v1.Response(start_operation=op_resp), |
242 |
| - ) |
243 |
| - await self._bridge_worker().complete_nexus_task(completion) |
244 | 275 |
|
245 | 276 | async def _handle_cancel_operation(
|
246 | 277 | self, request: temporalio.api.nexus.v1.CancelOperationRequest, task_token: bytes
|
|
0 commit comments