|
4 | 4 | import numpy as np
|
5 | 5 | import multiprocessing as mp
|
6 | 6 | from multiprocessing import shared_memory
|
| 7 | +from lightllm.utils.log_utils import init_logger |
| 8 | + |
| 9 | +logger = init_logger(__name__) |
7 | 10 |
|
8 | 11 |
|
9 | 12 | class SharedArray:
|
10 | 13 | def __init__(self, name, shape, dtype):
|
11 | 14 | dtype_byte_num = np.array([1], dtype=dtype).dtype.itemsize
|
| 15 | + dest_size = np.prod(shape) * dtype_byte_num |
12 | 16 | try:
|
13 |
| - shm = shared_memory.SharedMemory(name=name, create=True, size=np.prod(shape) * dtype_byte_num) |
14 |
| - print(f"create shm {name}") |
| 17 | + shm = shared_memory.SharedMemory(name=name, create=True, size=dest_size) |
| 18 | + logger.info(f"create shm {name}") |
15 | 19 | except:
|
16 |
| - shm = shared_memory.SharedMemory(name=name, create=False, size=np.prod(shape) * dtype_byte_num) |
17 |
| - assert ( |
18 |
| - len(shm.buf) == np.prod(shape) * dtype_byte_num |
19 |
| - ), f"{len(shm.buf)} is not equal to {np.prod(shape) * dtype_byte_num}" |
20 |
| - print(f"link shm {name}") |
| 20 | + shm = shared_memory.SharedMemory(name=name, create=False, size=dest_size) |
| 21 | + logger.info(f"link shm {name}") |
| 22 | + |
| 23 | + if shm.size != dest_size: |
| 24 | + logger.info(f"size not same, unlink shm {name} and create again") |
| 25 | + shm.unlink() |
| 26 | + shm.close() |
| 27 | + try: |
| 28 | + shm = shared_memory.SharedMemory(name=name, create=True, size=dest_size) |
| 29 | + logger.info(f"create shm {name}") |
| 30 | + except Exception as e: |
| 31 | + shm = shared_memory.SharedMemory(name=name, create=False, size=dest_size) |
| 32 | + logger.info(f"error {str(e)} to link shm {name}") |
| 33 | + |
21 | 34 | self.shm = shm # SharedMemory 对象一定要被持有,否则会被释放
|
22 | 35 | self.arr = np.ndarray(shape, dtype=dtype, buffer=self.shm.buf)
|
23 | 36 |
|
|
0 commit comments