From c9ce404825349be1a3f423b722dd81b38cef1b02 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Tue, 20 Aug 2024 10:33:56 -0700 Subject: [PATCH] Implement copy constructor --- .../morpheus/messages/control_message.py | 56 +++++++++++-------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/python/morpheus/morpheus/messages/control_message.py b/python/morpheus/morpheus/messages/control_message.py index dc78a77f3b..dd4d393178 100644 --- a/python/morpheus/morpheus/messages/control_message.py +++ b/python/morpheus/morpheus/messages/control_message.py @@ -37,7 +37,7 @@ @dataclasses.dataclass(init=False) class ControlMessage(MessageBase, cpp_class=_messages.ControlMessage): - def __init__(self, config: dict = None): + def __init__(self, config_or_message: typing.Union["ControlMessage", dict] = None): super().__init__() self._config: dict = {"metadata": {}} @@ -49,7 +49,15 @@ def __init__(self, config: dict = None): self._timestamps: dict[str, datetime] = {} self._type: ControlMessageType = ControlMessageType.NONE - self.config(config) + if isinstance(config_or_message, dict): + self.config(config_or_message) + elif isinstance(config_or_message, ControlMessage): + self._copy_impl(config_or_message, self) + elif config_or_message is not None: + raise ValueError(f"Invalid argument type {type(config_or_message)}, value must be a dict or ControlMessage") + + def copy(self) -> "ControlMessage": + return self._copy_impl(self) def config(self, config: dict = None) -> dict: if config is not None: @@ -74,24 +82,6 @@ def config(self, config: dict = None) -> dict: return self._config - def copy(self) -> "ControlMessage": - config = self._config.copy() - config["type"] = self.task_type().name - - tasks = [] - for (task_type, task_queue) in self.get_tasks().items(): - for task in task_queue: - tasks.append({"type": task_type, "properties": task}) - - config["tasks"] = tasks - - new_cm = ControlMessage(config) - new_cm._payload = self._payload - new_cm._tensors = self._tensors - new_cm._timestamps = self._timestamps.copy() - - return new_cm - def has_task(self, task_type: str) -> bool: """ Return True if the control message has at least one task of the given type @@ -100,7 +90,7 @@ def has_task(self, task_type: str) -> bool: tasks = self._tasks.get(task_type, []) return len(tasks) > 0 - def add_task(self, task_type: str, properties: dict): + def add_task(self, task_type: str, task: dict): if isinstance(task_type, str): cm_type = get_enum_members(ControlMessageType).get(task_type, ControlMessageType.NONE) if cm_type != ControlMessageType.NONE: @@ -109,7 +99,7 @@ def add_task(self, task_type: str, properties: dict): elif self._type != cm_type: raise ValueError("Cannot mix different types of tasks on the same control message") - self._tasks[task_type].append(properties) + self._tasks[task_type].append(task) def remove_task(self, task_type: str) -> dict: tasks = self._tasks.get(task_type, []) @@ -179,3 +169,25 @@ def filter_timestamp(self, regex_filter: str) -> dict[str, datetime]: re_obj = re.compile(regex_filter) return {key: value for key, value in self._timestamps.items() if re_obj.match(key)} + + @classmethod + def _copy_impl(cls, src: "ControlMessage", dst: "ControlMessage" = None) -> "ControlMessage": + config = src.config().copy() + config["type"] = src.task_type().name + + tasks = [] + for (task_type, task_queue) in src.get_tasks().items(): + for task in task_queue: + tasks.append({"type": task_type, "properties": task}) + + config["tasks"] = tasks + + if dst is None: + dst = cls() + + dst.config(config) + dst.payload(src.payload()) + dst.tensors(src.tensors()) + dst._timestamps = src._timestamps.copy() + + return dst