|
1 | | -import asyncio |
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import Protocol, Optional |
2 | 3 |
|
| 4 | +from ockam import RemoteNode, WorkerProtocol, LocalNodeProtocol |
3 | 5 |
|
4 | | -async def call_agent(agent, query): |
5 | | - return await agent.send(query) |
| 6 | +import json |
| 7 | + |
| 8 | + |
| 9 | +class SquadWorkerFactoryProtocol(Protocol): |
| 10 | + async def create(self, node: LocalNodeProtocol) -> WorkerProtocol: ... |
| 11 | + |
| 12 | + |
| 13 | +class Sharding: |
| 14 | + pass |
| 15 | + |
| 16 | + |
| 17 | +@dataclass |
| 18 | +class PerNode(Sharding): |
| 19 | + n: int = 1 |
| 20 | + |
| 21 | + |
| 22 | +@dataclass |
| 23 | +class PerItem(Sharding): |
| 24 | + pass |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class SquadStep: |
| 29 | + name: str |
| 30 | + worker_factory: SquadWorkerFactoryProtocol |
| 31 | + sharding: Optional[Sharding] = None |
| 32 | + |
| 33 | + |
| 34 | +def split_list_into_n_parts(lst, n): |
| 35 | + q, r = divmod(len(lst), n) |
| 36 | + return [lst[i * q + min(i, r) : (i + 1) * q + min(i + 1, r)] for i in range(n)] |
| 37 | + |
| 38 | + |
| 39 | +class SquadWorker: |
| 40 | + node: LocalNodeProtocol |
| 41 | + runners: list[str] |
| 42 | + prev_step_sharding: Optional[Sharding] |
| 43 | + steps: list[SquadStep] |
| 44 | + |
| 45 | + def __init__(self, runners: list[str], prev_step_sharding: Optional[SquadStep], steps: list[SquadStep]): |
| 46 | + self.runners = runners |
| 47 | + self.prev_step_sharding = prev_step_sharding |
| 48 | + self.steps = steps |
| 49 | + |
| 50 | + def random_address(self) -> str: |
| 51 | + import secrets |
| 52 | + |
| 53 | + return secrets.token_hex(6) |
| 54 | + |
| 55 | + async def handle_message(self, context, message): |
| 56 | + import json |
| 57 | + import ockam |
| 58 | + |
| 59 | + step = self.steps.pop(0) |
| 60 | + |
| 61 | + if not self.prev_step_sharding or isinstance(self.prev_step_sharding, PerItem): |
| 62 | + worker_name = self.random_address() |
| 63 | + worker = await step.worker_factory.create(self.node) |
| 64 | + await self.node.start_worker(worker_name, worker) |
| 65 | + reply = await self.node.send_and_receive(worker_name, message, timeout=600) |
| 66 | + await self.node.stop_worker(worker_name) |
| 67 | + elif isinstance(self.prev_step_sharding, PerNode): |
| 68 | + worker_name = self.random_address() |
| 69 | + worker = await step.worker_factory.create(self.node) |
| 70 | + await self.node.start_worker(worker_name, worker) |
| 71 | + |
| 72 | + items = json.loads(message) |
| 73 | + replies = [] |
| 74 | + for item in items: |
| 75 | + reply = await self.node.send_and_receive(worker_name, json.dumps(item), timeout=600) |
| 76 | + reply = json.loads(reply) |
| 77 | + if isinstance(reply, list): |
| 78 | + replies.extend(reply) |
| 79 | + else: |
| 80 | + replies.append(reply) |
| 81 | + |
| 82 | + await self.node.stop_worker(worker_name) |
| 83 | + reply = json.dumps(replies) |
| 84 | + else: |
| 85 | + raise ValueError(f"Invalid sharding: {self.prev_step_sharding}") |
| 86 | + |
| 87 | + if not step.sharding: |
| 88 | + await context.reply(reply) |
| 89 | + elif isinstance(step.sharding, PerNode): |
| 90 | + # n = step.sharding.n For now support only 1 |
| 91 | + |
| 92 | + async def start(runner: RemoteNode, chunk): |
| 93 | + worker_name = self.random_address() |
| 94 | + worker = SquadWorker(self.runners, step.sharding, self.steps) |
| 95 | + await runner.start_worker(worker_name, worker) |
| 96 | + |
| 97 | + mailbox_name = self.random_address() |
| 98 | + mailbox = await self.node.create_mailbox(mailbox_name) |
| 99 | + |
| 100 | + await mailbox.send_to_remote(runner.name, worker_name, json.dumps(chunk)) |
| 101 | + |
| 102 | + return runner, worker_name, mailbox |
| 103 | + |
| 104 | + async def finish(runner: RemoteNode, worker_name: str, mailbox): |
| 105 | + result = await mailbox.receive(timeout=1000) |
| 106 | + await runner.stop_worker(worker_name) |
| 107 | + return json.loads(result) |
| 108 | + |
| 109 | + items = json.loads(reply) |
| 110 | + |
| 111 | + if len(self.runners) > len(items): |
| 112 | + runners = self.runners[: len(items)] |
| 113 | + else: |
| 114 | + runners = self.runners |
| 115 | + |
| 116 | + chunks = split_list_into_n_parts(items, len(runners)) |
| 117 | + |
| 118 | + futures = [start(RemoteNode(self.node, runner), chunk) for runner, chunk in zip(runners, chunks)] |
| 119 | + handles = await ockam.gather(*futures, batch_size=10) |
| 120 | + |
| 121 | + futures = [finish(runner, worker_name, mailbox) for runner, worker_name, mailbox in handles] |
| 122 | + results = await ockam.gather(*futures, batch_size=100) |
| 123 | + |
| 124 | + await context.reply(json.dumps(results)) |
| 125 | + |
| 126 | + elif isinstance(step.sharding, PerItem): |
| 127 | + |
| 128 | + async def start(item): |
| 129 | + worker_name = self.random_address() |
| 130 | + worker = SquadWorker(self.runners.copy(), step.sharding, self.steps.copy()) |
| 131 | + worker.node = self.node |
| 132 | + await self.node.start_worker(worker_name, worker) |
| 133 | + |
| 134 | + mailbox_name = self.random_address() |
| 135 | + mailbox = await self.node.create_mailbox(mailbox_name) |
| 136 | + |
| 137 | + await mailbox.send(worker_name, json.dumps(item)) |
| 138 | + |
| 139 | + return worker_name, mailbox |
| 140 | + |
| 141 | + async def finish(worker_name, mailbox): |
| 142 | + result = await mailbox.receive(timeout=1000) |
| 143 | + await self.node.stop_worker(worker_name) |
| 144 | + return json.loads(result) |
| 145 | + |
| 146 | + items = json.loads(reply) |
| 147 | + |
| 148 | + futures = [start(item) for item in items] |
| 149 | + handles = await ockam.gather(*futures, batch_size=10) |
| 150 | + |
| 151 | + futures = [finish(worker_name, mailbox) for worker_name, mailbox in handles] |
| 152 | + results = await ockam.gather(*futures, batch_size=100) |
| 153 | + |
| 154 | + await context.reply(json.dumps(results)) |
| 155 | + else: |
| 156 | + raise ValueError(f"Invalid sharding: {step.sharding}") |
6 | 157 |
|
7 | 158 |
|
8 | 159 | class Squad: |
9 | | - def __init__(self): |
10 | | - self.coroutines = [] |
| 160 | + @staticmethod |
| 161 | + async def run(node: LocalNodeProtocol, runners: list[RemoteNode], data, steps: list[SquadStep]): |
| 162 | + if not steps: |
| 163 | + return [] |
| 164 | + |
| 165 | + runners = list(map(lambda x: x.name, runners)) |
| 166 | + |
| 167 | + worker = SquadWorker(runners, None, steps) |
| 168 | + worker.node = node |
| 169 | + worker_name = worker.random_address() |
| 170 | + |
| 171 | + await node.start_worker(worker_name, worker) |
| 172 | + |
| 173 | + result = await node.send_and_receive(worker_name, json.dumps(data), timeout=1000) |
11 | 174 |
|
12 | | - def add(self, agent, query, portals=[]): |
13 | | - self.coroutines.append(call_agent(agent, query)) |
| 175 | + await node.stop_worker(worker_name) |
14 | 176 |
|
15 | | - async def run(self): |
16 | | - async with asyncio.TaskGroup() as tg: |
17 | | - tasks = [tg.create_task(coro) for coro in self.coroutines] |
18 | | - return [task.result() for task in tasks] |
| 177 | + return result |
0 commit comments