Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/backend/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import threading
import time
import uuid

Expand All @@ -9,6 +10,7 @@
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles

from backend.server.data_transfer_handler import DataTransferHandler
from backend.server.request_handler import RequestHandler
from backend.server.scheduler_manage import SchedulerManage
from backend.server.server_args import parse_args
Expand Down Expand Up @@ -241,6 +243,14 @@ async def serve_index():
if model_name is not None and init_nodes_num is not None:
scheduler_manage.run(model_name, init_nodes_num, is_local_network)

if args.enable_weight_refit:
data_transfer_handler = DataTransferHandler()
data_transfer_handler.set_scheduler_manage(scheduler_manage)
threading.Thread(
target=data_transfer_handler.run,
daemon=True,
).start()

host = args.host
port = args.port

Expand Down
133 changes: 133 additions & 0 deletions src/backend/server/data_transfer_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os
import random
import time

from lattica import Lattica

from parallax.cli import (
ENDPOINT_PROTOCOL_VERSION,
PUBLIC_INITIAL_PEERS,
PUBLIC_RELAY_SERVERS,
)
from parallax.utils.weight_refit_utils import release_disk_storage
from parallax_utils.logging_config import get_logger

logger = get_logger(__name__)


class DataTransferHandler:
def __init__(self):
self.scheduler_manage = None
self.endpoint_lattica = None
self.initial_peers = PUBLIC_INITIAL_PEERS
self.relay_servers = PUBLIC_RELAY_SERVERS
self.protocol_version = ENDPOINT_PROTOCOL_VERSION
self.host_maddrs = [
f"/ip4/0.0.0.0/tcp/0",
f"/ip4/0.0.0.0/udp/0/quic-v1",
]
self.block_path = "/tmp/endpoint" # workaround to distinct with worker node path
self.request_data = {}
self._start_endpoint_lattica()

def _start_endpoint_lattica(self):
"""
Start another lattica instance data transfer within endpoints and/or with echo.
"""
# Reuse existing Lattica if running
if self.endpoint_lattica is not None:
logger.debug("Endpoint lattica already running, reusing existing instance")
return

self.endpoint_lattica = (
Lattica.builder().with_listen_addrs(self.host_maddrs).with_key_path(".")
)

if len(self.relay_servers) > 0:
logger.info(f"Endpoint using relay servers: {self.relay_servers}")
logger.info(f"Endpoint launched with protocol version: {self.protocol_version}")
self.endpoint_lattica.with_relay_servers(self.relay_servers).with_dcutr(
True
).with_protocol(self.protocol_version)

if len(self.initial_peers) > 0:
logger.info(f"Endpoint using initial peers: {self.initial_peers}")
self.endpoint_lattica.with_bootstraps(self.initial_peers)

folder = os.path.exists(self.block_path)
if not folder:
os.makedirs(self.block_path)
self.endpoint_lattica.with_storage_path(self.block_path)
self.endpoint_lattica.with_dht_db_path(self.block_path)
self.endpoint_lattica.with_key_path(self.block_path)

self.endpoint_lattica.build()
logger.debug("Endpoint lattica built")

def _get_full_weight_blocks(self, message):

def _download_weight(cid):
raw_data = None
time_out = 10 * 60 # 10 minutes timeout
time_begin_get_block = time.time()
time_end_get_block = None
peer_id = None
while True:
try:
cur_time = time.time()
if cur_time - time_begin_get_block > time_out:
logger.warning(f"Failed to get_block after 10 minutes! cid={cid}")
return None
peer_id, raw_data = self.endpoint_lattica.get_block(cid, timeout_secs=30)
time_end_get_block = time.time()
break
except Exception:
logger.warning(f"Failed to get block: {cid}. Retry in 1 second.")
time.sleep(1)
if raw_data is None:
return None
interval_get_block = time_end_get_block - time_begin_get_block
logger.info(
f"Finish download cid={cid}, get_block={interval_get_block}s, peer_id={peer_id}"
)
new_cid = self.scheduler_manage.lattica.put_block(raw_data)
return new_cid

release_disk_storage(["/tmp/endpoint", "/tmp/scheduler"])

cid_list = message.get("cid", None)
random.seed(time.time())
random.shuffle(cid_list)
new_cid_list = []

# sleep 30s for lattica direct connection
time.sleep(30)

