Skip to content

Commit f1edf18

Browse files
committed
feat(python): implement Squad
1 parent 0dca24f commit f1edf18

File tree

1 file changed

+170
-11
lines changed
  • implementations/python/python/ockam/squads

1 file changed

+170
-11
lines changed
Lines changed: 170 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,177 @@
1-
import asyncio
1+
from dataclasses import dataclass
2+
from typing import Protocol, Optional
23

4+
from ockam import RemoteNode, WorkerProtocol, LocalNodeProtocol
35

4-
async def call_agent(agent, query):
5-
return await agent.send(query)
6+
import json
7+
8+
9+
class SquadWorkerFactoryProtocol(Protocol):
10+
async def create(self, node: LocalNodeProtocol) -> WorkerProtocol: ...
11+
12+
13+
class Sharding:
14+
pass
15+
16+
17+
@dataclass
18+
class PerNode(Sharding):
19+
n: int = 1
20+
21+
22+
@dataclass
23+
class PerItem(Sharding):
24+
pass
25+
26+
27+
@dataclass
28+
class SquadStep:
29+
name: str
30+
worker_factory: SquadWorkerFactoryProtocol
31+
sharding: Optional[Sharding] = None
32+
33+
34+
def split_list_into_n_parts(lst, n):
35+
q, r = divmod(len(lst), n)
36+
return [lst[i * q + min(i, r) : (i + 1) * q + min(i + 1, r)] for i in range(n)]
37+
38+
39+
class SquadWorker:
40+
node: LocalNodeProtocol
41+
runners: list[str]
42+
prev_step_sharding: Optional[Sharding]
43+
steps: list[SquadStep]
44+
45+
def __init__(self, runners: list[str], prev_step_sharding: Optional[SquadStep], steps: list[SquadStep]):
46+
self.runners = runners
47+
self.prev_step_sharding = prev_step_sharding
48+
self.steps = steps
49+
50+
def random_address(self) -> str:
51+
import secrets
52+
53+
return secrets.token_hex(6)
54+
55+
async def handle_message(self, context, message):
56+
import json
57+
import ockam
58+
59+
step = self.steps.pop(0)
60+
61+
if not self.prev_step_sharding or isinstance(self.prev_step_sharding, PerItem):
62+
worker_name = self.random_address()
63+
worker = await step.worker_factory.create(self.node)
64+
await self.node.start_worker(worker_name, worker)
65+
reply = await self.node.send_and_receive(worker_name, message, timeout=600)
66+
await self.node.stop_worker(worker_name)
67+
elif isinstance(self.prev_step_sharding, PerNode):
68+
worker_name = self.random_address()
69+
worker = await step.worker_factory.create(self.node)
70+
await self.node.start_worker(worker_name, worker)
71+
72+
items = json.loads(message)
73+
replies = []
74+
for item in items:
75+
reply = await self.node.send_and_receive(worker_name, json.dumps(item), timeout=600)
76+
reply = json.loads(reply)
77+
if isinstance(reply, list):
78+
replies.extend(reply)
79+
else:
80+
replies.append(reply)
81+
82+
await self.node.stop_worker(worker_name)
83+
reply = json.dumps(replies)
84+
else:
85+
raise ValueError(f"Invalid sharding: {self.prev_step_sharding}")
86+
87+
if not step.sharding:
88+
await context.reply(reply)
89+
elif isinstance(step.sharding, PerNode):
90+
# n = step.sharding.n For now support only 1
91+
92+
async def start(runner: RemoteNode, chunk):
93+
worker_name = self.random_address()
94+
worker = SquadWorker(self.runners, step.sharding, self.steps)
95+
await runner.start_worker(worker_name, worker)
96+
97+
mailbox_name = self.random_address()
98+
mailbox = await self.node.create_mailbox(mailbox_name)
99+
100+
await mailbox.send_to_remote(runner.name, worker_name, json.dumps(chunk))
101+
102+
return runner, worker_name, mailbox
103+
104+
async def finish(runner: RemoteNode, worker_name: str, mailbox):
105+
result = await mailbox.receive(timeout=1000)
106+
await runner.stop_worker(worker_name)
107+
return json.loads(result)
108+
109+
items = json.loads(reply)
110+
111+
if len(self.runners) > len(items):
112+
runners = self.runners[: len(items)]
113+
else:
114+
runners = self.runners
115+
116+
chunks = split_list_into_n_parts(items, len(runners))
117+
118+
futures = [start(RemoteNode(self.node, runner), chunk) for runner, chunk in zip(runners, chunks)]
119+
handles = await ockam.gather(*futures, batch_size=10)
120+
121+
futures = [finish(runner, worker_name, mailbox) for runner, worker_name, mailbox in handles]
122+
results = await ockam.gather(*futures, batch_size=100)
123+
124+
await context.reply(json.dumps(results))
125+
126+
elif isinstance(step.sharding, PerItem):
127+
128+
async def start(item):
129+
worker_name = self.random_address()
130+
worker = SquadWorker(self.runners.copy(), step.sharding, self.steps.copy())
131+
worker.node = self.node
132+
await self.node.start_worker(worker_name, worker)
133+
134+
mailbox_name = self.random_address()
135+
mailbox = await self.node.create_mailbox(mailbox_name)
136+
137+
await mailbox.send(worker_name, json.dumps(item))
138+
139+
return worker_name, mailbox
140+
141+
async def finish(worker_name, mailbox):
142+
result = await mailbox.receive(timeout=1000)
143+
await self.node.stop_worker(worker_name)
144+
return json.loads(result)
145+
146+
items = json.loads(reply)
147+
148+
futures = [start(item) for item in items]
149+
handles = await ockam.gather(*futures, batch_size=10)
150+
151+
futures = [finish(worker_name, mailbox) for worker_name, mailbox in handles]
152+
results = await ockam.gather(*futures, batch_size=100)
153+
154+
await context.reply(json.dumps(results))
155+
else:
156+
raise ValueError(f"Invalid sharding: {step.sharding}")
6157

7158

8159
class Squad:
9-
def __init__(self):
10-
self.coroutines = []
160+
@staticmethod
161+
async def run(node: LocalNodeProtocol, runners: list[RemoteNode], data, steps: list[SquadStep]):
162+
if not steps:
163+
return []
164+
165+
runners = list(map(lambda x: x.name, runners))
166+
167+
worker = SquadWorker(runners, None, steps)
168+
worker.node = node
169+
worker_name = worker.random_address()
170+
171+
await node.start_worker(worker_name, worker)
172+
173+
result = await node.send_and_receive(worker_name, json.dumps(data), timeout=1000)
11174

12-
def add(self, agent, query, portals=[]):
13-
self.coroutines.append(call_agent(agent, query))
175+
await node.stop_worker(worker_name)
14176

15-
async def run(self):
16-
async with asyncio.TaskGroup() as tg:
17-
tasks = [tg.create_task(coro) for coro in self.coroutines]
18-
return [task.result() for task in tasks]
177+
return result

0 commit comments

Comments
 (0)