Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr
logger.debug(f"executor_input_addr: {args.executor_input_ipc}")
logger.debug(f"executor_output_addr: {args.executor_output_ipc}")
logger.debug(f"nccl_port: {args.nccl_port}")

# Pipe for subprocess communication
conn1, conn2 = multiprocessing.Pipe()

if args.scheduler_addr is None:
if args.log_level != "DEBUG":
display_parallax_join(args.model_path)
Expand Down Expand Up @@ -138,6 +142,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr
kvcache_mem_ratio=args.kvcache_mem_ratio,
shared_state=shared_state.dict, # Pass dict to subprocess
log_level=args.log_level,
conn=conn1,
)

# Launch all executor processes (including tp_rank=0)
Expand All @@ -149,6 +154,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr
args=(
args_copy,
shared_state.dict, # Pass dict to subprocess
conn2, # Pipe connector
),
)
proc.start()
Expand Down Expand Up @@ -187,6 +193,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr
kvcache_mem_ratio=args.kvcache_mem_ratio,
shared_state=shared_state.dict, # Pass dict to subprocess
log_level=args.log_level,
conn=conn1,
)

# Wait for layer allocation from scheduler (via shared state)
Expand Down Expand Up @@ -235,6 +242,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr
args=(
args_copy,
shared_state.dict, # Pass dict to subprocess
conn2, # Pipe connector
),
)
proc.start()
Expand Down
38 changes: 24 additions & 14 deletions src/parallax/p2p/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import shutil
import threading
import time
from typing import List, Optional
from typing import Any, List, Optional

import dijkstar
import httpx
Expand All @@ -32,6 +32,7 @@
from parallax.utils.weight_refit_utils import (
calculate_cid_manual,
concat_weight_partition,
filer_weight_cid_list,
parse_safetensors_from_memory,
release_disk_storage,
)
Expand Down Expand Up @@ -256,21 +257,28 @@ def _download_weight_thread(weight_dir, cid):

# step1. Check weight refit trigger message
time_stamp = message.get("time_stamp", None)
cid_list = message.get("cid", None)
index_map = message.get("index_map", None)
weight_version = message.get("version", 0)
if time_stamp is None or cid_list is None:
if time_stamp is None or index_map is None:
return
if gradient_server.last_refit_time >= float(time_stamp):
# Weight already updated
return

cid_list = filer_weight_cid_list(
gradient_server.block_start_index,
gradient_server.block_end_index,
gradient_server.block_end_index,
index_map,
)
random.seed(time.time())
random.shuffle(cid_list)

# add sleep 30s for direct connection first
logger.info(f"Start dealing weight refit message: {message}.")
logger.info(f"Wait for lattica direct connection.")
time.sleep(30)
# add sleep 10s for direct connection first
logger.debug(f"Received weight refit message: {message}.")
logger.info(f"Start dealing weight refit version: {weight_version}.")
logger.info(f"Wait 10s for lattica direct connection.")
time.sleep(10)

# step2. download weight
weight_dir = os.path.join("/tmp", str(time_stamp))
Expand Down Expand Up @@ -298,13 +306,9 @@ def _download_weight_thread(weight_dir, cid):

# step3. concat weight
# workaround: create sub-process to avoid GIL issues for lattica
logger.info(f"Start sub-process to concat weight partitions in {weight_dir}")
process = multiprocessing.Process(
target=concat_weight_partition,
args=(weight_dir, tensors),
)
process.start()
process.join()
new_tensors = concat_weight_partition(tensors)
gradient_server.conn.send(new_tensors)
logger.info(f"New tensors sent to executor")

