Skip to content

Commit

Permalink
Ensure timestamps are copied in LLMEngineStage (#1975)
Browse files Browse the repository at this point in the history
Closes #1973

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #1975
  • Loading branch information
dagardner-nv authored Oct 24, 2024
1 parent cc78e19 commit e351e07
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ class MORPHEUS_EXPORT ControlMessage
*/
std::optional<time_point_t> get_timestamp(const std::string& key, bool fail_if_nonexist = false);

/**
* @brief Return a reference to the timestamps map
*
* @return A const map reference containing timestamps
*/
const std::map<std::string, time_point_t>& get_timestamps() const;

/**
* @brief Retrieves timestamps for all keys that match a regex pattern.
*
Expand Down Expand Up @@ -340,6 +347,13 @@ struct MORPHEUS_EXPORT ControlMessageProxy
*/
static pybind11::object get_timestamp(ControlMessage& self, const std::string& key, bool fail_if_nonexist = false);

/**
* @brief Return all timestamps
*
* @return A Python dictionary of timestamps
*/
static pybind11::dict get_timestamps(ControlMessage& self);

/**
* @brief Retrieves timestamps for all keys that match a regex pattern from the ControlMessage object.
*
Expand Down
1 change: 1 addition & 0 deletions python/morpheus/morpheus/_lib/messages/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ControlMessage():
"""
Retrieve the timestamp for a given group and key. Returns None if the timestamp does not exist and fail_if_nonexist is False.
"""
def get_timestamps(self) -> dict: ...
def has_metadata(self, key: str) -> bool: ...
def has_task(self, task_type: str) -> bool: ...
def list_metadata(self) -> list: ...
Expand Down
1 change: 1 addition & 0 deletions python/morpheus/morpheus/_lib/messages/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ PYBIND11_MODULE(messages, _module)
"fail_if_nonexist is False.",
py::arg("key"),
py::arg("fail_if_nonexist") = false)
.def("get_timestamps", &ControlMessageProxy::get_timestamps)
.def("set_timestamp",
&ControlMessageProxy::set_timestamp,
"Set a timestamp for a given key and group.",
Expand Down
10 changes: 10 additions & 0 deletions python/morpheus/morpheus/_lib/src/messages/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ void ControlMessage::set_timestamp(const std::string& key, time_point_t timestam
m_timestamps[key] = timestamp_ns;
}

const std::map<std::string, time_point_t>& ControlMessage::get_timestamps() const
{
return m_timestamps;
}

std::map<std::string, time_point_t> ControlMessage::filter_timestamp(const std::string& regex_filter)
{
std::map<std::string, time_point_t> matching_timestamps;
Expand Down Expand Up @@ -365,6 +370,11 @@ py::list ControlMessageProxy::list_metadata(ControlMessage& self)
return py_keys;
}

py::dict ControlMessageProxy::get_timestamps(ControlMessage& self)
{
return py::cast(self.get_timestamps());
}

py::dict ControlMessageProxy::filter_timestamp(ControlMessage& self, const std::string& regex_filter)
{
auto cpp_map = self.filter_timestamp(regex_filter);
Expand Down
3 changes: 3 additions & 0 deletions python/morpheus/morpheus/messages/control_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def get_timestamp(self, key: str, fail_if_nonexist: bool = False) -> datetime |
raise ValueError("Timestamp for the specified key does not exist.") from e
return None

def get_timestamps(self) -> dict[str, datetime]:
return self._timestamps

def filter_timestamp(self, regex_filter: str) -> dict[str, datetime]:
re_obj = re.compile(regex_filter)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def _copy_tasks_and_metadata(self,
for tv in task_value:
dst.add_task(task, tv)

timestamps = src.get_timestamps()
for (ts_key, ts) in timestamps.items():
dst.set_timestamp(key=ts_key, timestamp=ts)

def _cast_to_cpp_control_message(self, py_message: ControlMessage, *,
cpp_messages_lib: types.ModuleType) -> ControlMessage:
"""
Expand Down
16 changes: 16 additions & 0 deletions tests/morpheus/messages/test_control_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,22 @@ def test_filter_timestamp():
assert result[f"{group}::key2"] == timestamp2, "The timestamp for key2 should match."


@pytest.mark.gpu_and_cpu_mode
def test_get_timestamps():
# Create a ControlMessage instance
msg = messages.ControlMessage()

# Setup test data
timestamp1 = datetime.datetime.now()
timestamp2 = timestamp1 + datetime.timedelta(seconds=1)
msg.set_timestamp("key1", timestamp1)
msg.set_timestamp("key2", timestamp2)

# Assert both keys are in the result and have correct timestamps
timestamps = msg.get_timestamps()
assert timestamps == {"key1": timestamp1, "key2": timestamp2}


@pytest.mark.gpu_and_cpu_modetest_tensor_manipulation_after_retrieval
def test_get_timestamp_fail_if_nonexist():
# Create a ControlMessage instance
Expand Down
42 changes: 40 additions & 2 deletions tests/morpheus_llm/stages/test_llm_engine_stage_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
# limitations under the License.

import os
from datetime import datetime

import pytest

from _utils import TEST_DIRS
from _utils import assert_results
from _utils.dataset_manager import DatasetManager
from morpheus.config import Config
from morpheus.messages import ControlMessage
from morpheus.pipeline.linear_pipeline import LinearPipeline
from morpheus.pipeline.stage_decorator import stage
from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage
from morpheus.stages.output.compare_dataframe_stage import CompareDataFrameStage
from morpheus.stages.output.in_memory_sink_stage import InMemorySinkStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage
from morpheus_llm.llm import LLMEngine
from morpheus_llm.llm.nodes.extracter_node import ExtracterNode
Expand All @@ -37,9 +43,10 @@ def _build_engine() -> LLMEngine:
return engine


def test_pipeline(config: Config, dataset_cudf: DatasetManager):
@pytest.mark.gpu_and_cpu_mode
def test_pipeline(config: Config, dataset: DatasetManager):
test_data = os.path.join(TEST_DIRS.validation_data_dir, 'root-cause-validation-data-input.jsonlines')
input_df = dataset_cudf[test_data]
input_df = dataset[test_data]
expected_df = input_df.copy(deep=True)
expected_df["response"] = expected_df['log']

Expand All @@ -53,3 +60,34 @@ def test_pipeline(config: Config, dataset_cudf: DatasetManager):
pipe.run()

assert_results(sink.get_results())


@pytest.mark.gpu_and_cpu_mode
def test_error_1973(config: Config, dataset: DatasetManager):
expected_timestamps: dict[str, datetime] = {}

@stage(execution_modes=(config.execution_mode, ))
def log_timestamp(msg: ControlMessage, *, timestamp_name: str) -> ControlMessage:
ts = datetime.now()
msg.set_timestamp(key=timestamp_name, timestamp=ts)
expected_timestamps[timestamp_name] = ts
return msg

task_payload = {"task_type": "llm_engine", "task_dict": {"input_keys": ['v1']}}
pipe = LinearPipeline(config)
pipe.set_source(InMemorySourceStage(config, dataframes=[dataset["filter_probs.csv"]]))
pipe.add_stage(DeserializeStage(config, task_type="llm_engine", task_payload=task_payload))
pipe.add_stage(log_timestamp(config, timestamp_name="pre_llm"))
pipe.add_stage(LLMEngineStage(config, engine=_build_engine()))
pipe.add_stage(log_timestamp(config, timestamp_name="post_llm"))
sink = pipe.add_stage(InMemorySinkStage(config))

pipe.run()

messages = sink.get_messages()
assert len(messages) == 1

msg = messages[0]
for (timestamp_name, expected_timestamp) in expected_timestamps.items():
actual_timestamp = msg.get_timestamp(timestamp_name, fail_if_nonexist=True)
assert actual_timestamp == expected_timestamp

0 comments on commit e351e07

Please sign in to comment.