Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support release pipeline #2581

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(self,
self.req_manager = self._bind_request_manager()

# create main thread
self._stop_flag = False
self._start_loop()
self._create_buffers()
self.engine_instance = self.create_instance()
Expand Down Expand Up @@ -241,6 +242,15 @@ def _bind_request_manager(self):
req_manager.bind_func(RequestType.ADD_MESSAGE, self._on_add_message)
return req_manager

def close(self):
self._stop_flag = True
self.req_manager.close()
self.model_agent.close()
self.model_agent = None
self._seq_length_buf = None
self._inputs = None
torch._C._cuda_clearCublasWorkspaces()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it? Cannot torch manage it?


def _start_loop(self):
"""start loop."""
return self.req_manager.start_loop(self.async_loop)
Expand Down Expand Up @@ -930,6 +940,10 @@ async def __step():
out_que.task_done()

while True:
if self._stop_flag:
logger.info('Stop _async_loop')
loop_background.cancel()
break
if self.req_manager.has_requests():
self.req_manager.step()

Expand Down
45 changes: 39 additions & 6 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import asyncio
import atexit
import os
import sys
from datetime import timedelta
from functools import partial
from typing import Any, Callable, Dict, List
from weakref import ReferenceType, ref

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -192,6 +195,10 @@ def get_logits(self, hidden_states: torch.Tensor):
"""get logits of model output."""
raise NotImplementedError('Not implemented.')

def close(self):
"""release model."""
pass


class BaseModelAgent(AutoModelAgent):
"""Base model agent.
Expand Down Expand Up @@ -235,6 +242,9 @@ def __init__(self,

self.stream = torch.cuda.Stream()

def close(self):
del self.patched_model

def _build_model(self,
model_path: str,
adapters: Dict[str, str] = None,
Expand Down Expand Up @@ -540,10 +550,11 @@ def __init__(self,
trust_remote_code: bool = True) -> None:
import signal

def __signal_term_handler(sig, frame):
def __signal_term_handler(sig, frame, agent):
"""sigterm handler."""
if hasattr(self, 'mp_context'):
procs = self.mp_context.processes
agent = agent()
if hasattr(agent, 'mp_context'):
procs = agent.mp_context.processes
for p in procs:
if p.is_alive():
p.kill()
Expand All @@ -553,8 +564,10 @@ def __signal_term_handler(sig, frame):

super().__init__(model_config=model_config, cache_config=cache_config)

signal.signal(signal.SIGTERM, __signal_term_handler)
signal.signal(signal.SIGTERM,
partial(__signal_term_handler, agent=ref(self)))

self.old_sys_excepthook = sys.excepthook
self.mp_ctx = mp.get_context('spawn')
self.world_size = world_size
self.backend_config = backend_config
Expand All @@ -579,6 +592,22 @@ def __signal_term_handler(sig, frame):
self.cache_config = cache_config
self.cache_engine = cache_engine
self.stream = torch.cuda.Stream()
self.stop = False

def close(self):
_exit_by_sending_exit_flag(0, ref(self))
self.stop = True
procs: List[mp.Process] = self.mp_context.processes
for p in procs:
if p.is_alive():
logger.info(f'Terminate {p}')
p.terminate()
else:
logger.info(f'Close {p}')
p.close()
if dist.is_initialized():
dist.destroy_process_group()
sys.excepthook = self.old_sys_excepthook

def _start_sub_process(self, model_path: str, model_config: ModelConfig,
cache_config: CacheConfig,
Expand Down Expand Up @@ -627,7 +656,7 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig,
dist.destroy_process_group()
raise e
# Please see Note [Exit By Sending Exit Flag]
atexit.register(_exit_by_sending_exit_flag, rank, self)
atexit.register(_exit_by_sending_exit_flag, rank, ref(self))

@torch.inference_mode()
def _build_model(
Expand Down Expand Up @@ -715,10 +744,14 @@ def get_logits(self, hidden_states: torch.Tensor):
return self.patched_model.get_logits(hidden_states)


def _exit_by_sending_exit_flag(rank: int, agent: TPModelAgent):
def _exit_by_sending_exit_flag(rank: int,
agent: 'ReferenceType[TPModelAgent]'):
"""[Note] Exit By Sending Exit Flag: the registration to `atexit` of this
function should be called after importing torch.multiprocessing and the
initialization of distributed process group."""
agent = agent()
if agent is None or getattr(agent, 'stop', False):
return
if not hasattr(agent, 'stream'):
# agent is not initialized, just exits normally
if hasattr(agent, 'patched_model'):
Expand Down
17 changes: 17 additions & 0 deletions lmdeploy/pytorch/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,23 @@ def __init__(self, thread_safe: bool = False):
if thread_safe:
self.thread_requests = Queue()

def close(self):
if not self._thread_safe:
if self._loop_task is not None:
_run_until_complete(self._loop_task)
else:
loop = self.event_loop
tasks = asyncio.all_tasks(loop=loop)

async def cancel_tasks():
for task in tasks:
task.cancel()

f = asyncio.run_coroutine_threadsafe(cancel_tasks(), loop=loop)
f.result()
loop.call_soon_threadsafe(loop.stop)
self._loop_thread.join()

def create_loop_task(self):
"""create coro task."""
logger.debug('creating engine loop task.')
Expand Down
17 changes: 17 additions & 0 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,23 @@ def __init__(self,
self._session_id = count(0)
self.request_logger = RequestLogger(max_log_len)

def close(self):
self.gens_set.clear()
if self.engine is not None:
if isinstance(self.backend_config, PytorchEngineConfig):
self.engine.close()
self.engine = None
import gc
gc.collect()
import torch
torch.cuda.empty_cache()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.close()

def _build_turbomind(
self,
model_path: str,
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/serve/vl_async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def __init__(self, model_path: str, **kwargs) -> None:
self.vl_prompt_template = get_vl_prompt_template(
model_path, self.chat_template, self.model_name)

def close(self):
self.vl_encoder.close()
super().close()

def _convert_prompts(self,
prompts: Union[VLPromptType, List[Dict],
List[VLPromptType], List[List[Dict]]]):
Expand Down
8 changes: 7 additions & 1 deletion lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os.path as osp
import sys
import weakref
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from itertools import repeat
Expand Down Expand Up @@ -130,6 +131,10 @@ def __init__(self,
self.session_len = self.config.session_len
self.eos_id = self.tokenizer.eos_token_id

def __del__(self):
"""release hardware resources."""
self.model_comm.destroy_nccl_params(self.nccl_params)

def _create_weight(self, model_comm):
"""Allocate weight buffer, load params if from_workspace."""

Expand Down Expand Up @@ -314,7 +319,8 @@ def create_instance(self, cuda_stream_id=0):
Returns:
TurboMindInstance: an instance of turbomind
"""
return TurboMindInstance(self, self.config, cuda_stream_id)
return TurboMindInstance(weakref.proxy(self), self.config,
cuda_stream_id)


class TurboMindInstance:
Expand Down
13 changes: 13 additions & 0 deletions lmdeploy/vl/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,20 @@ def __init__(self,
torch.cuda.empty_cache()
self._que: asyncio.Queue = None
self._loop_task: asyncio.Task = None
self._stop = False
if vision_config.thread_safe:
self._create_thread_safe_task()

def close(self):
if self.model is not None:
self._stop = True
if self.vision_config.thread_safe:
self._loop_thread.join()
else:
if hasattr(self, '_loop'):
self._loop.run_until_complete(self._loop_task)
self.model = None

def _create_thread_safe_task(self):
"""thread safe loop task."""
self._loop = asyncio.new_event_loop()
Expand Down Expand Up @@ -138,6 +149,8 @@ async def _forward_loop(self):
while record.total == 0 or (self._que.qsize() and
record.total < self.max_batch_size):
while self._que.qsize() == 0:
if self._stop and record.total == 0:
return
await asyncio.sleep(0.01)
item = await self._que.get()
record.enqueue(item[0], item[1], item[2])
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/python/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ PYBIND11_MODULE(_turbomind, m)
"node_id"_a,
"device_id_start"_a = 0,
"multi_node"_a = false)
.def("destroy_nccl_params", &AbstractTransformerModel::destroyNcclParams, "params"_a)
.def(
"create_custom_comms",
[](AbstractTransformerModel* model, int world_size) {
Expand Down
11 changes: 11 additions & 0 deletions src/turbomind/triton_backend/transformer_triton_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,14 @@ AbstractTransformerModel::createNcclParams(const int node_id, const int device_i
}
return std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>>(tensor_para_params, pipeline_para_params);
}

void AbstractTransformerModel::destroyNcclParams(
std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>> params)
{
for (auto& param : params.first) {
ftNcclParamDestroy(param);
}
for (auto& param : params.second) {
ftNcclParamDestroy(param);
}
}
2 changes: 2 additions & 0 deletions src/turbomind/triton_backend/transformer_triton_backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ struct AbstractTransformerModel {
virtual std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>>
createNcclParams(const int node_id, const int device_id_start = 0, const bool multi_node = false);

virtual void destroyNcclParams(std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>> params);

virtual void createCustomComms(std::vector<std::shared_ptr<ft::AbstractCustomComm>>* custom_all_reduce_comms,
int world_size) = 0;

Expand Down
3 changes: 3 additions & 0 deletions src/turbomind/utils/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ class Allocator<AllocatorType::CUDA>: public IAllocator {
auto size_and_type = pointer_mapping_.begin()->second;
free(&ptr, size_and_type.second == MemoryType::HOST);
}
// release cuda memory to os
cudaMemPoolTrimTo(mempool_, 0);
check_cuda_error(cudaStreamSynchronize(stream_));
if (enable_peer_access_) { // We own the pool in this case
check_cuda_error(cudaMemPoolDestroy(mempool_));
mempool_ = {};
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/utils/cublasMMWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ cublasMMWrapper::~cublasMMWrapper()
allocator_->free((void**)(&cublas_workspace_));
allocator_ = nullptr;
}
cublasDestroy(cublas_handle_);
cublasLtDestroy(cublaslt_handle_);
}

cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper& wrapper):
Expand Down
Loading