diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 52cd5d96..869f2df6 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -104,6 +104,10 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr logger.debug(f"executor_input_addr: {args.executor_input_ipc}") logger.debug(f"executor_output_addr: {args.executor_output_ipc}") logger.debug(f"nccl_port: {args.nccl_port}") + + # Pipe for subprocess communication + conn1, conn2 = multiprocessing.Pipe() + if args.scheduler_addr is None: if args.log_level != "DEBUG": display_parallax_join(args.model_path) @@ -138,6 +142,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr kvcache_mem_ratio=args.kvcache_mem_ratio, shared_state=shared_state.dict, # Pass dict to subprocess log_level=args.log_level, + conn=conn1, ) # Launch all executor processes (including tp_rank=0) @@ -149,6 +154,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr args=( args_copy, shared_state.dict, # Pass dict to subprocess + conn2, # Pipe connector ), ) proc.start() @@ -187,6 +193,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr kvcache_mem_ratio=args.kvcache_mem_ratio, shared_state=shared_state.dict, # Pass dict to subprocess log_level=args.log_level, + conn=conn1, ) # Wait for layer allocation from scheduler (via shared state) @@ -235,6 +242,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr args=( args_copy, shared_state.dict, # Pass dict to subprocess + conn2, # Pipe connector ), ) proc.start() diff --git a/src/parallax/p2p/server.py b/src/parallax/p2p/server.py index 9bf6e72e..86c0f12e 100644 --- a/src/parallax/p2p/server.py +++ b/src/parallax/p2p/server.py @@ -16,7 +16,7 @@ import shutil import threading import time -from typing import List, Optional +from typing import Any, List, Optional import dijkstar import httpx @@ -32,6 +32,7 @@ from parallax.utils.weight_refit_utils import ( calculate_cid_manual, concat_weight_partition, + filer_weight_cid_list, parse_safetensors_from_memory, release_disk_storage, ) @@ -256,21 +257,28 @@ def _download_weight_thread(weight_dir, cid): # step1. Check weight refit trigger message time_stamp = message.get("time_stamp", None) - cid_list = message.get("cid", None) + index_map = message.get("index_map", None) weight_version = message.get("version", 0) - if time_stamp is None or cid_list is None: + if time_stamp is None or index_map is None: return if gradient_server.last_refit_time >= float(time_stamp): # Weight already updated return + cid_list = filer_weight_cid_list( + gradient_server.block_start_index, + gradient_server.block_end_index, + gradient_server.block_end_index, + index_map, + ) random.seed(time.time()) random.shuffle(cid_list) - # add sleep 30s for direct connection first - logger.info(f"Start dealing weight refit message: {message}.") - logger.info(f"Wait for lattica direct connection.") - time.sleep(30) + # add sleep 10s for direct connection first + logger.debug(f"Received weight refit message: {message}.") + logger.info(f"Start dealing weight refit version: {weight_version}.") + logger.info(f"Wait 10s for lattica direct connection.") + time.sleep(10) # step2. download weight weight_dir = os.path.join("/tmp", str(time_stamp)) @@ -298,13 +306,9 @@ def _download_weight_thread(weight_dir, cid): # step3. concat weight # workaround: create sub-process to avoid GIL issues for lattica - logger.info(f"Start sub-process to concat weight partitions in {weight_dir}") - process = multiprocessing.Process( - target=concat_weight_partition, - args=(weight_dir, tensors), - ) - process.start() - process.join() + new_tensors = concat_weight_partition(tensors) + gradient_server.conn.send(new_tensors) + logger.info(f"New tensors sent to executor") # step4. send ipc message to update weight gradient_server.connection_handler.ipc_weight_refit(weight_dir, weight_version) @@ -349,6 +353,7 @@ def __init__( max_sequence_length: Optional[int] = None, param_mem_ratio: float = 0.65, kvcache_mem_ratio: float = 0.25, + conn: Any = None, ): self.recv_from_peer_addr = recv_from_peer_addr self.send_to_peer_addr = send_to_peer_addr @@ -385,6 +390,7 @@ def __init__( self.rtt_update_interval = 60 self.status = ServerState.JOINING self.manual_layer_assignment = block_end_index is not None and block_start_index is not None + self.conn = conn self.scheduler_stub = None self.scheduler_peer_id = None @@ -976,6 +982,7 @@ def _run_p2p_server_process( kvcache_mem_ratio: float = 0.25, shared_state: Optional[dict] = None, log_level: str = "INFO", + conn: Any = None, ): """Run P2P server in subprocess""" # Set log level in subprocess (spawn mode doesn't inherit log configuration) @@ -1006,6 +1013,7 @@ def _run_p2p_server_process( max_sequence_length=max_sequence_length, param_mem_ratio=param_mem_ratio, kvcache_mem_ratio=kvcache_mem_ratio, + conn=conn, ) # Attach shared state to server for syncing layer allocation if shared_state is not None: @@ -1055,6 +1063,7 @@ def launch_p2p_server_process( kvcache_mem_ratio: float = 0.25, shared_state: Optional[dict] = None, log_level: str = "INFO", + conn: Optional[Any] = None, ) -> multiprocessing.Process: """Launch P2P server as a subprocess and return the process object @@ -1089,6 +1098,7 @@ def launch_p2p_server_process( kvcache_mem_ratio, shared_state, log_level, + conn, ), ) process.start() diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 1420be44..eaeb95d0 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -87,6 +87,8 @@ def __init__( shared_state: Optional[dict] = None, # Weight Refit enable_weight_refit: Optional[bool] = False, + # Pipe communication + conn: Optional[Any] = None, ): # Backend if device is not None: @@ -110,6 +112,9 @@ def __init__( self.enable_weight_refit = enable_weight_refit self.weight_version = 0 + # Pipe communication + self.conn = conn + self.is_first_peer = start_layer == 0 self.is_last_peer = end_layer == self.config.get("num_hidden_layers") self.tp_size = tp_size diff --git a/src/parallax/server/executor/factory.py b/src/parallax/server/executor/factory.py index 5ca549da..a6442793 100755 --- a/src/parallax/server/executor/factory.py +++ b/src/parallax/server/executor/factory.py @@ -3,7 +3,7 @@ """ import argparse -from typing import Optional +from typing import Any, Optional from parallax.utils.utils import get_current_device from parallax_utils.logging_config import get_logger, set_log_level @@ -11,7 +11,7 @@ logger = get_logger(__name__) -def create_executor_config(args: argparse.Namespace, shared_state=None): +def create_executor_config(args: argparse.Namespace, shared_state=None, conn=None): """Create executor configuration from command line arguments.""" config = { @@ -41,6 +41,7 @@ def create_executor_config(args: argparse.Namespace, shared_state=None): "dp_size": getattr(args, "dp_size", 1), "nccl_port": args.nccl_port, "shared_state": shared_state, + "conn": conn, "use_hfcache": args.use_hfcache, "enable_lora": args.enable_lora, "max_lora_rank": args.max_lora_rank, @@ -59,13 +60,14 @@ def create_executor_config(args: argparse.Namespace, shared_state=None): def create_from_args( args, shared_state: Optional[dict] = None, + conn: Optional[Any] = None, device: Optional[str] = None, ): """ Creat executor for different backend. Lazy import here since CUDA modules cannot be import withough hardware support. """ - config = create_executor_config(args, shared_state) + config = create_executor_config(args, shared_state, conn) if device is None: device = get_current_device() if device is not None and device.startswith("cuda"): @@ -88,12 +90,12 @@ def create_from_args( return executor -def run_executor_process(args, shared_state=None): +def run_executor_process(args, shared_state=None, conn=None): """Run executor as a subprocess""" set_log_level(args.log_level) executor = None try: - executor = create_from_args(args, shared_state) + executor = create_from_args(args, shared_state, conn) executor.run_loop() except KeyboardInterrupt: logger.debug("Executor received interrupt signal, shutting down...") diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index e623e64e..462b51b5 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -87,6 +87,8 @@ def __init__( shared_state: Optional[dict] = None, # Weight Refit enable_weight_refit: Optional[bool] = False, + # Pipe communication + conn: Optional[Any] = None, ): logger.debug( f"Initializing MLX sharded model loader for repo={model_repo}, layers=[{start_layer}, {end_layer})" @@ -213,6 +215,7 @@ def __init__( tp_size=tp_size, shared_state=shared_state, enable_weight_refit=enable_weight_refit, + conn=conn, ) try: @@ -333,11 +336,6 @@ def handle_input_requests(self, requests: List[Request]): # This is an active request, add it to the scheduler queue to be processed. self.scheduler.enque_request(req) - def check_and_refit_weight(self, refit_weight_path: str): - if refit_weight_path == "": - return - self.shard_loader.update_weight_from_disk(self.model_shard, refit_weight_path) - def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True): """Process a batch of requests in MLX.""" # Run model and get updated cache diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 10582121..af9288b1 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -88,6 +88,8 @@ def __init__( shared_state: Optional[dict] = None, # Weight Refit enable_weight_refit: Optional[bool] = False, + # Pipe communication + conn: Optional[Any] = None, ): self.enable_lora = True if lora_paths is not None else enable_lora @@ -172,6 +174,7 @@ def __init__( dp_size=dp_size, shared_state=shared_state, enable_weight_refit=enable_weight_refit, + conn=conn, ) self.cur_batch = None self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) @@ -194,7 +197,8 @@ def __init__( def check_and_refit_weight(self, refit_weight_path: str): if refit_weight_path == "": return - refit_sgl_model(self.model_runner, refit_weight_path) + tensors = self.conn.recv() + refit_sgl_model(self.model_runner, tensors) def check_lora_server_args(self): assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" diff --git a/src/parallax/server/executor/vllm_executor.py b/src/parallax/server/executor/vllm_executor.py index c88d94db..3724c5e6 100755 --- a/src/parallax/server/executor/vllm_executor.py +++ b/src/parallax/server/executor/vllm_executor.py @@ -82,6 +82,8 @@ def __init__( shared_state: Optional[dict] = None, # Weight Refit enable_weight_refit: Optional[bool] = False, + # Pipe communication + conn: Optional[Any] = None, ): model_runner_params = { "model_repo": model_repo, @@ -134,6 +136,7 @@ def __init__( tp_size=tp_size, shared_state=shared_state, enable_weight_refit=enable_weight_refit, + conn=conn, ) def handle_input_requests(self, requests: List[Request]): diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 82b8227a..a899e33d 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -379,9 +379,11 @@ def initialize_sgl_model_runner( def refit_sgl_model( model_runner: ParallaxModelRunner, - refit_weight_path: str, + tensors: dict, ): """Runtime weight refit from disk""" - logger.info(f"Begin refit weight from path: {refit_weight_path}") - - model_runner.update_weights_from_disk(model_path=refit_weight_path, load_format="auto") + logger.info(f"Executor begins weight refit from host memory") + for x in tensors.keys(): + refit_tensors = [(x, tensors.get(x))] + model_runner.update_weights_from_tensor(named_tensors=refit_tensors, load_format=None) + logger.info(f"Finish weight refit") diff --git a/src/parallax/utils/weight_refit_utils.py b/src/parallax/utils/weight_refit_utils.py index b80ece31..dc4a90ed 100644 --- a/src/parallax/utils/weight_refit_utils.py +++ b/src/parallax/utils/weight_refit_utils.py @@ -8,7 +8,6 @@ from typing import Dict import torch -from safetensors.torch import save_file # CID constants CIDV1 = 0x01 @@ -32,7 +31,7 @@ def inplace_insert_value_with_idx(tensor_list, value, idx): tensor_list[idx] = value -def concat_weight_partition(refit_weight_path, original_tensors): +def concat_weight_partition(original_tensors): """ Concat partial weight into one safetensor. Partitioned weight should be named in the following format: @@ -42,9 +41,9 @@ def concat_weight_partition(refit_weight_path, original_tensors): # Concatenate if needed and save the final tensors sorted_keys = sorted(original_tensors.keys()) tensors = {} + res_tensors = {} prev_key = None concate_list = [] - file_idx = 0 max_size = 1024 * 1024 * 1024 # max size 1GB param_size = 0 for key in sorted_keys: @@ -53,9 +52,7 @@ def concat_weight_partition(refit_weight_path, original_tensors): tensors[key] = val param_size += val.numel() * val.element_size() if param_size > max_size: - save_file_name = refit_weight_path + "/model_" + str(file_idx) + ".safetensors" - save_file(tensors, save_file_name) - file_idx += 1 + res_tensors.update(tensors) param_size = 0 tensors = {} continue @@ -78,9 +75,7 @@ def concat_weight_partition(refit_weight_path, original_tensors): tensors[final_key] = concate_result param_size += val.numel() * val.element_size() if param_size > max_size: - save_file_name = refit_weight_path + "/model_" + str(file_idx) + ".safetensors" - save_file(tensors, save_file_name) - file_idx += 1 + res_tensors.update(tensors) param_size = 0 tensors = {} @@ -96,8 +91,8 @@ def concat_weight_partition(refit_weight_path, original_tensors): final_key = ".".join(cur_name_list) tensors[final_key] = concate_result - save_file_name = refit_weight_path + "/model_" + str(file_idx) + ".safetensors" - save_file(tensors, save_file_name) + res_tensors.update(tensors) + return res_tensors def is_block_needed(key, is_first_shard, is_last_shard, start_layer, end_layer) -> bool: