diff --git a/lightllm/server/router/dynamic_prompt/shared_arr.py b/lightllm/server/router/dynamic_prompt/shared_arr.py index e01630ed..9f567756 100644 --- a/lightllm/server/router/dynamic_prompt/shared_arr.py +++ b/lightllm/server/router/dynamic_prompt/shared_arr.py @@ -4,20 +4,33 @@ import numpy as np import multiprocessing as mp from multiprocessing import shared_memory +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class SharedArray: def __init__(self, name, shape, dtype): dtype_byte_num = np.array([1], dtype=dtype).dtype.itemsize + dest_size = np.prod(shape) * dtype_byte_num try: - shm = shared_memory.SharedMemory(name=name, create=True, size=np.prod(shape) * dtype_byte_num) - print(f"create shm {name}") + shm = shared_memory.SharedMemory(name=name, create=True, size=dest_size) + logger.info(f"create shm {name}") except: - shm = shared_memory.SharedMemory(name=name, create=False, size=np.prod(shape) * dtype_byte_num) - assert ( - len(shm.buf) == np.prod(shape) * dtype_byte_num - ), f"{len(shm.buf)} is not equal to {np.prod(shape) * dtype_byte_num}" - print(f"link shm {name}") + shm = shared_memory.SharedMemory(name=name, create=False, size=dest_size) + logger.info(f"link shm {name}") + + if shm.size != dest_size: + logger.info(f"size not same, unlink shm {name} and create again") + shm.unlink() + shm.close() + try: + shm = shared_memory.SharedMemory(name=name, create=True, size=dest_size) + logger.info(f"create shm {name}") + except Exception as e: + shm = shared_memory.SharedMemory(name=name, create=False, size=dest_size) + logger.info(f"error {str(e)} to link shm {name}") + self.shm = shm # SharedMemory 对象一定要被持有,否则会被释放 self.arr = np.ndarray(shape, dtype=dtype, buffer=self.shm.buf)