Skip to content

Commit

Permalink
Implement copy constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Aug 20, 2024
1 parent b449bd4 commit c9ce404
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions python/morpheus/morpheus/messages/control_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
@dataclasses.dataclass(init=False)
class ControlMessage(MessageBase, cpp_class=_messages.ControlMessage):

def __init__(self, config: dict = None):
def __init__(self, config_or_message: typing.Union["ControlMessage", dict] = None):
super().__init__()

self._config: dict = {"metadata": {}}
Expand All @@ -49,7 +49,15 @@ def __init__(self, config: dict = None):
self._timestamps: dict[str, datetime] = {}
self._type: ControlMessageType = ControlMessageType.NONE

self.config(config)
if isinstance(config_or_message, dict):
self.config(config_or_message)
elif isinstance(config_or_message, ControlMessage):
self._copy_impl(config_or_message, self)
elif config_or_message is not None:
raise ValueError(f"Invalid argument type {type(config_or_message)}, value must be a dict or ControlMessage")

def copy(self) -> "ControlMessage":
return self._copy_impl(self)

def config(self, config: dict = None) -> dict:
if config is not None:
Expand All @@ -74,24 +82,6 @@ def config(self, config: dict = None) -> dict:

return self._config

def copy(self) -> "ControlMessage":
config = self._config.copy()
config["type"] = self.task_type().name

tasks = []
for (task_type, task_queue) in self.get_tasks().items():
for task in task_queue:
tasks.append({"type": task_type, "properties": task})

config["tasks"] = tasks

new_cm = ControlMessage(config)
new_cm._payload = self._payload
new_cm._tensors = self._tensors
new_cm._timestamps = self._timestamps.copy()

return new_cm

def has_task(self, task_type: str) -> bool:
"""
Return True if the control message has at least one task of the given type
Expand All @@ -100,7 +90,7 @@ def has_task(self, task_type: str) -> bool:
tasks = self._tasks.get(task_type, [])
return len(tasks) > 0

def add_task(self, task_type: str, properties: dict):
def add_task(self, task_type: str, task: dict):
if isinstance(task_type, str):
cm_type = get_enum_members(ControlMessageType).get(task_type, ControlMessageType.NONE)
if cm_type != ControlMessageType.NONE:
Expand All @@ -109,7 +99,7 @@ def add_task(self, task_type: str, properties: dict):
elif self._type != cm_type:
raise ValueError("Cannot mix different types of tasks on the same control message")

self._tasks[task_type].append(properties)
self._tasks[task_type].append(task)

def remove_task(self, task_type: str) -> dict:
tasks = self._tasks.get(task_type, [])
Expand Down Expand Up @@ -179,3 +169,25 @@ def filter_timestamp(self, regex_filter: str) -> dict[str, datetime]:
re_obj = re.compile(regex_filter)

return {key: value for key, value in self._timestamps.items() if re_obj.match(key)}

@classmethod
def _copy_impl(cls, src: "ControlMessage", dst: "ControlMessage" = None) -> "ControlMessage":
config = src.config().copy()
config["type"] = src.task_type().name

tasks = []
for (task_type, task_queue) in src.get_tasks().items():
for task in task_queue:
tasks.append({"type": task_type, "properties": task})

config["tasks"] = tasks

if dst is None:
dst = cls()

dst.config(config)
dst.payload(src.payload())
dst.tensors(src.tensors())
dst._timestamps = src._timestamps.copy()

return dst

0 comments on commit c9ce404

Please sign in to comment.