diff --git a/exo/main.py b/exo/main.py index e07a03b35..e93b8cc05 100644 --- a/exo/main.py +++ b/exo/main.py @@ -186,9 +186,15 @@ def configure_uvloop(): default_model=args.default_model, system_prompt=args.system_prompt ) -node.on_token.register("update_topology_viz").on_next( - lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None -) +buffered_token_output = {} +def update_topology_viz(req_id, token, __): + if not topology_viz: return + if req_id in buffered_token_output: buffered_token_output[req_id].append(token) + else: buffered_token_output[req_id] = [token] + + if inference_engine.shard.model_id != 'stable-diffusion-2-1-base': + topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(buffered_token_output[req_id])) +node.on_token.register("update_topology_viz").on_next(update_topology_viz) def preemptively_start_download(request_id: str, opaque_status: str): try: diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index da67d9c67..83cc0f01f 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -131,9 +131,9 @@ async def CollectTopology(self, request, context): if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}") return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph) - async def SendNewToken(self, request, context): + async def SendResult(self, request, context): request_id = request.request_id - token = request.token + result = request.result is_finished = request.is_finished img = request.tensor if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}") diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index 00453deba..9a10c126e 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -130,9 +130,9 @@ async def process_inference_result( self.buffered_token_output[request_id][0].append(token.item()) is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") - asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id])) + asyncio.create_task(self.broadcast_result(request_id, [self.buffered_token_output[request_id][0][-1]], is_finished)) forward = token.reshape(1, -1) - intermediate_result = self.buffered_token_output[request_id][0] + intermediate_result = self.buffered_token_output[request_id][0][-1] else: forward = result else: @@ -586,17 +586,17 @@ def trigger_on_token_callbacks(self, request_id: str, token: int, is_finished: b if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {token=} {is_finished=}") self.on_token.trigger_all(request_id, token, is_finished) - async def broadcast_new_token(self, request_id: str, token: int, is_finished: bool) -> None: - async def send_new_token_to_peer(peer): + async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None: + async def send_result_to_peer(peer): try: - await asyncio.wait_for(peer.send_new_token(request_id, token, is_finished), timeout=15.0) + await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0) except asyncio.TimeoutError: - print(f"Timeout broadcasting new token to {peer.id()}") + print(f"Timeout broadcasting result to {peer.id()}") except Exception as e: - print(f"Error broadcasting new token to {peer.id()}: {e}") + print(f"Error broadcasting result to {peer.id()}: {e}") traceback.print_exc() - await asyncio.gather(*[send_new_token_to_peer(peer) for peer in self.peers], return_exceptions=True) + await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True) async def broadcast_opaque_status(self, request_id: str, status: str) -> None: if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}")