# step4. send ipc message to update weight
gradient_server.connection_handler.ipc_weight_refit(weight_dir, weight_version)
Expand Down Expand Up @@ -349,6 +353,7 @@ def __init__(
max_sequence_length: Optional[int] = None,
param_mem_ratio: float = 0.65,
kvcache_mem_ratio: float = 0.25,
conn: Any = None,
):
self.recv_from_peer_addr = recv_from_peer_addr
self.send_to_peer_addr = send_to_peer_addr
Expand Down Expand Up @@ -385,6 +390,7 @@ def __init__(
self.rtt_update_interval = 60
self.status = ServerState.JOINING
self.manual_layer_assignment = block_end_index is not None and block_start_index is not None
self.conn = conn

self.scheduler_stub = None
self.scheduler_peer_id = None
Expand Down Expand Up @@ -976,6 +982,7 @@ def _run_p2p_server_process(
kvcache_mem_ratio: float = 0.25,
shared_state: Optional[dict] = None,
log_level: str = "INFO",
conn: Any = None,
):
"""Run P2P server in subprocess"""
# Set log level in subprocess (spawn mode doesn't inherit log configuration)
Expand Down Expand Up @@ -1006,6 +1013,7 @@ def _run_p2p_server_process(
max_sequence_length=max_sequence_length,
param_mem_ratio=param_mem_ratio,
kvcache_mem_ratio=kvcache_mem_ratio,
conn=conn,
)
# Attach shared state to server for syncing layer allocation
if shared_state is not None:
Expand Down Expand Up @@ -1055,6 +1063,7 @@ def launch_p2p_server_process(
kvcache_mem_ratio: float = 0.25,
shared_state: Optional[dict] = None,
log_level: str = "INFO",
conn: Optional[Any] = None,
) -> multiprocessing.Process:
"""Launch P2P server as a subprocess and return the process object

Expand Down Expand Up @@ -1089,6 +1098,7 @@ def launch_p2p_server_process(
kvcache_mem_ratio,
shared_state,
log_level,
conn,
),
)
process.start()
Expand Down
5 changes: 5 additions & 0 deletions src/parallax/server/executor/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
shared_state: Optional[dict] = None,
# Weight Refit
enable_weight_refit: Optional[bool] = False,
# Pipe communication
conn: Optional[Any] = None,
):
# Backend
if device is not None:
Expand All @@ -110,6 +112,9 @@ def __init__(
self.enable_weight_refit = enable_weight_refit
self.weight_version = 0

# Pipe communication
self.conn = conn

self.is_first_peer = start_layer == 0
self.is_last_peer = end_layer == self.config.get("num_hidden_layers")
self.tp_size = tp_size
Expand Down
12 changes: 7 additions & 5 deletions src/parallax/server/executor/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
"""

import argparse
from typing import Optional
from typing import Any, Optional

from parallax.utils.utils import get_current_device
from parallax_utils.logging_config import get_logger, set_log_level

logger = get_logger(__name__)


def create_executor_config(args: argparse.Namespace, shared_state=None):
def create_executor_config(args: argparse.Namespace, shared_state=None, conn=None):
"""Create executor configuration from command line arguments."""

config = {
Expand Down Expand Up @@ -41,6 +41,7 @@ def create_executor_config(args: argparse.Namespace, shared_state=None):
"dp_size": getattr(args, "dp_size", 1),
"nccl_port": args.nccl_port,
"shared_state": shared_state,
"conn": conn,
"use_hfcache": args.use_hfcache,
"enable_lora": args.enable_lora,
"max_lora_rank": args.max_lora_rank,
Expand All @@ -59,13 +60,14 @@ def create_executor_config(args: argparse.Namespace, shared_state=None):
def create_from_args(
args,
shared_state: Optional[dict] = None,
conn: Optional[Any] = None,
device: Optional[str] = None,
):
"""
Creat executor for different backend.
Lazy import here since CUDA modules cannot be import withough hardware support.
"""
config = create_executor_config(args, shared_state)
config = create_executor_config(args, shared_state, conn)
if device is None:
device = get_current_device()
if device is not None and device.startswith("cuda"):
Expand All @@ -88,12 +90,12 @@ def create_from_args(
return executor


def run_executor_process(args, shared_state=None):
def run_executor_process(args, shared_state=None, conn=None):
"""Run executor as a subprocess"""
set_log_level(args.log_level)
executor = None
try:
executor = create_from_args(args, shared_state)
executor = create_from_args(args, shared_state, conn)
executor.run_loop()
except KeyboardInterrupt:
logger.debug("Executor received interrupt signal, shutting down...")
Expand Down
8 changes: 3 additions & 5 deletions src/parallax/server/executor/mlx_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
shared_state: Optional[dict] = None,
# Weight Refit
enable_weight_refit: Optional[bool] = False,
# Pipe communication
conn: Optional[Any] = None,
):
logger.debug(
f"Initializing MLX sharded model loader for repo={model_repo}, layers=[{start_layer}, {end_layer})"
Expand Down Expand Up @@ -213,6 +215,7 @@ def __init__(
tp_size=tp_size,
shared_state=shared_state,
enable_weight_refit=enable_weight_refit,
conn=conn,
)

try:
Expand Down Expand Up @@ -333,11 +336,6 @@ def handle_input_requests(self, requests: List[Request]):
# This is an active request, add it to the scheduler queue to be processed.
self.scheduler.enque_request(req)

def check_and_refit_weight(self, refit_weight_path: str):
if refit_weight_path == "":
return
self.shard_loader.update_weight_from_disk(self.model_shard, refit_weight_path)

def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True):
"""Process a batch of requests in MLX."""
# Run model and get updated cache
Expand Down
6 changes: 5 additions & 1 deletion src/parallax/server/executor/sglang_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(
shared_state: Optional[dict] = None,
# Weight Refit
enable_weight_refit: Optional[bool] = False,
# Pipe communication
conn: Optional[Any] = None,
):

self.enable_lora = True if lora_paths is not None else enable_lora
Expand Down Expand Up @@ -172,6 +174,7 @@ def __init__(
dp_size=dp_size,
shared_state=shared_state,
enable_weight_refit=enable_weight_refit,
conn=conn,
)
self.cur_batch = None
self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False)
Expand All @@ -194,7 +197,8 @@ def __init__(
def check_and_refit_weight(self, refit_weight_path: str):
if refit_weight_path == "":
return
refit_sgl_model(self.model_runner, refit_weight_path)
tensors = self.conn.recv()
refit_sgl_model(self.model_runner, tensors)

def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
Expand Down
3 changes: 3 additions & 0 deletions src/parallax/server/executor/vllm_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def __init__(
shared_state: Optional[dict] = None,
# Weight Refit
enable_weight_refit: Optional[bool] = False,
# Pipe communication
conn: Optional[Any] = None,
):
model_runner_params = {
"model_repo": model_repo,
Expand Down Expand Up @@ -134,6 +136,7 @@ def __init__(
tp_size=tp_size,
shared_state=shared_state,
enable_weight_refit=enable_weight_refit,
conn=conn,
)

def handle_input_requests(self, requests: List[Request]):
Expand Down
10 changes: 6 additions & 4 deletions src/parallax/sglang/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,11 @@ def initialize_sgl_model_runner(

def refit_sgl_model(
model_runner: ParallaxModelRunner,
refit_weight_path: str,
tensors: dict,
):
"""Runtime weight refit from disk"""
logger.info(f"Begin refit weight from path: {refit_weight_path}")

model_runner.update_weights_from_disk(model_path=refit_weight_path, load_format="auto")
logger.info(f"Executor begins weight refit from host memory")
for x in tensors.keys():
refit_tensors = [(x, tensors.get(x))]
model_runner.update_weights_from_tensor(named_tensors=refit_tensors, load_format=None)
logger.info(f"Finish weight refit")
17 changes: 6 additions & 11 deletions src/parallax/utils/weight_refit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Dict

import torch
from safetensors.torch import save_file

# CID constants
CIDV1 = 0x01
Expand All @@ -32,7 +31,7 @@ def inplace_insert_value_with_idx(tensor_list, value, idx):
tensor_list[idx] = value


def concat_weight_partition(refit_weight_path, original_tensors):
def concat_weight_partition(original_tensors):
"""
Concat partial weight into one safetensor.
Partitioned weight should be named in the following format:
Expand All @@ -42,9 +41,9 @@ def concat_weight_partition(refit_weight_path, original_tensors):
# Concatenate if needed and save the final tensors
sorted_keys = sorted(original_tensors.keys())
tensors = {}
res_tensors = {}
prev_key = None
concate_list = []
file_idx = 0
max_size = 1024 * 1024 * 1024 # max size 1GB
param_size = 0
for key in sorted_keys:
Expand All @@ -53,9 +52,7 @@ def concat_weight_partition(refit_weight_path, original_tensors):
tensors[key] = val
param_size += val.numel() * val.element_size()
if param_size > max_size:
save_file_name = refit_weight_path + "/model_" + str(file_idx) + ".safetensors"
save_file(tensors, save_file_name)
file_idx += 1
res_tensors.update(tensors)
param_size = 0
tensors = {}
continue
Expand All @@ -78,9 +75,7 @@ def concat_weight_partition(refit_weight_path, original_tensors):
tensors[final_key] = concate_result
param_size += val.numel() * val.element_size()
if param_size > max_size:
save_file_name = refit_weight_path + "/model_" + str(file_idx) + ".safetensors"
save_file(tensors, save_file_name)
file_idx += 1
res_tensors.update(tensors)
param_size = 0
tensors = {}

Expand All @@ -96,8 +91,8 @@ def concat_weight_partition(refit_weight_path, original_tensors):
final_key = ".".join(cur_name_list)
tensors[final_key] = concate_result

save_file_name = refit_weight_path + "/model_" + str(file_idx) + ".safetensors"
save_file(tensors, save_file_name)
res_tensors.update(tensors)
return res_tensors


def is_block_needed(key, is_first_shard, is_last_shard, start_layer, end_layer) -> bool:
Expand Down