From 32c982e9b9f1edbda52a7ca0043b7b0ec25bf01c Mon Sep 17 00:00:00 2001 From: Yuchen Zhang <134643420+yczhang-nv@users.noreply.github.com> Date: Thu, 5 Dec 2024 12:35:36 -0800 Subject: [PATCH 1/3] Add `tensor_count` property for ControlMessage (#2078) For ControlMessage, msg.tensors().count is a common pattern, calling msg.tensors() might require a bit more cost than we think. Add a `tensor_count` property to avoid the overhead. Closes #1876 ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md). - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - Yuchen Zhang (https://github.com/yczhang-nv) Approvers: - David Gardner (https://github.com/dagardner-nv) URL: https://github.com/nv-morpheus/Morpheus/pull/2078 --- examples/log_parsing/inference.py | 24 +++++++++---------- .../include/morpheus/messages/control.hpp | 9 +++++++ .../morpheus/_lib/messages/__init__.pyi | 1 + .../morpheus/_lib/messages/module.cpp | 1 + .../morpheus/_lib/src/messages/control.cpp | 15 ++++++++---- .../src/stages/inference_client_stage.cpp | 4 ++-- .../tests/messages/test_control_message.cpp | 5 +++- .../morpheus/messages/control_message.py | 5 ++++ .../inference/identity_inference_stage.py | 4 ++-- .../stages/inference/inference_stage.py | 6 ++--- .../inference/pytorch_inference_stage.py | 4 ++-- .../inference/triton_inference_stage.py | 2 +- .../stages/postprocess/ml_flow_drift_stage.py | 2 +- .../test_abp_pcap_preprocessing.py | 2 +- tests/examples/log_parsing/test_inference.py | 4 ++-- .../morpheus/messages/test_control_message.py | 16 ++++++++----- tests/morpheus/stages/test_inference_stage.py | 2 +- 17 files changed, 68 insertions(+), 38 deletions(-) diff --git a/examples/log_parsing/inference.py b/examples/log_parsing/inference.py index c815389e5f..27c83cf59d 100644 --- a/examples/log_parsing/inference.py +++ b/examples/log_parsing/inference.py @@ -57,16 +57,16 @@ class TritonInferenceLogParsing(TritonInferenceWorker): """ def build_output_message(self, msg: ControlMessage) -> ControlMessage: - seq_ids = cp.zeros((msg.tensors().count, 3), dtype=cp.uint32) - seq_ids[:, 0] = cp.arange(0, msg.tensors().count, dtype=cp.uint32) + seq_ids = cp.zeros((msg.tensor_count(), 3), dtype=cp.uint32) + seq_ids[:, 0] = cp.arange(0, msg.tensor_count(), dtype=cp.uint32) seq_ids[:, 2] = msg.tensors().get_tensor('seq_ids')[:, 2] memory = TensorMemory( - count=msg.tensors().count, + count=msg.tensor_count(), tensors={ - 'confidences': cp.zeros((msg.tensors().count, self._inputs[list(self._inputs.keys())[0]].shape[1])), - 'labels': cp.zeros((msg.tensors().count, self._inputs[list(self._inputs.keys())[0]].shape[1])), - 'input_ids': cp.zeros((msg.tensors().count, msg.tensors().get_tensor('input_ids').shape[1])), + 'confidences': cp.zeros((msg.tensor_count(), self._inputs[list(self._inputs.keys())[0]].shape[1])), + 'labels': cp.zeros((msg.tensor_count(), self._inputs[list(self._inputs.keys())[0]].shape[1])), + 'input_ids': cp.zeros((msg.tensor_count(), msg.tensors().get_tensor('input_ids').shape[1])), 'seq_ids': seq_ids }) @@ -154,19 +154,19 @@ def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: Tens seq_offset = seq_ids[0, 0].item() seq_count = seq_ids[-1, 0].item() + 1 - seq_offset - input_ids[batch_offset:inf.tensors().count + batch_offset, :] = inf.tensors().get_tensor('input_ids') - out_seq_ids[batch_offset:inf.tensors().count + batch_offset, :] = seq_ids + input_ids[batch_offset:inf.tensor_count() + batch_offset, :] = inf.tensors().get_tensor('input_ids') + out_seq_ids[batch_offset:inf.tensor_count() + batch_offset, :] = seq_ids resp_confidences = res.get_tensor('confidences') resp_labels = res.get_tensor('labels') # Two scenarios: - if (inf.payload().count == inf.tensors().count): + if (inf.payload().count == inf.tensor_count()): assert seq_count == res.count - confidences[batch_offset:inf.tensors().count + batch_offset, :] = resp_confidences - labels[batch_offset:inf.tensors().count + batch_offset, :] = resp_labels + confidences[batch_offset:inf.tensor_count() + batch_offset, :] = resp_confidences + labels[batch_offset:inf.tensor_count() + batch_offset, :] = resp_labels else: - assert inf.tensors().count == res.count + assert inf.tensor_count() == res.count mess_ids = seq_ids[:, 0].get().tolist() diff --git a/python/morpheus/morpheus/_lib/include/morpheus/messages/control.hpp b/python/morpheus/morpheus/_lib/include/morpheus/messages/control.hpp index 6f14d93037..9aa431c950 100644 --- a/python/morpheus/morpheus/_lib/include/morpheus/messages/control.hpp +++ b/python/morpheus/morpheus/_lib/include/morpheus/messages/control.hpp @@ -19,6 +19,7 @@ #include "morpheus/export.h" // for MORPHEUS_EXPORT #include "morpheus/messages/meta.hpp" // for MessageMeta +#include "morpheus/types.hpp" #include "morpheus/utilities/json_types.hpp" // for json_t #include // for object, dict, list @@ -197,6 +198,13 @@ class MORPHEUS_EXPORT ControlMessage */ void tensors(const std::shared_ptr& tensor_memory); + /** + * @brief Get the length of tensors in the tensor memory. + * + * @return The length of tensors in the tensor memory. + */ + TensorIndex tensor_count(); + /** * @brief Get the type of task associated with the control message. * @return An enum value indicating the task type. @@ -262,6 +270,7 @@ class MORPHEUS_EXPORT ControlMessage ControlMessageType m_cm_type{ControlMessageType::NONE}; std::shared_ptr m_payload{nullptr}; std::shared_ptr m_tensors{nullptr}; + TensorIndex m_tensor_count{0}; morpheus::utilities::json_t m_tasks{}; morpheus::utilities::json_t m_config{}; diff --git a/python/morpheus/morpheus/_lib/messages/__init__.pyi b/python/morpheus/morpheus/_lib/messages/__init__.pyi index 11fba00aee..50bbea7829 100644 --- a/python/morpheus/morpheus/_lib/messages/__init__.pyi +++ b/python/morpheus/morpheus/_lib/messages/__init__.pyi @@ -73,6 +73,7 @@ class ControlMessage(): def task_type(self) -> ControlMessageType: ... @typing.overload def task_type(self, task_type: ControlMessageType) -> None: ... + def tensor_count(self) -> int: ... @typing.overload def tensors(self) -> TensorMemory: ... @typing.overload diff --git a/python/morpheus/morpheus/_lib/messages/module.cpp b/python/morpheus/morpheus/_lib/messages/module.cpp index fed31a6d11..992b99e422 100644 --- a/python/morpheus/morpheus/_lib/messages/module.cpp +++ b/python/morpheus/morpheus/_lib/messages/module.cpp @@ -290,6 +290,7 @@ PYBIND11_MODULE(messages, _module) py::arg("meta")) .def("tensors", pybind11::overload_cast<>(&ControlMessage::tensors)) .def("tensors", pybind11::overload_cast&>(&ControlMessage::tensors)) + .def("tensor_count", &ControlMessage::tensor_count) .def("remove_task", &ControlMessage::remove_task, py::arg("task_type")) .def("set_metadata", &ControlMessage::set_metadata, py::arg("key"), py::arg("value")) .def("task_type", pybind11::overload_cast<>(&ControlMessage::task_type)) diff --git a/python/morpheus/morpheus/_lib/src/messages/control.cpp b/python/morpheus/morpheus/_lib/src/messages/control.cpp index c1a85dbcba..34141085b2 100644 --- a/python/morpheus/morpheus/_lib/src/messages/control.cpp +++ b/python/morpheus/morpheus/_lib/src/messages/control.cpp @@ -59,9 +59,10 @@ ControlMessage::ControlMessage(const morpheus::utilities::json_t& _config) : ControlMessage::ControlMessage(const ControlMessage& other) { - m_cm_type = other.m_cm_type; - m_payload = other.m_payload; - m_tensors = other.m_tensors; + m_cm_type = other.m_cm_type; + m_payload = other.m_payload; + m_tensors = other.m_tensors; + m_tensor_count = other.m_tensor_count; m_config = other.m_config; m_tasks = other.m_tasks; @@ -256,7 +257,13 @@ std::shared_ptr ControlMessage::tensors() void ControlMessage::tensors(const std::shared_ptr& tensors) { - m_tensors = tensors; + m_tensors = tensors; + m_tensor_count = tensors ? tensors->count : 0; +} + +TensorIndex ControlMessage::tensor_count() +{ + return m_tensor_count; } ControlMessageType ControlMessage::task_type() diff --git a/python/morpheus/morpheus/_lib/src/stages/inference_client_stage.cpp b/python/morpheus/morpheus/_lib/src/stages/inference_client_stage.cpp index c5baa2fa25..76022eeda2 100644 --- a/python/morpheus/morpheus/_lib/src/stages/inference_client_stage.cpp +++ b/python/morpheus/morpheus/_lib/src/stages/inference_client_stage.cpp @@ -58,7 +58,7 @@ static ShapeType get_seq_ids(const std::shared_ptr& message) auto seq_ids = message->tensors()->get_tensor("seq_ids"); const auto item_size = seq_ids.dtype().item_size(); - ShapeType host_seq_ids(message->tensors()->count); + ShapeType host_seq_ids(message->tensor_count()); MRC_CHECK_CUDA(cudaMemcpy2D(host_seq_ids.data(), item_size, seq_ids.data(), @@ -82,7 +82,7 @@ static TensorObject get_tensor(std::shared_ptr message, std::str static void reduce_outputs(std::shared_ptr const& message, TensorMap& output_tensors) { - if (message->payload()->count() == message->tensors()->count) + if (message->payload()->count() == message->tensor_count()) { return; } diff --git a/python/morpheus/morpheus/_lib/tests/messages/test_control_message.cpp b/python/morpheus/morpheus/_lib/tests/messages/test_control_message.cpp index 642660fcdc..2f02a71adc 100644 --- a/python/morpheus/morpheus/_lib/tests/messages/test_control_message.cpp +++ b/python/morpheus/morpheus/_lib/tests/messages/test_control_message.cpp @@ -21,6 +21,7 @@ #include "morpheus/messages/control.hpp" // for ControlMessage #include "morpheus/messages/memory/tensor_memory.hpp" // for TensorMemory #include "morpheus/messages/meta.hpp" // for MessageMeta +#include "morpheus/types.hpp" #include "morpheus/utilities/json_types.hpp" // for PythonByteContainer #include // for Message, TestPartResult, AssertionResult, TestInfo @@ -298,7 +299,8 @@ TEST_F(TestControlMessage, SetAndGetTensorMemory) { auto msg = ControlMessage(); - auto tensorMemory = std::make_shared(0); + TensorIndex count = 5; + auto tensorMemory = std::make_shared(count); // Optionally, modify tensorMemory here if it has any mutable state to test // Set the tensor memory @@ -309,6 +311,7 @@ TEST_F(TestControlMessage, SetAndGetTensorMemory) // Verify that the retrieved tensor memory matches what was set EXPECT_EQ(tensorMemory, retrievedTensorMemory); + EXPECT_EQ(count, msg.tensor_count()); } // Test setting TensorMemory to nullptr diff --git a/python/morpheus/morpheus/messages/control_message.py b/python/morpheus/morpheus/messages/control_message.py index a2c4f35496..a30df06638 100644 --- a/python/morpheus/morpheus/messages/control_message.py +++ b/python/morpheus/morpheus/messages/control_message.py @@ -46,6 +46,7 @@ def __init__(self, config_or_message: typing.Union["ControlMessage", dict] = Non self._payload: MessageMeta = None self._tensors: TensorMemory = None + self._tensor_count: int = 0 self._tasks: dict[str, deque] = defaultdict(deque) self._timestamps: dict[str, datetime] = {} @@ -147,9 +148,13 @@ def payload(self, payload: MessageMeta = None) -> MessageMeta | None: def tensors(self, tensors: TensorMemory = None) -> TensorMemory | None: if tensors is not None: self._tensors = tensors + self._tensor_count = tensors.count return self._tensors + def tensor_count(self) -> int: + return self._tensor_count + def task_type(self, new_task_type: ControlMessageType = None) -> ControlMessageType: if new_task_type is not None: self._type = new_task_type diff --git a/python/morpheus/morpheus/stages/inference/identity_inference_stage.py b/python/morpheus/morpheus/stages/inference/identity_inference_stage.py index 0edde48ced..dfa500eb29 100644 --- a/python/morpheus/morpheus/stages/inference/identity_inference_stage.py +++ b/python/morpheus/morpheus/stages/inference/identity_inference_stage.py @@ -45,12 +45,12 @@ def __init__(self, inf_queue: ProducerConsumerQueue, c: Config): self._seq_length = c.feature_length def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple: - return (msg.tensors().count, self._seq_length) + return (msg.tensor_count(), self._seq_length) def process(self, batch: ControlMessage, callback: typing.Callable[[TensorMemory], None]): def tmp(batch: ControlMessage, f): - count = batch.tensors().count + count = batch.tensor_count() f(TensorMemory( count=count, tensors={'probs': cp.zeros((count, self._seq_length), dtype=cp.float32)}, diff --git a/python/morpheus/morpheus/stages/inference/inference_stage.py b/python/morpheus/morpheus/stages/inference/inference_stage.py index f235e12fc4..3e95a34e30 100644 --- a/python/morpheus/morpheus/stages/inference/inference_stage.py +++ b/python/morpheus/morpheus/stages/inference/inference_stage.py @@ -244,7 +244,7 @@ def set_output_fut(resp: TensorMemory, inner_batch, batch_future: mrc.Future): nonlocal outstanding_requests nonlocal batch_offset mess = self._convert_one_response(output_message, inner_batch, resp, batch_offset) - batch_offset += inner_batch.tensors().count + batch_offset += inner_batch.tensor_count() outstanding_requests -= 1 batch_future.set_result(mess) @@ -359,13 +359,13 @@ def _convert_one_response(output: ControlMessage, inf: ControlMessage, res: Tens seq_count = seq_ids[-1, 0].item() + 1 - seq_offset # Two scenarios: - if (inf.payload().count == inf.tensors().count): + if (inf.payload().count == inf.tensor_count()): assert seq_count == res.count # In message and out message have same count. Just use probs as is probs[seq_offset:seq_offset + seq_count, :] = resp_probs else: - assert inf.tensors().count == res.count + assert inf.tensor_count() == res.count mess_ids = seq_ids[:, 0].get().tolist() diff --git a/python/morpheus/morpheus/stages/inference/pytorch_inference_stage.py b/python/morpheus/morpheus/stages/inference/pytorch_inference_stage.py index fc05d1ef38..3f9dd7bc81 100644 --- a/python/morpheus/morpheus/stages/inference/pytorch_inference_stage.py +++ b/python/morpheus/morpheus/stages/inference/pytorch_inference_stage.py @@ -73,7 +73,7 @@ def init(self): def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple: input_ids = msg.tensors().get_tensor("input_ids") input_mask = msg.tensors().get_tensor("input_mask") - count = msg.tensors().count + count = msg.tensor_count() # If we haven't cached the output dimension, do that here if (not self._output_size): test_intput_ids_shape = (self._max_batch_size, ) + input_ids.shape[1:] @@ -91,7 +91,7 @@ def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple: def process(self, batch: ControlMessage, callback: typing.Callable[[TensorMemory], None]): input_ids = batch.tensors().get_tensor("input_ids") input_mask = batch.tensors().get_tensor("input_mask") - count = batch.tensors().count + count = batch.tensor_count() # convert from cupy to torch tensor using dlpack input_ids = from_dlpack(input_ids.astype(cp.float).toDlpack()).type(torch.long) diff --git a/python/morpheus/morpheus/stages/inference/triton_inference_stage.py b/python/morpheus/morpheus/stages/inference/triton_inference_stage.py index 62f0a51d8e..9da8538e69 100644 --- a/python/morpheus/morpheus/stages/inference/triton_inference_stage.py +++ b/python/morpheus/morpheus/stages/inference/triton_inference_stage.py @@ -568,7 +568,7 @@ def create_wrapper(): raise ex def calc_output_dims(self, msg: ControlMessage) -> typing.Tuple: - return (msg.tensors().count, self._outputs[list(self._outputs.keys())[0]].shape[1]) + return (msg.tensor_count(), self._outputs[list(self._outputs.keys())[0]].shape[1]) def _build_response( self, diff --git a/python/morpheus/morpheus/stages/postprocess/ml_flow_drift_stage.py b/python/morpheus/morpheus/stages/postprocess/ml_flow_drift_stage.py index fddf84cc72..3c62a4b2c9 100644 --- a/python/morpheus/morpheus/stages/postprocess/ml_flow_drift_stage.py +++ b/python/morpheus/morpheus/stages/postprocess/ml_flow_drift_stage.py @@ -138,7 +138,7 @@ def _calc_drift(self, msg: ControlMessage): for label in range(len(self._labels), shifted.shape[1]): self._labels.append(str(label)) - count = msg.tensors().count + count = msg.tensor_count() for i in list(range(0, count, self._batch_size)): start = i diff --git a/tests/examples/abp_pcap_detection/test_abp_pcap_preprocessing.py b/tests/examples/abp_pcap_detection/test_abp_pcap_preprocessing.py index c791ea8265..4c46e0379b 100755 --- a/tests/examples/abp_pcap_detection/test_abp_pcap_preprocessing.py +++ b/tests/examples/abp_pcap_detection/test_abp_pcap_preprocessing.py @@ -41,7 +41,7 @@ def check_inf_message(msg: ControlMessage, expected_input__0: cp.ndarray): assert isinstance(msg, ControlMessage) assert msg.payload().count == expected_mess_count - assert msg.tensors().count == expected_count + assert msg.tensor_count() == expected_count df = msg.payload().get_data() assert 'flow_id' in df diff --git a/tests/examples/log_parsing/test_inference.py b/tests/examples/log_parsing/test_inference.py index 271a6a0ace..1e6e9d812d 100644 --- a/tests/examples/log_parsing/test_inference.py +++ b/tests/examples/log_parsing/test_inference.py @@ -138,7 +138,7 @@ def test_log_parsing_triton_inference_log_parsing_build_output_message(config: C msg = worker.build_output_message(input_msg) assert msg.payload() is input_msg.payload() assert msg.payload().count == mess_count - assert msg.tensors().count == count + assert msg.tensor_count() == count assert set(msg.tensors().tensor_names).issuperset(('confidences', 'labels', 'input_ids', 'seq_ids')) assert msg.tensors().get_tensor('confidences').shape == (count, 2) @@ -187,7 +187,7 @@ def test_log_parsing_inference_stage_convert_one_response(import_mod: typing.Lis assert isinstance(output_msg, ControlMessage) assert output_msg.payload() is input_inf.payload() assert output_msg.payload().count == mess_count - assert output_msg.tensors().count == count + assert output_msg.tensor_count() == count assert (output_msg.tensors().get_tensor('seq_ids') == input_inf.tensors().get_tensor('seq_ids')).all() assert (output_msg.tensors().get_tensor('input_ids') == input_inf.tensors().get_tensor('input_ids')).all() diff --git a/tests/morpheus/messages/test_control_message.py b/tests/morpheus/messages/test_control_message.py index dfd9218938..9f8c22fa5c 100644 --- a/tests/morpheus/messages/test_control_message.py +++ b/tests/morpheus/messages/test_control_message.py @@ -38,7 +38,7 @@ def _verify_metadata(msg: messages.ControlMessage, metadata: dict): @pytest.mark.gpu_and_cpu_mode def test_control_message_init(dataset: DatasetManager): - # Explicitly performing copies of the metadata, config and the dataframe, to ensure tha the original data is not + # Explicitly performing copies of the metadata, config and the dataframe, to ensure that the original data is not # being modified in place in some way. msg = messages.ControlMessage() assert msg.get_metadata() == {} # pylint: disable=use-implicit-booleaness-not-comparison @@ -318,9 +318,9 @@ def test_tensors_setting_and_getting(config: Config): message.tensors(tensor_memory) - retrieved_tensors = message.tensors() - assert retrieved_tensors.count == data["input_ids"].shape[0], "Tensor count mismatch." + assert message.tensor_count() == data["input_ids"].shape[0], "Tensor count mismatch." + retrieved_tensors = message.tensors() for key, val in data.items(): assert array_pkg.allclose(retrieved_tensors.get_tensor(key), val), f"Mismatch in tensor data for {key}." @@ -363,6 +363,7 @@ def test_tensor_manipulation_after_retrieval(config: Config): new_tensor = array_pkg.array([4, 5, 6]) retrieved_tensors.set_tensor("new_tensor", new_tensor) + assert message.tensor_count() == tokenized_data["input_ids"].shape[0], "Tensor count mismatch" assert array_pkg.allclose(retrieved_tensors.get_tensor("new_tensor"), new_tensor), "New tensor data mismatch." @@ -389,8 +390,9 @@ def test_tensor_update(config: Config): tensor_memory.set_tensors(new_tensors) - updated_tensors = message.tensors() + assert message.tensor_count() == tokenized_data["input_ids"].shape[0], "Tensor count mismatch" + updated_tensors = message.tensors() for key, val in new_tensors.items(): assert array_pkg.allclose(updated_tensors.get_tensor(key), val), f"Mismatch in updated tensor data for {key}." @@ -408,6 +410,7 @@ def test_update_individual_tensor(config: Config): tensor_memory.set_tensor("input_ids", update_data["input_ids"]) retrieved_tensors = message.tensors() + assert message.tensor_count() == initial_data["input_ids"].shape[0], "Tensor count mismatch" # Check updated tensor assert array_pkg.allclose(retrieved_tensors.get_tensor("input_ids"), update_data["input_ids"]), "Input IDs update mismatch." @@ -422,8 +425,9 @@ def test_behavior_with_empty_tensors(): tensor_memory = TensorMemory(count=0) message.tensors(tensor_memory) + assert message.tensor_count() == 0, "Tensor count should be 0 for empty tensor memory." + retrieved_tensors = message.tensors() - assert retrieved_tensors.count == 0, "Tensor count should be 0 for empty tensor memory." assert len(retrieved_tensors.tensor_names) == 0, "There should be no tensor names for empty tensor memory." @@ -442,8 +446,8 @@ def test_consistency_after_multiple_operations(config: Config): new_tensor = {"new_tensor": array_pkg.array([7, 8, 9])} tensor_memory.set_tensor("new_tensor", new_tensor["new_tensor"]) + assert message.tensor_count() == initial_data["input_ids"].shape[0], "Tensor count mismatch." retrieved_tensors = message.tensors() - assert retrieved_tensors.count == 3, "Tensor count mismatch after multiple operations." assert array_pkg.allclose(retrieved_tensors.get_tensor("input_ids"), array_pkg.array([4, 5, 6])), "Mismatch in input_ids after update." assert array_pkg.allclose(retrieved_tensors.get_tensor("new_tensor"), diff --git a/tests/morpheus/stages/test_inference_stage.py b/tests/morpheus/stages/test_inference_stage.py index e030b7bd19..135e0d8eb2 100755 --- a/tests/morpheus/stages/test_inference_stage.py +++ b/tests/morpheus/stages/test_inference_stage.py @@ -110,7 +110,7 @@ def test_convert_one_response(): cm = InferenceStageT._convert_one_response(output, inf, res, batch_offset) assert cm.payload() == inf.payload() assert cm.payload().count == 4 - assert cm.tensors().count == 4 + assert cm.tensor_count() == 4 assert cp.all(cm.tensors().get_tensor("probs") == res.get_tensor("probs")) # Test for the second branch From 5f178efbee02fee1c38ba5aa2b133eabc7a00f41 Mon Sep 17 00:00:00 2001 From: David Gardner <96306125+dagardner-nv@users.noreply.github.com> Date: Mon, 16 Dec 2024 10:46:37 -0800 Subject: [PATCH 2/3] Fix openai validation error (#2083) * Work-around known openai incompatibility with httpx v0.28 (openai/openai-python#1915) * Templatize the llm pip dependencies * Replace deprecated imports of openai from langchain and langchain_community Closes #2084 ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md). - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - David Gardner (https://github.com/dagardner-nv) Approvers: - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) URL: https://github.com/nv-morpheus/Morpheus/pull/2083 --- ci/conda/recipes/morpheus-libs/meta.yaml | 1 + ci/conda/recipes/morpheus/meta.yaml | 1 + .../all_cuda-125_arch-x86_64.yaml | 2 ++ .../examples_cuda-125_arch-x86_64.yaml | 2 ++ dependencies.yaml | 23 +++++++++++-------- docker/Dockerfile | 5 ++++ examples/llm/agents/common.py | 2 +- .../requirements_morpheus_llm.txt | 1 + tests/_utils/llm.py | 12 ++++------ .../llm/nodes/test_langchain_agent_node.py | 2 +- .../llm/test_agents_simple_pipe.py | 4 ++-- 11 files changed, 34 insertions(+), 21 deletions(-) diff --git a/ci/conda/recipes/morpheus-libs/meta.yaml b/ci/conda/recipes/morpheus-libs/meta.yaml index f6d88716bf..3fa359d0b1 100644 --- a/ci/conda/recipes/morpheus-libs/meta.yaml +++ b/ci/conda/recipes/morpheus-libs/meta.yaml @@ -64,6 +64,7 @@ outputs: - cudf {{ rapids_version }} - cython 3.0.* - glog >=0.7.1,<0.8 + - indicators=2.3 - libcudf {{ rapids_version }} - librdkafka >=1.9.2,<1.10.0a0 - mrc {{ minor_version }} diff --git a/ci/conda/recipes/morpheus/meta.yaml b/ci/conda/recipes/morpheus/meta.yaml index fd60e49243..0e72cd17a3 100644 --- a/ci/conda/recipes/morpheus/meta.yaml +++ b/ci/conda/recipes/morpheus/meta.yaml @@ -69,6 +69,7 @@ outputs: - cudf {{ rapids_version }} - cython 3.0.* - glog >=0.7.1,<0.8 + - indicators=2.3 - libcudf {{ rapids_version }} - librdkafka >=1.9.2,<1.10.0a0 - mrc {{ minor_version }} diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index ebbb3dffe0..513b2dd157 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -51,6 +51,7 @@ dependencies: - grpcio-status - gtest=1.14 - gxx=12.1 +- httpx>=0.23,<0.28 - huggingface_hub=0.20.2 - include-what-you-use=0.20 - indicators=2.3 @@ -136,6 +137,7 @@ dependencies: - faiss-cpu - google-search-results==2.4 - langchain-nvidia-ai-endpoints==0.0.11 + - langchain-openai==0.1.3 - langchain==0.1.16 - milvus==2.3.5 - nemollm==0.3.5 diff --git a/conda/environments/examples_cuda-125_arch-x86_64.yaml b/conda/environments/examples_cuda-125_arch-x86_64.yaml index 0ffb592b7d..646069c124 100644 --- a/conda/environments/examples_cuda-125_arch-x86_64.yaml +++ b/conda/environments/examples_cuda-125_arch-x86_64.yaml @@ -25,6 +25,7 @@ dependencies: - feedparser=6.0 - grpcio - grpcio-status +- httpx>=0.23,<0.28 - huggingface_hub=0.20.2 - jsonpatch>=1.33 - kfp @@ -73,6 +74,7 @@ dependencies: - faiss-cpu - google-search-results==2.4 - langchain-nvidia-ai-endpoints==0.0.11 + - langchain-openai==0.1.3 - langchain==0.1.16 - milvus==2.3.5 - nemollm==0.3.5 diff --git a/dependencies.yaml b/dependencies.yaml index 1f9941ad3b..477633593d 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -430,13 +430,14 @@ dependencies: common: - output_types: [requirements] packages: - - faiss-cpu - - google-search-results==2.4 - - langchain==0.1.16 - - langchain-nvidia-ai-endpoints==0.0.11 + - &faiss-cpu faiss-cpu + - &google-search-results google-search-results==2.4 + - &langchain langchain==0.1.16 + - &langchain-nvidia-ai-endpoints langchain-nvidia-ai-endpoints==0.0.11 + - &langchain-openai langchain-openai==0.1.3 - milvus==2.3.5 # update to match pymilvus when available - pymilvus==2.3.6 - - nemollm==0.3.5 + - &nemollm nemollm==0.3.5 example-dfp-prod: common: @@ -487,6 +488,7 @@ dependencies: - &transformers transformers=4.36.2 # newer versions are incompatible with our pinned version of huggingface_hub - anyio>=3.7 - arxiv=1.4 + - httpx>=0.23,<0.28 # work-around for https://github.com/openai/openai-python/issues/1915 - huggingface_hub=0.20.2 # work-around for https://github.com/UKPLab/sentence-transformers/issues/1762 - jsonpatch>=1.33 - newspaper3k=0.2 @@ -499,11 +501,12 @@ dependencies: - requests-toolbelt=1.0 # Transitive dep needed by nemollm, specified here to ensure we get a compatible version - pip - pip: - - langchain==0.1.16 - - langchain-nvidia-ai-endpoints==0.0.11 - - faiss-cpu - - google-search-results==2.4 - - nemollm==0.3.5 + - *faiss-cpu + - *google-search-results + - *langchain + - *langchain-nvidia-ai-endpoints + - *langchain-openai + - *nemollm - sentence-transformers==2.7 # using pip now instead of conda to avoid install of pytorch cpu model-training-tuning: diff --git a/docker/Dockerfile b/docker/Dockerfile index 929d4fd005..cdd7693b11 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -185,6 +185,7 @@ COPY . ./ RUN --mount=type=cache,id=workspace_cache,target=/workspace/.cache,sharing=locked \ --mount=type=cache,id=conda_pkgs,target=/opt/conda/pkgs,sharing=locked \ + --mount=type=cache,id=pip_cache,target=/root/.cache/pip,sharing=locked \ # Install git-lfs before running the build to avoid errors during conda build /opt/conda/bin/mamba install -y -n base -c conda-forge "git-lfs" &&\ source activate base &&\ @@ -227,6 +228,7 @@ COPY ${MORPHEUS_ROOT_HOST}/conda/environments/dev_cuda-${CUDA_MAJOR_VER}${CUDA_M # Update the morpheus environment RUN --mount=type=cache,id=conda_pkgs,target=/opt/conda/pkgs,sharing=locked \ + --mount=type=cache,id=pip_cache,target=/root/.cache/pip,sharing=locked \ # Temp add channel_alias to get around conda 404 errors conda config --env --set channel_alias ${CONDA_CHANNEL_ALIAS} &&\ /opt/conda/bin/conda env update --solver=libmamba -n morpheus --file conda/environments/dev.yaml &&\ @@ -263,6 +265,7 @@ COPY . ./ RUN --mount=type=cache,id=workspace_cache,target=/workspace/.cache,sharing=locked \ --mount=type=cache,id=conda_pkgs,target=/opt/conda/pkgs,sharing=locked \ + --mount=type=cache,id=pip_cache,target=/root/.cache/pip,sharing=locked \ # Install git-lfs before running the build to avoid errors during conda build /opt/conda/bin/mamba install -y -n base -c conda-forge "git-lfs" &&\ source activate base &&\ @@ -285,6 +288,7 @@ COPY . ./ RUN --mount=type=cache,id=workspace_cache,target=/workspace/.cache,sharing=locked \ --mount=type=bind,from=conda_bld_morpheus,source=/opt/conda/conda-bld,target=/opt/conda/conda-bld \ --mount=type=cache,id=conda_pkgs,target=/opt/conda/pkgs,sharing=locked \ + --mount=type=cache,id=pip_cache,target=/root/.cache/pip,sharing=locked \ source activate morpheus &&\ CONDA_ALWAYS_YES=true /opt/conda/bin/mamba install -n morpheus \ -c local \ @@ -314,6 +318,7 @@ COPY "${MORPHEUS_ROOT_HOST}/conda/environments/runtime_cuda-${CUDA_MAJOR_VER}${C # Mount Morpheus conda package build in `conda_bld_morpheus` RUN --mount=type=bind,from=conda_bld_morpheus,source=/opt/conda/conda-bld,target=/opt/conda/conda-bld \ --mount=type=cache,id=conda_pkgs,target=/opt/conda/pkgs,sharing=locked \ + --mount=type=cache,id=pip_cache,target=/root/.cache/pip,sharing=locked \ # CVE-2018-20225 for the base pip, not the env one # conda will ignore the request to remove pip python -m pip uninstall -y pip && \ diff --git a/examples/llm/agents/common.py b/examples/llm/agents/common.py index 528291b857..40af0d5de6 100644 --- a/examples/llm/agents/common.py +++ b/examples/llm/agents/common.py @@ -18,7 +18,7 @@ from langchain.agents import initialize_agent from langchain.agents import load_tools from langchain.agents.agent import AgentExecutor -from langchain.llms.openai import OpenAI +from langchain_openai import OpenAI from morpheus.config import Config from morpheus.pipeline.linear_pipeline import LinearPipeline diff --git a/python/morpheus_llm/morpheus_llm/requirements_morpheus_llm.txt b/python/morpheus_llm/morpheus_llm/requirements_morpheus_llm.txt index 8f9a9620b9..d8f16a5a37 100644 --- a/python/morpheus_llm/morpheus_llm/requirements_morpheus_llm.txt +++ b/python/morpheus_llm/morpheus_llm/requirements_morpheus_llm.txt @@ -4,6 +4,7 @@ faiss-cpu google-search-results==2.4 langchain-nvidia-ai-endpoints==0.0.11 +langchain-openai==0.1.3 langchain==0.1.16 milvus==2.3.5 nemollm==0.3.5 diff --git a/tests/_utils/llm.py b/tests/_utils/llm.py index 49f087b9b3..ccd1ed6195 100644 --- a/tests/_utils/llm.py +++ b/tests/_utils/llm.py @@ -89,13 +89,11 @@ def mk_mock_openai_response(messages: list[str]) -> mock.MagicMock: response = mock.MagicMock() response.choices = [_mk_mock_choice(message) for message in messages] - response.dict.return_value = { - "choices": [{ - 'message': { - 'role': 'assistant', 'content': message - } - } for message in messages] - } + + response_dict = {"choices": [{'message': {'role': 'assistant', 'content': message}} for message in messages]} + + response.dict.return_value = response_dict + response.model_dump.return_value = response_dict return response diff --git a/tests/morpheus_llm/llm/nodes/test_langchain_agent_node.py b/tests/morpheus_llm/llm/nodes/test_langchain_agent_node.py index 0779b11604..ad4e131ace 100644 --- a/tests/morpheus_llm/llm/nodes/test_langchain_agent_node.py +++ b/tests/morpheus_llm/llm/nodes/test_langchain_agent_node.py @@ -32,8 +32,8 @@ from langchain.agents import initialize_agent from langchain.callbacks.manager import AsyncCallbackManagerForToolRun from langchain.callbacks.manager import CallbackManagerForToolRun - from langchain_community.chat_models.openai import ChatOpenAI from langchain_core.tools import BaseTool + from langchain_openai import ChatOpenAI except ImportError: pass diff --git a/tests/morpheus_llm/llm/test_agents_simple_pipe.py b/tests/morpheus_llm/llm/test_agents_simple_pipe.py index 5d33dacb03..d219a9780e 100644 --- a/tests/morpheus_llm/llm/test_agents_simple_pipe.py +++ b/tests/morpheus_llm/llm/test_agents_simple_pipe.py @@ -42,8 +42,8 @@ from langchain.agents.tools import Tool from langchain.schema import Generation from langchain.schema import LLMResult - from langchain_community.llms import OpenAI # pylint: disable=no-name-in-module from langchain_community.utilities import serpapi + from langchain_openai import OpenAI # pylint: disable=no-name-in-module except ImportError: pass @@ -129,7 +129,7 @@ def test_agents_simple_pipe_integration_openai(config: Config, questions: list[s @pytest.mark.usefixtures("openai", "restore_environ") @mock.patch("langchain_community.utilities.serpapi.SerpAPIWrapper.aresults") -@mock.patch("langchain_community.llms.OpenAI._agenerate", +@mock.patch("langchain_openai.OpenAI._agenerate", autospec=True) # autospec is needed as langchain will inspect the function def test_agents_simple_pipe(mock_openai_agenerate: mock.AsyncMock, mock_serpapi_aresults: mock.AsyncMock, From 5e1116d494c849c5a7937d60c7d8ae1c63b56ba4 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 16 Dec 2024 10:56:16 -0800 Subject: [PATCH 3/3] Remove cudf._lib.utils usage in favor of pylibcudf (#2082) In anticipation of this downstream cuDF PR removing some functionality in `cudf._lib.utils` https://github.com/rapidsai/cudf/pull/17586, this PR replaces that usage with equivalent usage from the stable `pylibcudf` ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md). - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - Matthew Roeschke (https://github.com/mroeschke) - David Gardner (https://github.com/dagardner-nv) Approvers: - David Gardner (https://github.com/dagardner-nv) URL: https://github.com/nv-morpheus/Morpheus/pull/2082 --- .../morpheus/morpheus/_lib/cudf_helpers.pyx | 43 +++++++++++++++++-- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/python/morpheus/morpheus/_lib/cudf_helpers.pyx b/python/morpheus/morpheus/_lib/cudf_helpers.pyx index 6c23a1b543..fe0e96536c 100644 --- a/python/morpheus/morpheus/_lib/cudf_helpers.pyx +++ b/python/morpheus/morpheus/_lib/cudf_helpers.pyx @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools + import cudf +from cudf.core.column import ColumnBase from cudf.core.dtypes import StructDtype from libcpp.string cimport string @@ -26,8 +29,6 @@ from pylibcudf.libcudf.table.table_view cimport table_view from pylibcudf.libcudf.types cimport size_type from cudf._lib.column cimport Column -from cudf._lib.utils cimport data_from_unique_ptr -from cudf._lib.utils cimport table_view_from_table ##### THE FOLLOWING CODE IS COPIED FROM CUDF AND SHOULD BE REMOVED WHEN UPDATING TO cudf>=24.12 ##### # see https://github.com/rapidsai/cudf/pull/17193 for details @@ -39,6 +40,7 @@ cimport pylibcudf.libcudf.copying as cpp_copying from pylibcudf.libcudf.column.column_view cimport column_view from libcpp.memory cimport make_unique, unique_ptr from pylibcudf.libcudf.scalar.scalar cimport scalar +from pylibcudf cimport Table as plc_Table from cudf._lib.scalar cimport DeviceScalar # imports needed for from_column_view_with_fix @@ -289,8 +291,35 @@ cdef public api: index_names = schema_infos[0:index_col_count] if index_col_count > 0 else None column_names = schema_infos[index_col_count:] - data, index = data_from_unique_ptr(move(table.tbl), column_names=column_names, index_names=index_names) + plc_table = plc_Table.from_libcudf(move(table.tbl)) + if index_names is None: + index = None + data = { + col_name: ColumnBase.from_pylibcudf(col) + for col_name, col in zip( + column_names, plc_table.columns() + ) + } + else: + result_columns = [ + ColumnBase.from_pylibcudf(col) + for col in plc_table.columns() + ] + index = cudf.Index._from_data( + dict( + zip( + index_names, + result_columns[: len(index_names)], + ) + ) + ) + data = dict( + zip( + column_names, + result_columns[len(index_names) :], + ) + ) df = cudf.DataFrame._from_data(data, index) # Update the struct field names after the DataFrame is created @@ -356,7 +385,13 @@ cdef public api: cdef vector[string] temp_col_names = get_column_names(table, True) - cdef table_view input_table_view = table_view_from_table(table, ignore_index=False) + cdef plc_Table plc_table = plc_Table( + [ + col.to_pylibcudf(mode="read") + for col in itertools.chain(table.index._columns, table._columns) + ] + ) + cdef table_view input_table_view = plc_table.view() cdef vector[string] index_names cdef vector[string] column_names