while True:
if len(cid_list) == 0:
break
else:
cid = cid_list.pop()
logger.info(f"Start downloading refit weight {cid}")
new_cid = _download_weight(cid)
if new_cid is None:
logger.warning(f"Endpoint failed to download full weight")
break
else:
new_cid_list.append(new_cid)
return new_cid_list

def set_scheduler_manage(self, scheduler_manage):
self.scheduler_manage = scheduler_manage

def run(self):
while True:
refit_data = self.scheduler_manage.refit_data
if len(refit_data) > 0:
new_cid_list = self._get_full_weight_blocks(refit_data)
# replace cid_list and feed to scheduler
refit_data["cid"] = new_cid_list
self.scheduler_manage.scheduler.refit_request = refit_data
self.scheduler_manage.scheduler.refit_set = set()
self.scheduler_manage.refit_data = {}
time.sleep(1)
14 changes: 12 additions & 2 deletions src/backend/server/scheduler_manage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import threading
import time
from typing import List
Expand Down Expand Up @@ -44,6 +45,8 @@ def __init__(
self.http_port = http_port
self.use_hfcache = use_hfcache
self.enable_weight_refit = enable_weight_refit
self.refit_data = {}
self.block_path = "/tmp/scheduler"
self.model_name = None
self.init_nodes_num = None
self.scheduler = None
Expand Down Expand Up @@ -122,8 +125,7 @@ def weight_refit(self, request_data):
"""
if self.scheduler is None:
return False
self.scheduler.refit_request = request_data
self.scheduler.refit_set = set()
self.refit_data = request_data
return True

def get_last_refit_time(self):
Expand Down Expand Up @@ -237,6 +239,14 @@ def _start_lattica(self):
logger.info(f"Using initial peers: {self.initial_peers}")
self.lattica.with_bootstraps(self.initial_peers)

if self.enable_weight_refit:
folder = os.path.exists(self.block_path)
if not folder:
os.makedirs(self.block_path)
self.lattica.with_storage_path(self.block_path)
self.lattica.with_dht_db_path(self.block_path)
self.lattica.with_key_path(self.block_path)

self.lattica.build()
logger.debug("Lattica node built")

Expand Down
2 changes: 2 additions & 0 deletions src/parallax/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
"/dns4/relay-lattica-eu.gradient.network/tcp/18080/p2p/12D3KooWRAuR7rMNA7Yd4S1vgKS6akiJfQoRNNexTtzWxYPiWfG5",
]

ENDPOINT_PROTOCOL_VERSION = "/gradient.network.parallax.protocal_version_0"


def check_python_version():
"""Check if Python version is 3.11 or higher."""
Expand Down
6 changes: 3 additions & 3 deletions src/parallax/p2p/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from parallax.utils.weight_refit_utils import (
calculate_cid_manual,
concat_weight_partition,
filer_weight_cid_list,
filter_weight_cid_list,
parse_safetensors_from_memory,
release_disk_storage,
)
Expand Down Expand Up @@ -253,7 +253,7 @@ def _download_weight_thread(weight_dir, cid):
return True, tensors

# step0. Release lattica disk storage
release_disk_storage()
release_disk_storage(["/tmp"])

# step1. Check weight refit trigger message
time_stamp = message.get("time_stamp", None)
Expand All @@ -265,7 +265,7 @@ def _download_weight_thread(weight_dir, cid):
# Weight already updated
return

cid_list = filer_weight_cid_list(
cid_list = filter_weight_cid_list(
gradient_server.block_start_index,
gradient_server.block_end_index,
gradient_server.block_end_index,
Expand Down
15 changes: 9 additions & 6 deletions src/parallax/utils/weight_refit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def is_block_needed(key, is_first_shard, is_last_shard, start_layer, end_layer)
return False


def filer_weight_cid_list(start_layer, end_layer, hidden_layers, index_map):
def filter_weight_cid_list(start_layer, end_layer, hidden_layers, index_map):
"""
Filters block cids that worker node needs to download.
Arguments:
Expand Down Expand Up @@ -148,12 +148,15 @@ def remove_list_dirs(dir_list):
continue


def release_disk_storage():
def release_disk_storage(base_dirs):
"""Remove lattica storage files before get blocks"""
storage_dir = "/tmp"
storage_dirs = glob.glob(storage_dir + "/*.storage")
key_dirs = glob.glob(storage_dir + "/*.key")
dht_dirs = glob.glob(storage_dir + "/*.dht")
storage_dirs = []
key_dirs = []
dht_dirs = []
for cache_dir in base_dirs:
storage_dirs.extend(glob.glob(cache_dir + "/*.storage"))
key_dirs.extend(glob.glob(cache_dir + "/*.key"))
dht_dirs.extend(glob.glob(cache_dir + "/*.dht"))
remove_list_dirs(storage_dirs)
remove_list_dirs(key_dirs)
remove_list_dirs(dht_dirs)
Expand Down
58 changes: 51 additions & 7 deletions src/router/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ async def broadcast_raw(
path: str,
headers: Dict[str, str],
body: bytes,
method: str,
) -> List[Dict[str, Any]]:
endpoints = await self._snapshot_endpoints()
if not endpoints:
Expand All @@ -509,7 +510,10 @@ async def _one(ep: Endpoint) -> Dict[str, Any]:
url = _join_url(ep.base_url, path)
try:
logger.info(f"Broadcasting raw to {url}")
resp = await client.post(url, headers=headers, content=body)
if method == "get":
resp = await client.get(url)
else:
resp = await client.post(url, headers=headers, content=body)
content_type = resp.headers.get("content-type", "")
if "application/json" in content_type.lower():
body_out: Any = resp.json()
Expand Down Expand Up @@ -796,6 +800,7 @@ async def health() -> JSONResponse:
"/endpoints",
"/v1/chat/completions",
"/weight/refit",
"/weight/refit/timestamp",
],
}
)
Expand Down Expand Up @@ -929,14 +934,53 @@ async def weight_refit(raw_request: Request) -> JSONResponse:
"""
headers = _filter_forward_headers(dict(raw_request.headers))
body = await raw_request.body()
results = await registry.broadcast_raw(path="/weight/refit", headers=headers, body=body)
ok = all(r.get("ok") is True for r in results)
results = await registry.broadcast_raw(
path="/weight/refit", headers=headers, body=body, method="post"
)
ok = any(r.get("ok") is True for r in results)
if ok:
return JSONResponse(
content={
"type": "weight_refit",
"data": None,
},
status_code=200,
)
else:
return JSONResponse(
content={
"type": "weight_refit",
"data": "Sever not ready",
},
status_code=500,
)


@app.get("/weight/refit/timestamp")
async def weight_refit(raw_request: Request) -> JSONResponse:
"""
Example:
curl -sS -X GET http://127.0.0.1:3001/weight/refit/timestamp
"""
results = await registry.broadcast_raw(
path="/weight/refit/timestamp", headers=None, body=None, method="get"
)
min_timestamp = None
for r in results:
if r.get("ok") == True:
body_out = r.get("response", None)
if body_out is not None:
ep_timestamp = body_out.get("latest_timestamp", 0.0)
if min_timestamp is None:
min_timestamp = ep_timestamp
else:
min_timestamp = min(min_timestamp, ep_timestamp)
if min_timestamp is None:
min_timestamp = 0.0
return JSONResponse(
status_code=200 if ok else 207,
status_code=200,
content={
"ok": ok,
"broadcast_count": len(results),
"results": results,
"latest_timestamp": min_timestamp,
},
)

Expand Down
2 changes: 1 addition & 1 deletion src/scheduling/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def max_requests(self) -> int:
if self.max_concurrent_requests is None:
return derived_max
else:
return min(self.max_concurrent_requests, derived_max)
return max(self.max_concurrent_requests, derived_max)

@property
def num_current_layers(self) -> int:
Expand Down
3 changes: 1 addition & 2 deletions src/scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def __init__(
self.model_info = model_info
self.num_layers = model_info.num_layers
self.routing_strategy: Literal["rr", "dp"] = routing_strategy
self.enable_weight_refit = enable_weight_refit
self.refit_request = {}
self.node_manager = NodeManager(initial_nodes=nodes)

allocator_class = (
Expand Down Expand Up @@ -119,6 +117,7 @@ def __init__(
)

# Weight refit
self.enable_weight_refit = enable_weight_refit
self.refit_request = {}
self.refit_set = set()
self.last_refit_time = 0.0
Expand Down