From 13681ea215fd49098eca5084fd9a5c208961b40a Mon Sep 17 00:00:00 2001 From: David Gardner Date: Tue, 20 Aug 2024 08:02:50 -0700 Subject: [PATCH] Adjust python API for ControlMessage to match that of the bindings for the C++ impl --- .../morpheus/messages/control_message.py | 64 +++++++++---------- 1 file changed, 29 insertions(+), 35 deletions(-) diff --git a/python/morpheus/morpheus/messages/control_message.py b/python/morpheus/morpheus/messages/control_message.py index 90763b7350..b179fb464e 100644 --- a/python/morpheus/morpheus/messages/control_message.py +++ b/python/morpheus/morpheus/messages/control_message.py @@ -47,32 +47,30 @@ def __init__(self, config: dict = None): self._timestamps: dict[str, datetime] = {} self._type: ControlMessageType = ControlMessageType.NONE - self.config: dict = config + self.config(config) - @property - def config(self) -> dict: - return self._config + def config(self, config: dict = None) -> dict: + if config is not None: + cm_type: str | ControlMessageType = config.get("type") + if cm_type is not None: + if isinstance(cm_type, str): + try: + cm_type = get_enum_members(ControlMessageType)[cm_type] + except KeyError as exc: + raise ValueError( + f"Invalid ControlMessageType: {cm_type}, supported types: {get_enum_keys(ControlMessageType)}" + ) from exc - @config.setter - def config(self, config: dict): - cm_type: str | ControlMessageType = config.get("type") - if cm_type is not None: - if isinstance(cm_type, str): - try: - cm_type = get_enum_members(ControlMessageType)[cm_type] - except KeyError as exc: - raise ValueError( - f"Invalid ControlMessageType: {cm_type}, supported types: {get_enum_keys(ControlMessageType)}" - ) from exc + self._type = cm_type - self._type = cm_type + tasks = config.get("tasks") + if tasks is not None: + for task in tasks: + self.add_task(task["type"], task["properties"]) - tasks = config.get("tasks") - if tasks is not None: - for task in tasks: - self.add_task(task["type"], task["properties"]) + self._config = {"metadata": config.get("metadata", {}).copy()} - self._config = {"metadata": config.get("metadata", {}).copy()} + return self._config def has_task(self, task_type: str) -> bool: """ @@ -124,28 +122,24 @@ def get_metadata(self, key: str = None, fail_on_nonexist: bool = False) -> typin def list_metadata(self) -> list[str]: return sorted(self._config["metadata"].keys()) - def payload(self) -> MessageMeta | None: + def payload(self, payload: MessageMeta = None) -> MessageMeta | None: + if payload is not None: + self._payload = payload + return self._payload - def set_payload(self, payload: MessageMeta): - self._payload = payload + def tensors(self, tensors: TensorMemory = None) -> TensorMemory | None: + if tensors is not None: + self._tensors = tensors - @property - def tensors(self) -> TensorMemory | None: return self._tensors - @tensors.setter - def tensors(self, tensors: TensorMemory): - self._tensors = tensors + def task_type(self, new_task_type: ControlMessageType = None) -> ControlMessageType: + if new_task_type is not None: + self._type = new_task_type - @property - def task_type(self) -> ControlMessageType: return self._type - @task_type.setter - def task_type(self, task_type: ControlMessageType): - self._type = task_type - def set_timestamp(self, key: str, timestamp: datetime): self._timestamps[key] = timestamp