Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure timestamps are copied in LLMEngineStage #1975

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ class MORPHEUS_EXPORT ControlMessage
*/
std::optional<time_point_t> get_timestamp(const std::string& key, bool fail_if_nonexist = false);

/**
* @brief Return a reference to the timestamps map
*
* @return A const map reference containing timestamps
*/
const std::map<std::string, time_point_t>& get_timestamps() const;

/**
* @brief Retrieves timestamps for all keys that match a regex pattern.
*
Expand Down Expand Up @@ -340,6 +347,13 @@ struct MORPHEUS_EXPORT ControlMessageProxy
*/
static pybind11::object get_timestamp(ControlMessage& self, const std::string& key, bool fail_if_nonexist = false);

/**
* @brief Return all timestamps
*
* @return A Python dictionary of timestamps
*/
static pybind11::dict get_timestamps(ControlMessage& self);

/**
* @brief Retrieves timestamps for all keys that match a regex pattern from the ControlMessage object.
*
Expand Down
1 change: 1 addition & 0 deletions python/morpheus/morpheus/_lib/messages/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ControlMessage():
"""
Retrieve the timestamp for a given group and key. Returns None if the timestamp does not exist and fail_if_nonexist is False.
"""
def get_timestamps(self) -> dict: ...
def has_metadata(self, key: str) -> bool: ...
def has_task(self, task_type: str) -> bool: ...
def list_metadata(self) -> list: ...
Expand Down
1 change: 1 addition & 0 deletions python/morpheus/morpheus/_lib/messages/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ PYBIND11_MODULE(messages, _module)
"fail_if_nonexist is False.",
py::arg("key"),
py::arg("fail_if_nonexist") = false)
.def("get_timestamps", &ControlMessageProxy::get_timestamps)
.def("set_timestamp",
&ControlMessageProxy::set_timestamp,
"Set a timestamp for a given key and group.",
Expand Down
10 changes: 10 additions & 0 deletions python/morpheus/morpheus/_lib/src/messages/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ void ControlMessage::set_timestamp(const std::string& key, time_point_t timestam
m_timestamps[key] = timestamp_ns;
}

const std::map<std::string, time_point_t>& ControlMessage::get_timestamps() const
{
return m_timestamps;
}

std::map<std::string, time_point_t> ControlMessage::filter_timestamp(const std::string& regex_filter)
{
std::map<std::string, time_point_t> matching_timestamps;
Expand Down Expand Up @@ -365,6 +370,11 @@ py::list ControlMessageProxy::list_metadata(ControlMessage& self)
return py_keys;
}

py::dict ControlMessageProxy::get_timestamps(ControlMessage& self)
{
return py::cast(self.get_timestamps());
}

py::dict ControlMessageProxy::filter_timestamp(ControlMessage& self, const std::string& regex_filter)
{
auto cpp_map = self.filter_timestamp(regex_filter);
Expand Down
3 changes: 3 additions & 0 deletions python/morpheus/morpheus/messages/control_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def get_timestamp(self, key: str, fail_if_nonexist: bool = False) -> datetime |
raise ValueError("Timestamp for the specified key does not exist.") from e
return None

def get_timestamps(self) -> dict[str, datetime]:
return self._timestamps

def filter_timestamp(self, regex_filter: str) -> dict[str, datetime]:
re_obj = re.compile(regex_filter)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def _copy_tasks_and_metadata(self,
for tv in task_value:
dst.add_task(task, tv)

timestamps = src.get_timestamps()
for (ts_key, ts) in timestamps.items():
dst.set_timestamp(key=ts_key, timestamp=ts)

def _cast_to_cpp_control_message(self, py_message: ControlMessage, *,
cpp_messages_lib: types.ModuleType) -> ControlMessage:
"""
Expand Down
16 changes: 16 additions & 0 deletions tests/morpheus/messages/test_control_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,22 @@ def test_filter_timestamp():
assert result[f"{group}::key2"] == timestamp2, "The timestamp for key2 should match."


@pytest.mark.gpu_and_cpu_mode
def test_get_timestamps():
# Create a ControlMessage instance
msg = messages.ControlMessage()

# Setup test data
timestamp1 = datetime.datetime.now()
timestamp2 = timestamp1 + datetime.timedelta(seconds=1)
msg.set_timestamp("key1", timestamp1)
msg.set_timestamp("key2", timestamp2)

# Assert both keys are in the result and have correct timestamps
timestamps = msg.get_timestamps()
assert timestamps == {"key1": timestamp1, "key2": timestamp2}


@pytest.mark.gpu_and_cpu_modetest_tensor_manipulation_after_retrieval
def test_get_timestamp_fail_if_nonexist():
# Create a ControlMessage instance
Expand Down
42 changes: 40 additions & 2 deletions tests/morpheus_llm/stages/test_llm_engine_stage_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
# limitations under the License.

import os
from datetime import datetime

import pytest

from _utils import TEST_DIRS
from _utils import assert_results
from _utils.dataset_manager import DatasetManager
from morpheus.config import Config
from morpheus.messages import ControlMessage
from morpheus.pipeline.linear_pipeline import LinearPipeline
from morpheus.pipeline.stage_decorator import stage
from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage
from morpheus.stages.output.compare_dataframe_stage import CompareDataFrameStage
from morpheus.stages.output.in_memory_sink_stage import InMemorySinkStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage
from morpheus_llm.llm import LLMEngine
from morpheus_llm.llm.nodes.extracter_node import ExtracterNode
Expand All @@ -37,9 +43,10 @@ def _build_engine() -> LLMEngine:
return engine


def test_pipeline(config: Config, dataset_cudf: DatasetManager):
@pytest.mark.gpu_and_cpu_mode
def test_pipeline(config: Config, dataset: DatasetManager):
test_data = os.path.join(TEST_DIRS.validation_data_dir, 'root-cause-validation-data-input.jsonlines')
input_df = dataset_cudf[test_data]
input_df = dataset[test_data]
expected_df = input_df.copy(deep=True)
expected_df["response"] = expected_df['log']

Expand All @@ -53,3 +60,34 @@ def test_pipeline(config: Config, dataset_cudf: DatasetManager):
pipe.run()

assert_results(sink.get_results())


@pytest.mark.gpu_and_cpu_mode
def test_error_1973(config: Config, dataset: DatasetManager):
expected_timestamps: dict[str, datetime] = {}

@stage(execution_modes=(config.execution_mode, ))
def log_timestamp(msg: ControlMessage, *, timestamp_name: str) -> ControlMessage:
ts = datetime.now()
msg.set_timestamp(key=timestamp_name, timestamp=ts)
expected_timestamps[timestamp_name] = ts
return msg

task_payload = {"task_type": "llm_engine", "task_dict": {"input_keys": ['v1']}}
pipe = LinearPipeline(config)
pipe.set_source(InMemorySourceStage(config, dataframes=[dataset["filter_probs.csv"]]))
pipe.add_stage(DeserializeStage(config, task_type="llm_engine", task_payload=task_payload))
pipe.add_stage(log_timestamp(config, timestamp_name="pre_llm"))
pipe.add_stage(LLMEngineStage(config, engine=_build_engine()))
pipe.add_stage(log_timestamp(config, timestamp_name="post_llm"))
sink = pipe.add_stage(InMemorySinkStage(config))

pipe.run()

messages = sink.get_messages()
assert len(messages) == 1

msg = messages[0]
for (timestamp_name, expected_timestamp) in expected_timestamps.items():
actual_timestamp = msg.get_timestamp(timestamp_name, fail_if_nonexist=True)
assert actual_timestamp == expected_timestamp
Loading