diff --git a/morpheus/_lib/cmake/libmorpheus.cmake b/morpheus/_lib/cmake/libmorpheus.cmake index 954620dfac..b4d3e5baaf 100644 --- a/morpheus/_lib/cmake/libmorpheus.cmake +++ b/morpheus/_lib/cmake/libmorpheus.cmake @@ -70,7 +70,7 @@ add_library(morpheus src/stages/add_scores.cpp src/stages/deserialize.cpp src/stages/file_source.cpp - src/stages/filter_detection.cpp + src/stages/filter_detections.cpp src/stages/http_server_source_stage.cpp src/stages/inference_client_stage.cpp src/stages/kafka_source.cpp diff --git a/morpheus/_lib/include/morpheus/stages/filter_detection.hpp b/morpheus/_lib/include/morpheus/stages/filter_detections.hpp similarity index 61% rename from morpheus/_lib/include/morpheus/stages/filter_detection.hpp rename to morpheus/_lib/include/morpheus/stages/filter_detections.hpp index 5e78a8322e..1e2d0cd86b 100644 --- a/morpheus/_lib/include/morpheus/stages/filter_detection.hpp +++ b/morpheus/_lib/include/morpheus/stages/filter_detections.hpp @@ -17,22 +17,22 @@ #pragma once -#include "morpheus/export.h" -#include "morpheus/messages/multi.hpp" -#include "morpheus/objects/dev_mem_info.hpp" // for DevMemInfo -#include "morpheus/objects/filter_source.hpp" +#include "morpheus/export.h" // for MORPHEUS_EXPORT +#include "morpheus/messages/control.hpp" // for ControlMessage +#include "morpheus/messages/multi.hpp" // for MultiMessage +#include "morpheus/objects/dev_mem_info.hpp" // for DevMemInfo +#include "morpheus/objects/filter_source.hpp" // for FilterSource -#include -#include -#include -#include -#include +#include // for cudaMemcpy +#include // for Builder +#include // for Object +#include // for PythonNode +#include // for observable_member, trace_activity, map, decay_t, from #include // for size_t -#include -#include -#include -#include +#include // for map +#include // for allocator, shared_ptr +#include // for string namespace morpheus { /****** Component public implementations *******************/ @@ -68,11 +68,12 @@ namespace morpheus { * Depending on the downstream stages, this can cause performance issues, especially if those stages need to acquire * the Python GIL. */ +template class MORPHEUS_EXPORT FilterDetectionsStage - : public mrc::pymrc::PythonNode, std::shared_ptr> + : public mrc::pymrc::PythonNode, std::shared_ptr> { public: - using base_t = mrc::pymrc::PythonNode, std::shared_ptr>; + using base_t = mrc::pymrc::PythonNode, std::shared_ptr>; using typename base_t::sink_type_t; using typename base_t::source_type_t; using typename base_t::subscribe_fn_t; @@ -90,8 +91,8 @@ class MORPHEUS_EXPORT FilterDetectionsStage private: subscribe_fn_t build_operator(); - DevMemInfo get_tensor_filter_source(const std::shared_ptr& x); - DevMemInfo get_column_filter_source(const std::shared_ptr& x); + DevMemInfo get_tensor_filter_source(const sink_type_t& x); + DevMemInfo get_column_filter_source(const sink_type_t& x); float m_threshold; bool m_copy; @@ -101,6 +102,11 @@ class MORPHEUS_EXPORT FilterDetectionsStage std::map m_idx2label; }; +using FilterDetectionsStageMM = // NOLINT(readability-identifier-naming) + FilterDetectionsStage; +using FilterDetectionsStageCM = // NOLINT(readability-identifier-naming) + FilterDetectionsStage; + /****** FilterDetectionStageInterfaceProxy******************/ /** * @brief Interface proxy, used to insulate python bindings. @@ -108,7 +114,27 @@ class MORPHEUS_EXPORT FilterDetectionsStage struct MORPHEUS_EXPORT FilterDetectionStageInterfaceProxy { /** - * @brief Create and initialize a FilterDetectionStage, and return the result + * @brief Create and initialize a FilterDetectionStage that receives MultiMessage and emits MultiMessage, and return + * the result + * + * @param builder : Pipeline context object reference + * @param name : Name of a stage reference + * @param threshold : Threshold to classify + * @param copy : Whether or not to perform a copy default=true + * @param filter_source : Indicate if the values used for filtering exist in either an output tensor + * (`FilterSource::TENSOR`) or a column in a Dataframe (`FilterSource::DATAFRAME`). + * @param field_name : Name of the tensor or Dataframe column to filter on default="probs" + * @return std::shared_ptr>> + */ + static std::shared_ptr> init_mm(mrc::segment::Builder& builder, + const std::string& name, + float threshold, + bool copy, + FilterSource filter_source, + std::string field_name); + /** + * @brief Create and initialize a FilterDetectionStage that receives ControlMessage and emits ControlMessage, and + * return the result * * @param builder : Pipeline context object reference * @param name : Name of a stage reference @@ -117,14 +143,15 @@ struct MORPHEUS_EXPORT FilterDetectionStageInterfaceProxy * @param filter_source : Indicate if the values used for filtering exist in either an output tensor * (`FilterSource::TENSOR`) or a column in a Dataframe (`FilterSource::DATAFRAME`). * @param field_name : Name of the tensor or Dataframe column to filter on default="probs" - * @return std::shared_ptr> + * @return std::shared_ptr>> */ - static std::shared_ptr> init(mrc::segment::Builder& builder, - const std::string& name, - float threshold, - bool copy, - FilterSource filter_source, - std::string field_name); + static std::shared_ptr> init_cm(mrc::segment::Builder& builder, + const std::string& name, + float threshold, + bool copy, + FilterSource filter_source, + std::string field_name); }; + /** @} */ // end of group } // namespace morpheus diff --git a/morpheus/_lib/src/stages/add_scores_stage_base.cpp b/morpheus/_lib/src/stages/add_scores_stage_base.cpp index b7ff58ca67..4bb76420de 100644 --- a/morpheus/_lib/src/stages/add_scores_stage_base.cpp +++ b/morpheus/_lib/src/stages/add_scores_stage_base.cpp @@ -18,26 +18,24 @@ #include "morpheus/stages/add_scores_stage_base.hpp" #include "morpheus/messages/memory/tensor_memory.hpp" // for TensorMemory -#include "morpheus/messages/meta.hpp" -#include "morpheus/messages/multi_response.hpp" // for MultiResponseMessage -#include "morpheus/objects/dtype.hpp" // for DType -#include "morpheus/objects/tensor.hpp" // for Tensor -#include "morpheus/objects/tensor_object.hpp" // for TensorObject -#include "morpheus/types.hpp" // for TensorIndex -#include "morpheus/utilities/matx_util.hpp" // for MatxUtil -#include "morpheus/utilities/string_util.hpp" // for StringUtil -#include "morpheus/utilities/tensor_util.hpp" // for TensorUtils - -#include // for CHECK, COMPACT_GOOGLE_LOG_FATAL, LogMessageFatal, COMP... +#include "morpheus/messages/meta.hpp" // for MessageMeta +#include "morpheus/messages/multi_response.hpp" // for MultiResponseMessage +#include "morpheus/objects/dtype.hpp" // for DType +#include "morpheus/objects/tensor.hpp" // for Tensor +#include "morpheus/objects/tensor_object.hpp" // for TensorObject +#include "morpheus/types.hpp" // for TensorIndex +#include "morpheus/utilities/matx_util.hpp" // for MatxUtil +#include "morpheus/utilities/string_util.hpp" // for StringUtil +#include "morpheus/utilities/tensor_util.hpp" // for TensorUtils + +#include // for CHECK, COMPACT_GOOGLE_LOG_FATAL, LogMessageFatal #include // for observable_member, trace_activity, decay_t, operator| #include // for size_t #include // for reverse_iterator #include // for shared_ptr, allocator, __shared_ptr_access #include // for basic_ostream, operator<<, basic_ostream::operator<< -#include // for runtime_error #include // for is_same_v -#include // for type_info #include // for move, pair #include // for vector // IWYU thinks we need __alloc_traits<>::value_type for vector assignments @@ -72,12 +70,10 @@ AddScoresStageBase::source_type_t AddScoresStageBaseon_control_message(x); } - // sink_type_t not supported else { - std::string error_msg{"AddScoresStageBase receives unsupported input type: " + std::string(typeid(x).name())}; - LOG(ERROR) << error_msg; - throw std::runtime_error(error_msg); + // sink_type_t not supported + static_assert(!sizeof(sink_type_t), "AddScoresStageBase receives unsupported input type"); } return x; } diff --git a/morpheus/_lib/src/stages/filter_detection.cpp b/morpheus/_lib/src/stages/filter_detection.cpp deleted file mode 100644 index 199d716e5b..0000000000 --- a/morpheus/_lib/src/stages/filter_detection.cpp +++ /dev/null @@ -1,217 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "morpheus/stages/filter_detection.hpp" // IWYU pragma: accosiated - -#include "mrc/segment/builder.hpp" -#include "mrc/segment/object.hpp" -#include "pymrc/node.hpp" - -#include "morpheus/messages/multi_tensor.hpp" -#include "morpheus/objects/dev_mem_info.hpp" // for DevMemInfo -#include "morpheus/objects/dtype.hpp" // for DataType -#include "morpheus/objects/memory_descriptor.hpp" -#include "morpheus/objects/table_info.hpp" -#include "morpheus/objects/tensor_object.hpp" // for TensorIndex, TensorObject -#include "morpheus/types.hpp" // for RangeType -#include "morpheus/utilities/matx_util.hpp" -#include "morpheus/utilities/tensor_util.hpp" // for TensorUtils::get_element_stride - -#include // for cudaMemcpy, cudaMemcpyDeviceToDevice, cudaMemcpyDeviceToHost -#include -#include -#include // for CHECK, CHECK_NE -#include // for MRC_CHECK_CUDA -#include -#include // for device_buffer -#include - -#include -#include // for uint8_t -#include -#include -#include -#include // needed for glog -#include -#include // for pair -#include -// IWYU thinks we need ext/new_allocator.h for size_t for some reason -// IWYU pragma: no_include - -namespace morpheus { - -// Component public implementations -// ************ FilterDetectionStage **************************** // -FilterDetectionsStage::FilterDetectionsStage(float threshold, - bool copy, - FilterSource filter_source, - std::string field_name) : - PythonNode(base_t::op_factory_from_sub_fn(build_operator())), - m_threshold(threshold), - m_copy(copy), - m_filter_source(filter_source), - m_field_name(std::move(field_name)) -{ - CHECK(m_filter_source != FilterSource::Auto); // The python stage should determine this -} - -DevMemInfo FilterDetectionsStage::get_tensor_filter_source(const std::shared_ptr& x) -{ - // The pipeline build will check to ensure that our input is a MultiResponseMessage - const auto& filter_source = std::static_pointer_cast(x)->get_tensor(m_field_name); - CHECK(filter_source.rank() > 0 && filter_source.rank() <= 2) - << "C++ impl of the FilterDetectionsStage currently only supports one and two dimensional " - "arrays"; - - // Depending on the input the stride is given in bytes or elements, convert to elements - auto stride = morpheus::TensorUtils::get_element_stride(filter_source.get_stride()); - return {filter_source.data(), filter_source.dtype(), filter_source.get_memory(), filter_source.get_shape(), stride}; -} - -DevMemInfo FilterDetectionsStage::get_column_filter_source(const std::shared_ptr& x) -{ - auto table_info = x->get_meta(m_field_name); - - // since we only asked for one column, we know its the first - const auto& col = table_info.get_column(0); - auto dtype = morpheus::DType::from_cudf(col.type().id()); - auto num_rows = col.size(); - auto data = - const_cast(static_cast(col.head() + col.offset() * dtype.item_size())); - - return { - data, - std::move(dtype), - std::make_shared(rmm::cuda_stream_per_thread, rmm::mr::get_current_device_resource()), - {num_rows, 1}, - {1, 0}, - }; -} - -FilterDetectionsStage::subscribe_fn_t FilterDetectionsStage::build_operator() -{ - return [this](rxcpp::observable input, rxcpp::subscriber output) { - std::function& x)> get_filter_source; - - if (m_filter_source == FilterSource::TENSOR) - { - get_filter_source = [this](auto x) { - return get_tensor_filter_source(x); - }; - } - else - { - get_filter_source = [this](auto x) { - return get_column_filter_source(x); - }; - } - - return input.subscribe(rxcpp::make_observer( - [this, &output, &get_filter_source](sink_type_t x) { - auto tmp_buffer = get_filter_source(x); - - const auto num_rows = tmp_buffer.shape(0); - const auto num_columns = tmp_buffer.shape(1); - - bool by_row = (num_columns > 1); - - // Now call the threshold function - auto thresh_bool_buffer = MatxUtil::threshold(tmp_buffer, m_threshold, by_row); - - std::vector host_bool_values(num_rows); - - // Copy bools back to host - MRC_CHECK_CUDA(cudaMemcpy(host_bool_values.data(), - thresh_bool_buffer->data(), - thresh_bool_buffer->size(), - cudaMemcpyDeviceToHost)); - - // Only used when m_copy is true - std::vector selected_ranges; - std::size_t num_selected_rows = 0; - - // We are slicing by rows, using num_rows as our marker for undefined - std::size_t slice_start = num_rows; - for (std::size_t row = 0; row < num_rows; ++row) - { - bool above_threshold = host_bool_values[row]; - - if (above_threshold && slice_start == num_rows) - { - slice_start = row; - } - else if (!above_threshold && slice_start != num_rows) - { - if (m_copy) - { - selected_ranges.emplace_back(std::pair{slice_start, row}); - num_selected_rows += (row - slice_start); - } - else - { - output.on_next(x->get_slice(slice_start, row)); - } - - slice_start = num_rows; - } - } - - if (slice_start != num_rows) - { - // Last row was above the threshold - if (m_copy) - { - selected_ranges.emplace_back(std::pair{slice_start, num_rows}); - num_selected_rows += (num_rows - slice_start); - } - else - { - output.on_next(x->get_slice(slice_start, num_rows)); - } - } - - // num_selected_rows will always be 0 when m_copy is false, - // or when m_copy is true, but none of the rows matched the output - if (num_selected_rows > 0) - { - DCHECK(m_copy); - output.on_next(x->copy_ranges(selected_ranges, num_selected_rows)); - } - }, - [&](std::exception_ptr error_ptr) { - output.on_error(error_ptr); - }, - [&]() { - output.on_completed(); - })); - }; -} - -// ************ FilterDetectionStageInterfaceProxy ************* // -std::shared_ptr> FilterDetectionStageInterfaceProxy::init( - mrc::segment::Builder& builder, - const std::string& name, - float threshold, - bool copy, - FilterSource filter_source, - std::string field_name) -{ - auto stage = builder.construct_object(name, threshold, copy, filter_source, field_name); - - return stage; -} -} // namespace morpheus diff --git a/morpheus/_lib/src/stages/filter_detections.cpp b/morpheus/_lib/src/stages/filter_detections.cpp new file mode 100644 index 0000000000..c1cf8b7036 --- /dev/null +++ b/morpheus/_lib/src/stages/filter_detections.cpp @@ -0,0 +1,310 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "morpheus/stages/filter_detections.hpp" + +#include "mrc/segment/builder.hpp" // for Builder +#include "mrc/segment/object.hpp" // for Object + +#include "morpheus/messages/control.hpp" // for ControlMessage +#include "morpheus/messages/multi.hpp" // for MultiMessage +#include "morpheus/messages/multi_tensor.hpp" // for MultiTensorMessage +#include "morpheus/objects/dev_mem_info.hpp" // for DevMemInfo +#include "morpheus/objects/dtype.hpp" // for DType +#include "morpheus/objects/memory_descriptor.hpp" // for MemoryDescriptor +#include "morpheus/objects/table_info.hpp" // for TableInfo +#include "morpheus/types.hpp" // for RangeType +#include "morpheus/utilities/matx_util.hpp" // for MatxUtil +#include "morpheus/utilities/tensor_util.hpp" // for TensorUtils + +#include // for cudaMemcpy, cudaMemcpyKind +#include // for column_view +#include // for data_type +#include // for COMPACT_GOOGLE_LOG_FATAL, LogMessageFatal, CHECK, DCHECK +#include // for MRC_CHECK_CUDA +#include // for cuda_stream_per_thread +#include // for get_current_device_resource + +#include // for size_t +#include // for uint8_t +#include // for exception_ptr +#include // for function +#include // for make_shared, shared_ptr, __shared_ptr_access +#include // for operator<<, basic_ostream +#include // for char_traits, string +#include // for move, pair +#include // for vector + +namespace morpheus { + +// Component public implementations +// ************ FilterDetectionStage **************************** // +template +FilterDetectionsStage::FilterDetectionsStage(float threshold, + bool copy, + FilterSource filter_source, + std::string field_name) : + base_t(base_t::op_factory_from_sub_fn(build_operator())), + m_threshold(threshold), + m_copy(copy), + m_filter_source(filter_source), + m_field_name(std::move(field_name)) +{ + CHECK(m_filter_source != FilterSource::Auto); // The python stage should determine this +} + +template +DevMemInfo FilterDetectionsStage::get_tensor_filter_source(const sink_type_t& x) +{ + if constexpr (std::is_same_v) + { + // The pipeline build will check to ensure that our input is a MultiResponseMessage + const auto& filter_source = std::static_pointer_cast(x)->get_tensor(m_field_name); + CHECK(filter_source.rank() > 0 && filter_source.rank() <= 2) + << "C++ impl of the FilterDetectionsStage currently only supports one and two dimensional " + "arrays"; + + // Depending on the input the stride is given in bytes or elements, convert to elements + auto stride = TensorUtils::get_element_stride(filter_source.get_stride()); + return { + filter_source.data(), filter_source.dtype(), filter_source.get_memory(), filter_source.get_shape(), stride}; + } + else if constexpr (std::is_same_v) + { + const auto& filter_source = x->tensors()->get_tensor(m_field_name); + CHECK(filter_source.rank() > 0 && filter_source.rank() <= 2) + << "C++ impl of the FilterDetectionsStage currently only supports one and two dimensional " + "arrays"; + + // Depending on the input the stride is given in bytes or elements, convert to elements + auto stride = TensorUtils::get_element_stride(filter_source.get_stride()); + return { + filter_source.data(), filter_source.dtype(), filter_source.get_memory(), filter_source.get_shape(), stride}; + } + else + { + // sink_type_t not supported + static_assert(!sizeof(sink_type_t), "FilterDetectionsStage receives unsupported input type"); + } +} + +template +DevMemInfo FilterDetectionsStage::get_column_filter_source(const sink_type_t& x) +{ + TableInfo table_info; + if constexpr (std::is_same_v) + { + table_info = x->get_meta(m_field_name); + } + else if constexpr (std::is_same_v) + { + table_info = x->payload()->get_info(m_field_name); + } + else + { + // sink_type_t not supported + static_assert(!sizeof(sink_type_t), "FilterDetectionsStage receives unsupported input type"); + } + + // since we only asked for one column, we know its the first + const auto& col = table_info.get_column(0); + auto dtype = DType::from_cudf(col.type().id()); + auto num_rows = col.size(); + auto data = + const_cast(static_cast(col.head() + col.offset() * dtype.item_size())); + + return { + data, + std::move(dtype), + std::make_shared(rmm::cuda_stream_per_thread, rmm::mr::get_current_device_resource()), + {num_rows, 1}, + {1, 0}, + }; +} + +template +FilterDetectionsStage::subscribe_fn_t FilterDetectionsStage::build_operator() +{ + return [this](rxcpp::observable input, rxcpp::subscriber output) { + std::function get_filter_source; + + if (m_filter_source == FilterSource::TENSOR) + { + get_filter_source = [this](auto x) { + return get_tensor_filter_source(x); + }; + } + else + { + get_filter_source = [this](auto x) { + return get_column_filter_source(x); + }; + } + + return input.subscribe(rxcpp::make_observer( + [this, &output, &get_filter_source](sink_type_t x) { + auto tmp_buffer = get_filter_source(x); + + const auto num_rows = tmp_buffer.shape(0); + const auto num_columns = tmp_buffer.shape(1); + + bool by_row = (num_columns > 1); + + // Now call the threshold function + auto thresh_bool_buffer = MatxUtil::threshold(tmp_buffer, m_threshold, by_row); + + std::vector host_bool_values(num_rows); + + // Copy bools back to host + MRC_CHECK_CUDA(cudaMemcpy(host_bool_values.data(), + thresh_bool_buffer->data(), + thresh_bool_buffer->size(), + cudaMemcpyDeviceToHost)); + + // Only used when m_copy is true + std::vector selected_ranges; + std::size_t num_selected_rows = 0; + + // We are slicing by rows, using num_rows as our marker for undefined + std::size_t slice_start = num_rows; + for (std::size_t row = 0; row < num_rows; ++row) + { + bool above_threshold = host_bool_values[row]; + + if (above_threshold && slice_start == num_rows) + { + slice_start = row; + } + else if (!above_threshold && slice_start != num_rows) + { + if (m_copy) + { + selected_ranges.emplace_back(std::pair{slice_start, row}); + num_selected_rows += (row - slice_start); + } + else + { + if constexpr (std::is_same_v) + { + output.on_next(x->get_slice(slice_start, row)); + } + else if constexpr (std::is_same_v) + { + auto meta = x->payload(); + std::shared_ptr sliced_cm = std::make_shared(*x); + sliced_cm->payload(meta->get_slice(slice_start, row)); + output.on_next(sliced_cm); + } + else + { + // sink_type_t not supported + static_assert(!sizeof(sink_type_t), + "FilterDetectionsStage receives unsupported input type"); + } + } + + slice_start = num_rows; + } + } + + if (slice_start != num_rows) + { + // Last row was above the threshold + if (m_copy) + { + selected_ranges.emplace_back(std::pair{slice_start, num_rows}); + num_selected_rows += (num_rows - slice_start); + } + else + { + if constexpr (std::is_same_v) + { + output.on_next(x->get_slice(slice_start, num_rows)); + } + else if constexpr (std::is_same_v) + { + auto meta = x->payload(); + x->payload(meta->get_slice(slice_start, num_rows)); + output.on_next(x); + } + else + { + // sink_type_t not supported + static_assert(!sizeof(sink_type_t), + "FilterDetectionsStage receives unsupported input type"); + } + } + } + + // num_selected_rows will always be 0 when m_copy is false, + // or when m_copy is true, but none of the rows matched the output + if (num_selected_rows > 0) + { + DCHECK(m_copy); + if constexpr (std::is_same_v) + { + output.on_next(x->copy_ranges(selected_ranges, num_selected_rows)); + } + else if constexpr (std::is_same_v) + { + auto meta = x->payload(); + x->payload(meta->copy_ranges(selected_ranges)); + output.on_next(x); + } + else + { + // sink_type_t not supported + static_assert(!sizeof(sink_type_t), "FilterDetectionsStage receives unsupported input type"); + } + } + }, + [&](std::exception_ptr error_ptr) { + output.on_error(error_ptr); + }, + [&]() { + output.on_completed(); + })); + }; +} + +// ************ FilterDetectionStageInterfaceProxy ************* // +std::shared_ptr> FilterDetectionStageInterfaceProxy::init_mm( + mrc::segment::Builder& builder, + const std::string& name, + float threshold, + bool copy, + FilterSource filter_source, + std::string field_name) +{ + auto stage = builder.construct_object(name, threshold, copy, filter_source, field_name); + + return stage; +} + +std::shared_ptr> FilterDetectionStageInterfaceProxy::init_cm( + mrc::segment::Builder& builder, + const std::string& name, + float threshold, + bool copy, + FilterSource filter_source, + std::string field_name) +{ + auto stage = builder.construct_object(name, threshold, copy, filter_source, field_name); + + return stage; +} +} // namespace morpheus diff --git a/morpheus/_lib/src/stages/preprocess_fil.cpp b/morpheus/_lib/src/stages/preprocess_fil.cpp index 978e7557eb..ad1e09c1b4 100644 --- a/morpheus/_lib/src/stages/preprocess_fil.cpp +++ b/morpheus/_lib/src/stages/preprocess_fil.cpp @@ -38,7 +38,6 @@ #include // for column_view #include // for type_id, data_type #include // for cast -#include // for COMPACT_GOOGLE_LOG_ERROR, LOG, LogMessage #include // for __check_cuda_errors, MRC_CHECK_CUDA #include // for Builder #include // for gil_scoped_acquire @@ -50,9 +49,7 @@ #include // for find #include // for size_t #include // for shared_ptr, __shared_ptr_access, allocator, mak... -#include // for runtime_error #include // for is_same_v -#include // for type_info #include // for move namespace morpheus { @@ -144,12 +141,10 @@ TableInfo PreprocessFILStage::fix_bad_columns(sink_type_t x) // Now re-get the meta return x->payload()->get_info(m_fea_cols); } - // sink_type_t not supported else { - std::string error_msg{"PreProcessFILStage receives unsupported input type: " + std::string(typeid(x).name())}; - LOG(ERROR) << error_msg; - throw std::runtime_error(error_msg); + // sink_type_t not supported + static_assert(!sizeof(sink_type_t), "PreProcessFILStage receives unsupported input type"); } } @@ -164,12 +159,10 @@ PreprocessFILStage::source_type_t PreprocessFILStage // for column #include // for make_column_from_scalar @@ -42,7 +42,6 @@ #include // for table_view #include // for type_id, data_type #include // for cast -#include // for COMPACT_GOOGLE_LOG_ERROR, LOG, LogMessage #include // for Builder #include // for normalize_spaces #include // for tokenizer_result, load_vocabulary_file, subword_tok... @@ -52,9 +51,7 @@ #include // for uint32_t, int32_t #include // for shared_ptr, unique_ptr, __shared_ptr_access, make_s... -#include // for runtime_error #include // for is_same_v -#include // for type_info #include // for move #include // for vector @@ -100,12 +97,10 @@ PreprocessNLPStage::source_type_t PreprocessNLPStageon_control_message(x); } - // sink_type_t not supported else { - std::string error_msg{"PreProcessNLPStage receives unsupported input type: " + std::string(typeid(x).name())}; - LOG(ERROR) << error_msg; - throw std::runtime_error(error_msg); + // sink_type_t not supported + static_assert(!sizeof(sink_type_t), "PreProcessNLPStage receives unsupported input type"); } } diff --git a/morpheus/_lib/stages/__init__.pyi b/morpheus/_lib/stages/__init__.pyi index bfd66dcb64..8b8413b67e 100644 --- a/morpheus/_lib/stages/__init__.pyi +++ b/morpheus/_lib/stages/__init__.pyi @@ -21,7 +21,8 @@ __all__ = [ "DeserializeControlMessageStage", "DeserializeMultiMessageStage", "FileSourceStage", - "FilterDetectionsStage", + "FilterDetectionsControlMessageStage", + "FilterDetectionsMultiMessageStage", "FilterSource", "HttpServerSourceStage", "InferenceClientStageCM", @@ -64,7 +65,10 @@ class FileSourceStage(mrc.core.segment.SegmentObject): @typing.overload def __init__(self, builder: mrc.core.segment.Builder, name: str, filename: str, repeat: int, filter_null: bool, filter_null_columns: typing.List[str], parser_kwargs: dict) -> None: ... pass -class FilterDetectionsStage(mrc.core.segment.SegmentObject): +class FilterDetectionsControlMessageStage(mrc.core.segment.SegmentObject): + def __init__(self, builder: mrc.core.segment.Builder, name: str, threshold: float, copy: bool, filter_source: morpheus._lib.common.FilterSource, field_name: str = 'probs') -> None: ... + pass +class FilterDetectionsMultiMessageStage(mrc.core.segment.SegmentObject): def __init__(self, builder: mrc.core.segment.Builder, name: str, threshold: float, copy: bool, filter_source: morpheus._lib.common.FilterSource, field_name: str = 'probs') -> None: ... pass class HttpServerSourceStage(mrc.core.segment.SegmentObject): diff --git a/morpheus/_lib/stages/module.cpp b/morpheus/_lib/stages/module.cpp index 32c3c5e030..5b33f59179 100644 --- a/morpheus/_lib/stages/module.cpp +++ b/morpheus/_lib/stages/module.cpp @@ -25,7 +25,7 @@ #include "morpheus/stages/add_scores.hpp" #include "morpheus/stages/deserialize.hpp" #include "morpheus/stages/file_source.hpp" -#include "morpheus/stages/filter_detection.hpp" +#include "morpheus/stages/filter_detections.hpp" #include "morpheus/stages/http_server_source_stage.hpp" #include "morpheus/stages/inference_client_stage.hpp" #include "morpheus/stages/kafka_source.hpp" @@ -168,11 +168,23 @@ PYBIND11_MODULE(stages, _module) py::arg("filter_null_columns"), py::arg("parser_kwargs")); - py::class_, + py::class_, mrc::segment::ObjectProperties, - std::shared_ptr>>( - _module, "FilterDetectionsStage", py::multiple_inheritance()) - .def(py::init<>(&FilterDetectionStageInterfaceProxy::init), + std::shared_ptr>>( + _module, "FilterDetectionsMultiMessageStage", py::multiple_inheritance()) + .def(py::init<>(&FilterDetectionStageInterfaceProxy::init_mm), + py::arg("builder"), + py::arg("name"), + py::arg("threshold"), + py::arg("copy"), + py::arg("filter_source"), + py::arg("field_name") = "probs"); + + py::class_, + mrc::segment::ObjectProperties, + std::shared_ptr>>( + _module, "FilterDetectionsControlMessageStage", py::multiple_inheritance()) + .def(py::init<>(&FilterDetectionStageInterfaceProxy::init_cm), py::arg("builder"), py::arg("name"), py::arg("threshold"), diff --git a/morpheus/controllers/filter_detections_controller.py b/morpheus/controllers/filter_detections_controller.py index ecd38a59b3..167bef64fb 100644 --- a/morpheus/controllers/filter_detections_controller.py +++ b/morpheus/controllers/filter_detections_controller.py @@ -20,6 +20,7 @@ import typing_utils from morpheus.common import FilterSource +from morpheus.messages import ControlMessage from morpheus.messages import MultiMessage from morpheus.messages import MultiResponseMessage @@ -66,12 +67,18 @@ def field_name(self): """ return self._field_name - def _find_detections(self, x: MultiMessage) -> typing.Union[cp.ndarray, np.ndarray]: - # Determind the filter source - if self._filter_source == FilterSource.TENSOR: - filter_source = x.get_output(self._field_name) - else: - filter_source = x.get_meta(self._field_name).values + def _find_detections(self, x: MultiMessage | ControlMessage) -> typing.Union[cp.ndarray, np.ndarray]: + # Determine the filter source + if isinstance(x, MultiMessage): + if self._filter_source == FilterSource.TENSOR: + filter_source = x.get_output(self._field_name) + else: + filter_source = x.get_meta(self._field_name).values + elif isinstance(x, ControlMessage): + if self._filter_source == FilterSource.TENSOR: + filter_source = x.tensors().get_tensor(self._field_name) + else: + filter_source = x.payload().get_data(self._field_name).values if (isinstance(filter_source, np.ndarray)): array_mod = np @@ -89,7 +96,7 @@ def _find_detections(self, x: MultiMessage) -> typing.Union[cp.ndarray, np.ndarr return array_mod.where(detections[1:] != detections[:-1])[0].reshape((-1, 2)) - def filter_copy(self, x: MultiMessage) -> MultiMessage: + def filter_copy(self, x: MultiMessage | ControlMessage) -> MultiMessage | ControlMessage: """ This function uses a threshold value to filter the messages. @@ -113,9 +120,15 @@ def filter_copy(self, x: MultiMessage) -> MultiMessage: if (true_pairs.shape[0] == 0): return None - return x.copy_ranges(true_pairs) + if isinstance(x, MultiMessage): + return x.copy_ranges(true_pairs) + if isinstance(x, ControlMessage): + meta = x.payload() + x.payload(meta.copy_ranges(true_pairs)) + return x + raise TypeError(f"Unsupported message type: {type(x)}") - def filter_slice(self, x: MultiMessage) -> typing.List[MultiMessage]: + def filter_slice(self, x: MultiMessage | ControlMessage) -> typing.List[MultiMessage] | typing.List[ControlMessage]: """ This function uses a threshold value to filter the messages. @@ -134,10 +147,19 @@ def filter_slice(self, x: MultiMessage) -> typing.List[MultiMessage]: output_list = [] if x is not None: true_pairs = self._find_detections(x) - for pair in true_pairs: - pair = tuple(pair.tolist()) - if ((pair[1] - pair[0]) > 0): - output_list.append(x.get_slice(*pair)) + if isinstance(x, MultiMessage): + for pair in true_pairs: + pair = tuple(pair.tolist()) + if ((pair[1] - pair[0]) > 0): + output_list.append(x.get_slice(*pair)) + elif isinstance(x, ControlMessage): + for pair in true_pairs: + pair = tuple(pair.tolist()) + if ((pair[1] - pair[0]) > 0): + sliced_meta = x.payload().get_slice(*pair) + cm = ControlMessage(x) + cm.payload(sliced_meta) + output_list.append(cm) return output_list diff --git a/morpheus/stages/postprocess/filter_detections_stage.py b/morpheus/stages/postprocess/filter_detections_stage.py index c071ffb333..9cadb26290 100644 --- a/morpheus/stages/postprocess/filter_detections_stage.py +++ b/morpheus/stages/postprocess/filter_detections_stage.py @@ -23,6 +23,7 @@ from morpheus.common import FilterSource from morpheus.config import Config from morpheus.controllers.filter_detections_controller import FilterDetectionsController +from morpheus.messages import ControlMessage from morpheus.messages import MultiMessage from morpheus.messages import MultiResponseMessage from morpheus.pipeline.single_port_stage import SinglePortStage @@ -103,9 +104,9 @@ def accepted_types(self) -> typing.Tuple: """ if self._controller.filter_source == FilterSource.TENSOR: - return (MultiResponseMessage, ) + return (MultiResponseMessage, ControlMessage) - return (MultiMessage, ) + return (MultiMessage, ControlMessage) def compute_schema(self, schema: StageSchema): self._controller.update_filter_source(message_type=schema.input_type) @@ -117,12 +118,21 @@ def supports_cpp_node(self): def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject: if self._build_cpp_node(): - node = _stages.FilterDetectionsStage(builder, - self.unique_name, - self._controller.threshold, - self._copy, - self._controller.filter_source, - self._controller.field_name) + if (self._schema.input_type == ControlMessage): + node = _stages.FilterDetectionsControlMessageStage(builder, + self.unique_name, + self._controller.threshold, + self._copy, + self._controller.filter_source, + self._controller.field_name) + + else: + node = _stages.FilterDetectionsMultiMessageStage(builder, + self.unique_name, + self._controller.threshold, + self._copy, + self._controller.filter_source, + self._controller.field_name) else: if self._copy: diff --git a/morpheus/stages/postprocess/generate_viz_frames_stage.py b/morpheus/stages/postprocess/generate_viz_frames_stage.py index cf60d638b8..b2d059666c 100644 --- a/morpheus/stages/postprocess/generate_viz_frames_stage.py +++ b/morpheus/stages/postprocess/generate_viz_frames_stage.py @@ -32,6 +32,7 @@ from morpheus.cli.register_stage import register_stage from morpheus.config import Config from morpheus.config import PipelineModes +from morpheus.messages import ControlMessage from morpheus.messages import MultiResponseMessage from morpheus.pipeline.pass_thru_type_mixin import PassThruTypeMixin from morpheus.pipeline.single_port_stage import SinglePortStage @@ -91,11 +92,11 @@ def accepted_types(self) -> typing.Tuple: Returns ------- - typing.Tuple[morpheus.pipeline.messages.MultiResponseMessage, ] + typing.Tuple[morpheus.pipeline.messages.MultiResponseMessage, ControlMessage] Accepted input types """ - return (MultiResponseMessage, ) + return (MultiResponseMessage, ControlMessage) def supports_cpp_node(self): return False @@ -118,7 +119,7 @@ def round_to_sec(x: int | float): """ return int(round(x / 1000.0) * 1000) - def _to_vis_df(self, x: MultiResponseMessage): + def _to_vis_df(self, x: MultiResponseMessage | ControlMessage): idx2label = { 0: 'address', @@ -133,7 +134,11 @@ def _to_vis_df(self, x: MultiResponseMessage): 9: 'user' } - df = x.get_meta(["timestamp", "src_ip", "dest_ip", "src_port", "dest_port", "data"]) + columns = ["timestamp", "src_ip", "dest_ip", "src_port", "dest_port", "data"] + if isinstance(x, MultiResponseMessage): + df = x.get_meta(columns) + elif isinstance(x, ControlMessage): + df = x.payload().get_data(columns) def indent_data(y: str): try: @@ -141,9 +146,16 @@ def indent_data(y: str): except Exception: return y + if isinstance(df, cudf.DataFrame): + df = df.to_pandas() + df["data"] = df["data"].apply(indent_data) - probs = x.get_probs_tensor() + if isinstance(x, MultiResponseMessage): + probs = x.get_probs_tensor() + elif isinstance(x, ControlMessage): + probs = x.tensors().get_tensor("probs") + pass_thresh = (probs >= 0.5).any(axis=1) max_arg = probs.argmax(axis=1) @@ -263,14 +275,21 @@ def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> def node_fn(input_obs, output_obs): - def write_batch(x: MultiResponseMessage): + def write_batch(x: MultiResponseMessage | ControlMessage): sink = pa.BufferOutputStream() # This is the timestamp of the earliest message - time0 = x.get_meta("timestamp").min() - - df = x.get_meta(["timestamp", "src_ip", "dest_ip", "secret_keys", "data"]) + if isinstance(x, MultiResponseMessage): + time0 = x.get_meta("timestamp").min() + elif isinstance(x, ControlMessage): + time0 = x.payload().get_data("timestamp").min() + + columns = ["timestamp", "src_ip", "dest_ip", "secret_keys", "data"] + if isinstance(x, MultiResponseMessage): + df = x.get_meta(columns) + elif isinstance(x, ControlMessage): + df = x.payload().get_data(columns) out_df = cudf.DataFrame() diff --git a/morpheus/stages/postprocess/ml_flow_drift_stage.py b/morpheus/stages/postprocess/ml_flow_drift_stage.py index 4e5974cf51..60434c3a42 100644 --- a/morpheus/stages/postprocess/ml_flow_drift_stage.py +++ b/morpheus/stages/postprocess/ml_flow_drift_stage.py @@ -24,6 +24,7 @@ from morpheus.cli.register_stage import register_stage from morpheus.config import Config from morpheus.config import PipelineModes +from morpheus.messages import ControlMessage from morpheus.messages import MultiResponseMessage from morpheus.pipeline.pass_thru_type_mixin import PassThruTypeMixin from morpheus.pipeline.single_port_stage import SinglePortStage @@ -119,27 +120,36 @@ def accepted_types(self) -> typing.Tuple: Returns ------- - typing.Tuple[`morpheus.pipeline.messages.MultiResponseMessage`, ] + typing.Tuple[`morpheus.pipeline.messages.MultiResponseMessage`, ControlMessage] Accepted input types. """ - return (MultiResponseMessage, ) + return (MultiResponseMessage, ControlMessage) def supports_cpp_node(self): return False - def _calc_drift(self, x: MultiResponseMessage): + def _calc_drift(self, x: MultiResponseMessage | ControlMessage): + if isinstance(x, MultiResponseMessage): + probs_tensor = x.get_probs_tensor() + elif isinstance(x, ControlMessage): + probs_tensor = x.tensors().get_tensor("probs") # All probs in a batch will be calculated - shifted = cp.abs(x.get_probs_tensor() - 0.5) + 0.5 + shifted = cp.abs(probs_tensor - 0.5) + 0.5 # Make sure the labels list is long enough for label in range(len(self._labels), shifted.shape[1]): self._labels.append(str(label)) - for i in list(range(0, x.count, self._batch_size)): + if isinstance(x, MultiResponseMessage): + count = x.count + elif isinstance(x, ControlMessage): + count = x.payload().count + + for i in list(range(0, count, self._batch_size)): start = i - end = min(start + self._batch_size, x.count) + end = min(start + self._batch_size, count) mean = cp.mean(shifted[start:end, :], axis=0, keepdims=True) # For each column, report the metric diff --git a/morpheus/stages/postprocess/timeseries_stage.py b/morpheus/stages/postprocess/timeseries_stage.py index 28c4b70f2c..5005114df3 100644 --- a/morpheus/stages/postprocess/timeseries_stage.py +++ b/morpheus/stages/postprocess/timeseries_stage.py @@ -28,8 +28,9 @@ from morpheus.cli.register_stage import register_stage from morpheus.config import Config from morpheus.config import PipelineModes +from morpheus.messages import ControlMessage +from morpheus.messages import MultiResponseAEMessage from morpheus.messages import MultiResponseMessage -from morpheus.messages.multi_ae_message import MultiMessage from morpheus.pipeline.pass_thru_type_mixin import PassThruTypeMixin from morpheus.pipeline.single_port_stage import SinglePortStage @@ -58,7 +59,6 @@ def calc_bin(obj: pd.Timestamp, time0: pd.Timestamp, resolution_sec: float) -> i """ Calculates the bin spacing between the start and stop timestamp at a specified resolution. """ - return round((round_seconds(obj) - time0).total_seconds()) // resolution_sec @@ -164,7 +164,7 @@ class _TimeSeriesAction: window_end: dt.datetime = None send_message: bool = False - message: MultiResponseMessage = None + message: MultiResponseMessage | ControlMessage = None class _UserTimeSeries: @@ -207,7 +207,8 @@ def __init__(self, self._holding_timestamps = deque() # Stateful members - self._pending_messages: deque[MultiResponseMessage] = deque() # Holds the existing messages pending + self._pending_messages: deque[MultiResponseMessage + | ControlMessage] = deque() # Holds the existing messages pending self._timeseries_data: pd.DataFrame = pd.DataFrame(columns=[self._timestamp_col ]) # Holds all available timeseries data @@ -263,16 +264,24 @@ def _determine_action(self, is_complete: bool) -> typing.Optional[_TimeSeriesAct if (len(self._pending_messages) == 0): return None - # Note: We calculate everything in bins to ensure 1) Full bins, and 2) Even binning + # Note: We calculate everything in bins to ensure 1) Full xbins, and 2) Even binning timeseries_start = self._timeseries_data["event_bin"].iloc[0] timeseries_end = self._timeseries_data["event_bin"].iloc[-1] # Peek the front message - x: MultiResponseMessage = self._pending_messages[0] + x: MultiResponseMessage | ControlMessage = self._pending_messages[0] # Get the first message timestamp - message_start = calc_bin(x.get_meta(self._timestamp_col).iloc[0], self._t0_epoch, self._resolution_sec) - message_end = calc_bin(x.get_meta(self._timestamp_col).iloc[-1], self._t0_epoch, self._resolution_sec) + if isinstance(x, MultiResponseMessage): + message_start = calc_bin(x.get_meta(self._timestamp_col).iloc[0], self._t0_epoch, self._resolution_sec) + message_end = calc_bin(x.get_meta(self._timestamp_col).iloc[-1], self._t0_epoch, self._resolution_sec) + elif isinstance(x, ControlMessage): + message_start = calc_bin(pd.Timestamp(x.payload().get_data(self._timestamp_col).iloc[0]), + self._t0_epoch, + self._resolution_sec) + message_end = calc_bin(pd.Timestamp(x.payload().get_data(self._timestamp_col).iloc[-1]), + self._t0_epoch, + self._resolution_sec) window_start = message_start - self._half_window_bins window_end = message_end + self._half_window_bins @@ -341,17 +350,23 @@ def _determine_action(self, is_complete: bool) -> typing.Optional[_TimeSeriesAct send_message=True, message=self._pending_messages.popleft()) - def _calc_timeseries(self, x: MultiResponseMessage, is_complete: bool): + def _calc_timeseries(self, x: MultiResponseMessage | ControlMessage, is_complete: bool): if (x is not None): # Ensure that we have the meta column set for all messages - x.set_meta("ts_anomaly", False) + if isinstance(x, MultiResponseMessage): + x.set_meta("ts_anomaly", False) + elif isinstance(x, ControlMessage): + x.payload().set_data("ts_anomaly", False) # Save this message in the pending queue self._pending_messages.append(x) - new_timedata = x.get_meta([self._timestamp_col]) + if isinstance(x, MultiResponseMessage): + new_timedata = x.get_meta([self._timestamp_col]) + elif isinstance(x, ControlMessage): + new_timedata = x.payload().get_data([self._timestamp_col]).to_pandas() # Save this message event times in the event list. Ensure the values are always sorted self._timeseries_data = pd.concat([self._timeseries_data, new_timedata]).sort_index() @@ -472,34 +487,38 @@ def accepted_types(self) -> typing.Tuple: Returns ------- - typing.Tuple[`morpheus.pipeline.messages.MultiResponseMessage`, ] + typing.Tuple[`morpheus.pipeline.messages.MultiResponseMessage`, ControlMessage] Accepted input types. """ - return (MultiMessage, ) + return (MultiResponseMessage, ControlMessage) def supports_cpp_node(self): return False - def _call_timeseries_user(self, x: MultiMessage): + def _call_timeseries_user(self, x: MultiResponseAEMessage | ControlMessage): + if isinstance(x, MultiResponseAEMessage): + user_id = x.user_id + elif isinstance(x, ControlMessage): + user_id = x.get_metadata("user_id") - if (x.user_id not in self._timeseries_per_user): - self._timeseries_per_user[x.user_id] = _UserTimeSeries(user_id=x.user_id, - timestamp_col=self._timestamp_col, - resolution=self._resolution, - min_window=self._min_window, - hot_start=self._hot_start, - cold_end=self._cold_end, - filter_percent=self._filter_percent, - zscore_threshold=self._zscore_threshold) + if (user_id not in self._timeseries_per_user): + self._timeseries_per_user[user_id] = _UserTimeSeries(user_id=user_id, + timestamp_col=self._timestamp_col, + resolution=self._resolution, + min_window=self._min_window, + hot_start=self._hot_start, + cold_end=self._cold_end, + filter_percent=self._filter_percent, + zscore_threshold=self._zscore_threshold) - return self._timeseries_per_user[x.user_id]._calc_timeseries(x, False) + return self._timeseries_per_user[user_id]._calc_timeseries(x, False) def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject: - def on_next(x: MultiMessage): + def on_next(x: MultiResponseMessage | ControlMessage): - message_list: typing.List[MultiResponseMessage] = self._call_timeseries_user(x) + message_list: typing.List[MultiResponseMessage | ControlMessage] = self._call_timeseries_user(x) return message_list @@ -508,7 +527,8 @@ def on_completed(): to_send = [] for timestamp in self._timeseries_per_user.values(): - message_list: typing.List[MultiResponseMessage] = timestamp._calc_timeseries(None, True) + message_list: typing.List[MultiResponseMessage | ControlMessage] = timestamp._calc_timeseries( + None, True) to_send = to_send + message_list diff --git a/morpheus/stages/postprocess/validation_stage.py b/morpheus/stages/postprocess/validation_stage.py index 7ae46db06f..445e8b4be9 100644 --- a/morpheus/stages/postprocess/validation_stage.py +++ b/morpheus/stages/postprocess/validation_stage.py @@ -111,7 +111,7 @@ def accepted_types(self) -> typing.Tuple: Returns ------- - typing.Tuple(`morpheus.pipeline.messages.MultiMessage`, ) + typing.Tuple(`morpheus.pipeline.messages.MultiMessage`, ControlMessage) Accepted input types. """ diff --git a/morpheus/stages/preprocess/preprocess_ae_stage.py b/morpheus/stages/preprocess/preprocess_ae_stage.py index e96a527630..1cf7263d7c 100644 --- a/morpheus/stages/preprocess/preprocess_ae_stage.py +++ b/morpheus/stages/preprocess/preprocess_ae_stage.py @@ -22,9 +22,11 @@ from morpheus.cli.register_stage import register_stage from morpheus.config import Config from morpheus.config import PipelineModes +from morpheus.messages import ControlMessage from morpheus.messages import InferenceMemoryAE from morpheus.messages import MultiInferenceMessage from morpheus.messages import MultiMessage +from morpheus.messages import TensorMemory as CppTensorMemory from morpheus.messages.multi_ae_message import MultiAEMessage from morpheus.stages.inference.auto_encoder_inference_stage import MultiInferenceAEMessage from morpheus.stages.preprocess.preprocess_base_stage import PreprocessBaseStage @@ -58,14 +60,14 @@ def accepted_types(self) -> typing.Tuple: """ Returns accepted input types for this stage. """ - return (MultiAEMessage, ) + return (MultiAEMessage, ControlMessage) def supports_cpp_node(self): return False @staticmethod def pre_process_batch(x: MultiAEMessage, fea_len: int, - feature_columns: typing.List[str]) -> MultiInferenceAEMessage: + feature_columns: typing.List[str]) -> MultiInferenceAEMessage | ControlMessage: """ This function performs pre-processing for autoencoder. @@ -84,7 +86,42 @@ def pre_process_batch(x: MultiAEMessage, fea_len: int, Autoencoder inference message. """ + if isinstance(x, ControlMessage): + return PreprocessAEStage.process_control_message(x, fea_len, feature_columns) + if isinstance(x, MultiAEMessage): + return PreprocessAEStage.process_multi_ae_message(x, fea_len, feature_columns) + raise TypeError("Unsupported message type.") + @staticmethod + def process_control_message(x: ControlMessage, fea_len: int, feature_columns: typing.List[str]) -> ControlMessage: + meta_df = x.payload().get_data(x.payload().df.columns.intersection(feature_columns)) + + autoencoder = x.get_metadata("autoencoder") + scores_mean = x.get_metadata("train_scores_mean") + scores_std = x.get_metadata("train_scores_std") + count = len(meta_df.index) + + inputs = cp.zeros(meta_df.shape, dtype=cp.float32) + + if autoencoder is not None: + data = autoencoder.prepare_df(meta_df) + inputs = autoencoder.build_input_tensor(data) + inputs = cp.asarray(inputs.detach()) + count = inputs.shape[0] + + seg_ids = cp.zeros((count, 3), dtype=cp.uint32) + seg_ids[:, 0] = cp.arange(0, count, dtype=cp.uint32) + seg_ids[:, 2] = fea_len - 1 + + x.set_metadata("autoencoder", autoencoder) + x.set_metadata("train_scores_mean", scores_mean) + x.set_metadata("train_scores_std", scores_std) + x.tensors(CppTensorMemory(count=count, tensors={"input": inputs, "seq_ids": seg_ids})) + return x + + @staticmethod + def process_multi_ae_message(x: MultiAEMessage, fea_len: int, + feature_columns: typing.List[str]) -> MultiInferenceAEMessage: meta_df = x.get_meta(x.meta.df.columns.intersection(feature_columns)) autoencoder = x.model scores_mean = x.train_scores_mean @@ -117,7 +154,8 @@ def pre_process_batch(x: MultiAEMessage, fea_len: int, return infer_message - def _get_preprocess_fn(self) -> typing.Callable[[MultiMessage], MultiInferenceMessage]: + def _get_preprocess_fn( + self) -> typing.Callable[[MultiMessage | ControlMessage], MultiInferenceMessage | ControlMessage]: return partial(PreprocessAEStage.pre_process_batch, fea_len=self._fea_length, feature_columns=self._feature_columns) diff --git a/morpheus/stages/preprocess/preprocess_fil_stage.py b/morpheus/stages/preprocess/preprocess_fil_stage.py index cbfc6a581f..8ff369ebe7 100644 --- a/morpheus/stages/preprocess/preprocess_fil_stage.py +++ b/morpheus/stages/preprocess/preprocess_fil_stage.py @@ -67,7 +67,8 @@ def supports_cpp_node(self): return True @staticmethod - def pre_process_batch(x: MultiMessage, fea_len: int, fea_cols: typing.List[str]) -> MultiInferenceFILMessage: + def pre_process_batch(x: typing.Union[MultiMessage, ControlMessage], fea_len: int, + fea_cols: typing.List[str]) -> typing.Union[MultiMessage, ControlMessage]: """ For FIL category usecases, this function performs pre-processing. diff --git a/tests/stages/test_filter_detections_stage.py b/tests/stages/test_filter_detections_stage.py new file mode 100644 index 0000000000..4f9cb43cfe --- /dev/null +++ b/tests/stages/test_filter_detections_stage.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cupy as cp +import pytest + +import morpheus._lib.messages as _messages +from morpheus.common import FilterSource +from morpheus.messages import ControlMessage +from morpheus.messages import MultiResponseMessage +from morpheus.messages import ResponseMemory +from morpheus.messages.message_meta import MessageMeta +from morpheus.stages.postprocess.filter_detections_stage import FilterDetectionsStage + + +def _make_multi_response_message(df, probs): + df_ = df[0:len(probs)] + mem = ResponseMemory(count=len(df_), tensors={'probs': probs}) + + return MultiResponseMessage(meta=MessageMeta(df_), memory=mem) + + +def _make_control_message(df, probs): + df_ = df[0:len(probs)] + cm = ControlMessage() + cm.payload(MessageMeta(df_)) + cm.tensors(_messages.TensorMemory(count=len(df_), tensors={'probs': probs})) + + return cm + + +def test_constructor(config): + fds = FilterDetectionsStage(config) + assert fds.name == "filter" + + # Just ensure that we get a valid non-empty tuple + accepted_types = fds.accepted_types() + assert isinstance(accepted_types, tuple) + assert len(accepted_types) > 0 + + +@pytest.mark.use_cudf +def test_filter_copy(config, filter_probs_df): + fds = FilterDetectionsStage(config, threshold=0.5, filter_source=FilterSource.TENSOR) + + probs = cp.array([[0.1, 0.5, 0.3], [0.2, 0.3, 0.4]]) + mock_multi_response_message = _make_multi_response_message(filter_probs_df, probs) + mock_control_message = _make_control_message(filter_probs_df, probs) + + # All values are at or below the threshold so nothing should be returned + output_multi_response_message = fds._controller.filter_copy(mock_multi_response_message) + assert output_multi_response_message is None + output_control_message = fds._controller.filter_copy(mock_control_message) + assert output_control_message is None + + # Only one row has a value above the threshold + probs = cp.array([ + [0.2, 0.4, 0.3], + [0.1, 0.5, 0.8], + [0.2, 0.4, 0.3], + ]) + + mock_multi_response_message = _make_multi_response_message(filter_probs_df, probs) + output_multi_response_message = fds._controller.filter_copy(mock_multi_response_message) + assert output_multi_response_message.get_meta().to_cupy().tolist() == filter_probs_df.loc[1:1, :].to_cupy().tolist() + mock_control_message = _make_control_message(filter_probs_df, probs) + output_control_message = fds._controller.filter_copy(mock_control_message) + assert output_control_message.payload().get_data().to_cupy().tolist() == output_multi_response_message.get_meta( + ).to_cupy().tolist() + + # Two adjacent rows have a value above the threashold + probs = cp.array([ + [0.2, 0.4, 0.3], + [0.1, 0.2, 0.3], + [0.1, 0.5, 0.8], + [0.1, 0.9, 0.2], + [0.2, 0.4, 0.3], + ]) + + mock_multi_response_message = _make_multi_response_message(filter_probs_df, probs) + output_multi_response_message = fds._controller.filter_copy(mock_multi_response_message) + assert output_multi_response_message.get_meta().to_cupy().tolist() == filter_probs_df.loc[2:3, :].to_cupy().tolist() + mock_control_message = _make_control_message(filter_probs_df, probs) + output_control_message = fds._controller.filter_copy(mock_control_message) + assert output_control_message.payload().get_data().to_cupy().tolist() == output_multi_response_message.get_meta( + ).to_cupy().tolist() + + # Two non-adjacent rows have a value above the threashold + probs = cp.array([ + [0.2, 0.4, 0.3], + [0.1, 0.2, 0.3], + [0.1, 0.5, 0.8], + [0.4, 0.3, 0.2], + [0.1, 0.9, 0.2], + [0.2, 0.4, 0.3], + ]) + + mask = cp.zeros(len(filter_probs_df), dtype=cp.bool_) + mask[2] = True + mask[4] = True + + mock_multi_response_message = _make_multi_response_message(filter_probs_df, probs) + output_multi_response_message = fds._controller.filter_copy(mock_multi_response_message) + assert output_multi_response_message.get_meta().to_cupy().tolist() == filter_probs_df.loc[ + mask, :].to_cupy().tolist() + mock_control_message = _make_control_message(filter_probs_df, probs) + output_control_message = fds._controller.filter_copy(mock_control_message) + assert output_control_message.payload().get_data().to_cupy().tolist() == output_multi_response_message.get_meta( + ).to_cupy().tolist() + + +@pytest.mark.use_cudf +@pytest.mark.parametrize('do_copy', [True, False]) +@pytest.mark.parametrize('threshold', [0.1, 0.5, 0.8]) +@pytest.mark.parametrize('field_name', ['v1', 'v2', 'v3', 'v4']) +def test_filter_column(config, filter_probs_df, do_copy, threshold, field_name): + fds = FilterDetectionsStage(config, + threshold=threshold, + copy=do_copy, + filter_source=FilterSource.DATAFRAME, + field_name=field_name) + expected_df = filter_probs_df.to_pandas() + expected_df = expected_df[expected_df[field_name] > threshold] + + probs = cp.zeros([len(filter_probs_df), 3], 'float') + mock_multi_response_message = _make_multi_response_message(filter_probs_df, probs) + # All values are at or below the threshold + output_multi_response_message = fds._controller.filter_copy(mock_multi_response_message) + assert output_multi_response_message.get_meta().to_cupy().tolist() == expected_df.to_numpy().tolist() + mock_control_message = _make_control_message(filter_probs_df, probs) + output_control_message = fds._controller.filter_copy(mock_control_message) + assert output_control_message.payload().get_data().to_cupy().tolist() == output_multi_response_message.get_meta( + ).to_cupy().tolist() + + +@pytest.mark.use_cudf +def test_filter_slice(config, filter_probs_df): + fds = FilterDetectionsStage(config, threshold=0.5, filter_source=FilterSource.TENSOR) + + probs = cp.array([[0.1, 0.5, 0.3], [0.2, 0.3, 0.4]]) + mock_multi_response_message = _make_multi_response_message(filter_probs_df, probs) + + # All values are at or below the threshold + output_multi_response_messages = fds._controller.filter_slice(mock_multi_response_message) + assert len(output_multi_response_messages) == 0 + mock_control_message = _make_control_message(filter_probs_df, probs) + output_control_message = fds._controller.filter_slice(mock_control_message) + assert len(output_control_message) == len(output_multi_response_messages) + + # Only one row has a value above the threshold + probs = cp.array([ + [0.2, 0.4, 0.3], + [0.1, 0.5, 0.8], + [0.2, 0.4, 0.3], + ]) + + mock_multi_response_message: MultiResponseMessage = _make_multi_response_message(filter_probs_df, probs) + + output_multi_response_messages = fds._controller.filter_slice(mock_multi_response_message) + assert len(output_multi_response_messages) == 1 + assert output_multi_response_messages[0].get_meta().to_cupy().tolist() == filter_probs_df.loc[ + 1:1, :].to_cupy().tolist() + + mock_control_message = _make_control_message(filter_probs_df, probs) + output_control_message = fds._controller.filter_slice(mock_control_message) + assert len(output_control_message) == len(output_multi_response_messages) + assert output_control_message[0].payload().get_data().to_cupy().tolist( + ) == output_multi_response_messages[0].get_meta().to_cupy().tolist() + + # Two adjacent rows have a value above the threashold + probs = cp.array([ + [0.2, 0.4, 0.3], + [0.1, 0.2, 0.3], + [0.1, 0.5, 0.8], + [0.1, 0.9, 0.2], + [0.2, 0.4, 0.3], + ]) + + mock_multi_response_message = _make_multi_response_message(filter_probs_df, probs) + + output_multi_response_messages = fds._controller.filter_slice(mock_multi_response_message) + assert len(output_multi_response_messages) == 1 + assert output_multi_response_messages[0].offset == 2 + assert output_multi_response_messages[0].count == 2 + assert output_multi_response_messages[0].get_meta().to_cupy().tolist() == filter_probs_df.loc[ + 2:3, :].to_cupy().tolist() + + mock_control_message = _make_control_message(filter_probs_df, probs) + output_control_message = fds._controller.filter_slice(mock_control_message) + assert len(output_control_message) == len(output_multi_response_messages) + assert output_control_message[0].payload().get_data().to_cupy().tolist( + ) == output_multi_response_messages[0].get_meta().to_cupy().tolist() + + # Two non-adjacent rows have a value above the threashold + probs = cp.array([ + [0.2, 0.4, 0.3], + [0.1, 0.2, 0.3], + [0.1, 0.5, 0.8], + [0.4, 0.3, 0.2], + [0.1, 0.9, 0.2], + [0.2, 0.4, 0.3], + ]) + + mock_multi_response_message = _make_multi_response_message(filter_probs_df, probs) + + output_multi_response_messages = fds._controller.filter_slice(mock_multi_response_message) + assert len(output_multi_response_messages) == 2 + + # pylint: disable=unbalanced-tuple-unpacking + (multi_response_msg1, multi_response_msg2) = output_multi_response_messages + assert multi_response_msg1.offset == 2 + assert multi_response_msg1.count == 1 + + assert multi_response_msg2.offset == 4 + assert multi_response_msg2.count == 1 + + assert multi_response_msg1.get_meta().to_cupy().tolist() == filter_probs_df.loc[2:2, :].to_cupy().tolist() + assert multi_response_msg2.get_meta().to_cupy().tolist() == filter_probs_df.loc[4:4, :].to_cupy().tolist() + + mock_control_message = _make_control_message(filter_probs_df, probs) + output_control_message = fds._controller.filter_slice(mock_control_message) + assert len(output_control_message) == len(output_multi_response_messages) + (control_msg1, control_msg2) = output_control_message # pylint: disable=unbalanced-tuple-unpacking + assert control_msg1.payload().count == multi_response_msg1.count + assert control_msg2.payload().count == multi_response_msg2.count + + assert control_msg1.payload().get_data().to_cupy().tolist() == multi_response_msg1.get_meta().to_cupy().tolist() + assert control_msg2.payload().get_data().to_cupy().tolist() == multi_response_msg2.get_meta().to_cupy().tolist() diff --git a/tests/stages/test_generate_viz_frames_stage.py b/tests/stages/test_generate_viz_frames_stage.py new file mode 100644 index 0000000000..7cc79c2d31 --- /dev/null +++ b/tests/stages/test_generate_viz_frames_stage.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import cupy as cp +import typing_utils + +import cudf + +import morpheus._lib.messages as _messages +from morpheus.config import Config +from morpheus.messages import ControlMessage +from morpheus.messages import MessageMeta +from morpheus.messages import MultiResponseMessage +from morpheus.messages import ResponseMemory +from morpheus.stages.postprocess.generate_viz_frames_stage import GenerateVizFramesStage + + +def _make_multi_response_message(df, probs): + df_ = df[0:len(probs)] + mem = ResponseMemory(count=len(df_), tensors={'probs': probs}) + + return MultiResponseMessage(meta=MessageMeta(df_), memory=mem) + + +def _make_control_message(df, probs): + df_ = df[0:len(probs)] + cm = ControlMessage() + cm.payload(MessageMeta(df_)) + cm.tensors(_messages.TensorMemory(count=len(df_), tensors={'probs': probs})) + + return cm + + +def test_constructor(config: Config): + stage = GenerateVizFramesStage(config) + assert stage.name == "gen_viz" + + accepted_union = typing.Union[stage.accepted_types()] + assert typing_utils.issubtype(MultiResponseMessage, accepted_union) + assert typing_utils.issubtype(ControlMessage, accepted_union) + + +def test_process_control_message_and_multi_message(config: Config): + stage = GenerateVizFramesStage(config) + + df = cudf.DataFrame({ + "timestamp": [1616380971990, 1616380971991], + "src_ip": ["10.20.16.248", "10.244.0.1"], + "dest_ip": ["10.244.0.59", "10.244.0.25"], + "src_port": ["50410", "50410"], + "dest_port": ["80", "80"], + "data": ["a", "b"] + }) + + probs = cp.array([[0.1, 0.5, 0.3], [0.2, 0.3, 0.4]]) + mock_multi_response_message = _make_multi_response_message(df, probs) + mock_control_message = _make_control_message(df, probs) + + output_multi_response_message_list = stage._to_vis_df(mock_multi_response_message) + output_control_message_list = stage._to_vis_df(mock_control_message) + for output_multi_response_message, output_control_message in zip(output_multi_response_message_list, + output_control_message_list): + assert output_multi_response_message[1].equals(output_control_message[1]) diff --git a/tests/stages/test_ml_flow_drift_stage.py b/tests/stages/test_ml_flow_drift_stage.py new file mode 100644 index 0000000000..3f41683315 --- /dev/null +++ b/tests/stages/test_ml_flow_drift_stage.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +from unittest.mock import patch + +import cupy as cp +import pytest +import typing_utils + +import morpheus._lib.messages as _messages +from morpheus.messages import ControlMessage +from morpheus.messages import MultiResponseMessage +from morpheus.messages import ResponseMemory +from morpheus.messages.message_meta import MessageMeta +from morpheus.stages.postprocess.ml_flow_drift_stage import MLFlowDriftStage + + +def _make_multi_response_message(df, probs): + df_ = df[0:len(probs)] + mem = ResponseMemory(count=len(df_), tensors={'probs': probs}) + + return MultiResponseMessage(meta=MessageMeta(df_), count=len(df_), memory=mem) + + +def _make_control_message(df, probs): + df_ = df[0:len(probs)] + cm = ControlMessage() + cm.payload(MessageMeta(df_)) + cm.tensors(_messages.TensorMemory(count=len(df_), tensors={'probs': probs})) + + return cm + + +def test_constructor(config): + with patch("morpheus.stages.postprocess.ml_flow_drift_stage.mlflow.start_run"): + stage = MLFlowDriftStage(config) + assert stage.name == "mlflow_drift" + + accepted_union = typing.Union[stage.accepted_types()] + assert typing_utils.issubtype(MultiResponseMessage, accepted_union) + assert typing_utils.issubtype(ControlMessage, accepted_union) + + +@pytest.mark.use_cudf +@pytest.mark.use_python +def test_calc_drift(config, filter_probs_df): + with patch("morpheus.stages.postprocess.ml_flow_drift_stage.mlflow.start_run"): + labels = ["a", "b", "c"] + stage = MLFlowDriftStage(config, labels=labels, batch_size=1) + + probs = cp.array([[0.1, 0.5, 0.3], [0.2, 0.3, 0.4]]) + mock_multi_response_message = _make_multi_response_message(filter_probs_df, probs) + mock_control_message = _make_control_message(filter_probs_df, probs) + + expected_metrics = [{ + 'a': 0.9, 'b': 0.5, 'c': 0.7, 'total': 0.6999999999999998 + }, { + 'a': 0.8, 'b': 0.7, 'c': 0.6, 'total': 0.7000000000000001 + }] + + multi_response_message_metrics = [] + with patch("morpheus.stages.postprocess.ml_flow_drift_stage.mlflow.log_metrics") as mock_log_metrics: + stage._calc_drift(mock_multi_response_message) + for call_arg in mock_log_metrics.call_args_list: + multi_response_message_metrics.append(call_arg[0][0]) + assert multi_response_message_metrics == expected_metrics + + control_message_metrics = [] + with patch("morpheus.stages.postprocess.ml_flow_drift_stage.mlflow.log_metrics") as mock_log_metrics: + stage._calc_drift(mock_control_message) + for call_arg in mock_log_metrics.call_args_list: + control_message_metrics.append(call_arg[0][0]) + assert control_message_metrics == multi_response_message_metrics diff --git a/tests/stages/test_preprocess_ae_stage.py b/tests/stages/test_preprocess_ae_stage.py new file mode 100644 index 0000000000..d702abee54 --- /dev/null +++ b/tests/stages/test_preprocess_ae_stage.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import cupy as cp +import pytest +import typing_utils + +import cudf + +from morpheus.config import Config +from morpheus.config import ConfigAutoEncoder +from morpheus.messages import ControlMessage +from morpheus.messages import MessageMeta +from morpheus.messages import MultiAEMessage +from morpheus.stages.preprocess.preprocess_ae_stage import PreprocessAEStage + + +@pytest.fixture(name='config') +def fixture_config(config: Config): + config.feature_length = 256 + config.ae = ConfigAutoEncoder() + config.ae.feature_columns = ["data"] + yield config + + +def test_constructor(config: Config): + stage = PreprocessAEStage(config) + assert stage.name == "preprocess-ae" + + accepted_union = typing.Union[stage.accepted_types()] + assert typing_utils.issubtype(MultiAEMessage, accepted_union) + assert typing_utils.issubtype(ControlMessage, accepted_union) + + +def test_process_control_message_and_multi_message(config: Config): + stage = PreprocessAEStage(config) + + df = cudf.DataFrame({"data": ["a", "b", "c"]}) + meta = MessageMeta(df) + + input_multi_ae_message = MultiAEMessage(meta=meta, + mess_offset=0, + mess_count=3, + model=None, + train_scores_mean=0.0, + train_scores_std=1.0) + + output_multi_inference_ae_message = stage.pre_process_batch(input_multi_ae_message, + fea_len=256, + feature_columns=["data"]) + + input_control_message = ControlMessage() + input_control_message.payload(meta) + + output_control_message = stage.pre_process_batch(input_control_message, fea_len=256, feature_columns=["data"]) + + # Check if each tensor in the control message is equal to the corresponding tensor in the inference message + for tensor_key in output_control_message.tensors().tensor_names: + assert cp.array_equal(output_control_message.tensors().get_tensor(tensor_key), + getattr(output_multi_inference_ae_message, tensor_key)) diff --git a/tests/stages/test_preprocess_fil_stage.py b/tests/stages/test_preprocess_fil_stage.py index 638fcaa994..cdbe66dafe 100644 --- a/tests/stages/test_preprocess_fil_stage.py +++ b/tests/stages/test_preprocess_fil_stage.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import cupy as cp import pytest +import typing_utils import cudf @@ -40,9 +43,9 @@ def test_constructor(config: Config): assert stage._fea_length == config.feature_length assert stage.features == config.fil.feature_columns - accepted_types = stage.accepted_types() - assert isinstance(accepted_types, tuple) - assert len(accepted_types) > 0 + accepted_union = typing.Union[stage.accepted_types()] + assert typing_utils.issubtype(MultiMessage, accepted_union) + assert typing_utils.issubtype(ControlMessage, accepted_union) def test_process_control_message(config: Config): diff --git a/tests/stages/test_preprocess_nlp_stage.py b/tests/stages/test_preprocess_nlp_stage.py index 22fc99e04a..9c202a168d 100644 --- a/tests/stages/test_preprocess_nlp_stage.py +++ b/tests/stages/test_preprocess_nlp_stage.py @@ -13,11 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing from unittest.mock import Mock from unittest.mock import patch import cupy as cp import pytest +import typing_utils import cudf @@ -61,9 +63,9 @@ def test_constructor(config: Config): assert stage._do_lower_case is False assert stage._add_special_tokens is False - accepted_types = stage.accepted_types() - assert isinstance(accepted_types, tuple) - assert len(accepted_types) > 0 + accepted_union = typing.Union[stage.accepted_types()] + assert typing_utils.issubtype(MultiMessage, accepted_union) + assert typing_utils.issubtype(ControlMessage, accepted_union) @patch("morpheus.stages.preprocess.preprocess_nlp_stage.tokenize_text_series") diff --git a/tests/stages/test_timeseries_stage.py b/tests/stages/test_timeseries_stage.py new file mode 100644 index 0000000000..8babd0e752 --- /dev/null +++ b/tests/stages/test_timeseries_stage.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import cupy as cp +import pandas as pd +import pytest +import typing_utils + +import morpheus._lib.messages as _messages +from morpheus.config import Config +from morpheus.config import ConfigAutoEncoder +from morpheus.messages import ControlMessage +from morpheus.messages import MultiResponseAEMessage +from morpheus.messages import MultiResponseMessage +from morpheus.messages import ResponseMemory +from morpheus.messages.message_meta import MessageMeta +from morpheus.stages.postprocess.timeseries_stage import TimeSeriesStage + + +@pytest.fixture(name='config') +def fixture_config(config: Config): + config.feature_length = 256 + config.ae = ConfigAutoEncoder() + config.ae.feature_columns = ["data"] + config.ae.timestamp_column_name = "ts" + yield config + + +def _make_multi_response_ae_message(df, probs): + df_ = df[0:len(probs)] + mem = ResponseMemory(count=len(df_), tensors={'probs': probs}) + + return MultiResponseAEMessage(meta=MessageMeta(df_), count=len(df_), memory=mem, user_id="test_user_id") + + +def _make_control_message(df, probs): + df_ = df[0:len(probs)] + cm = ControlMessage() + cm.payload(MessageMeta(df_)) + cm.tensors(_messages.TensorMemory(count=len(df_), tensors={'probs': probs})) + cm.set_metadata("user_id", "test_user_id") + + return cm + + +def test_constructor(config): + stage = TimeSeriesStage(config) + assert stage.name == "timeseries" + + accepted_union = typing.Union[stage.accepted_types()] + assert typing_utils.issubtype(MultiResponseMessage, accepted_union) + assert typing_utils.issubtype(ControlMessage, accepted_union) + + +@pytest.mark.use_cudf +@pytest.mark.use_python +def test_call_timeseries_user(config): + stage = TimeSeriesStage(config) + + df = pd.DataFrame({"ts": pd.date_range(start='01-01-2022', periods=5)}) + probs = cp.array([[0.1, 0.5, 0.3], [0.2, 0.3, 0.4]]) + mock_multi_response_ae_message = _make_multi_response_ae_message(df, probs) + mock_control_message = _make_control_message(df, probs) + + assert stage._call_timeseries_user(mock_multi_response_ae_message)[0].user_id == "test_user_id" + assert stage._call_timeseries_user(mock_control_message)[0].get_metadata("user_id") == "test_user_id" diff --git a/tests/stages/test_validation_stage.py b/tests/stages/test_validation_stage.py new file mode 100644 index 0000000000..8f15799b63 --- /dev/null +++ b/tests/stages/test_validation_stage.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import pandas as pd +import typing_utils + +from morpheus.messages import ControlMessage +from morpheus.messages import MultiMessage +from morpheus.messages.message_meta import MessageMeta +from morpheus.stages.postprocess.validation_stage import ValidationStage + + +def _make_multi_message(df): + return MultiMessage(meta=MessageMeta(df)) + + +def _make_control_message(df): + cm = ControlMessage() + cm.payload(MessageMeta(df)) + + return cm + + +def test_constructor(config): + df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + stage = ValidationStage(config, val_file_name=df) + assert stage.name == "validation" + + # Just ensure that we get a valid non-empty tuple + accepted_union = typing.Union[stage.accepted_types()] + assert typing_utils.issubtype(MultiMessage, accepted_union) + assert typing_utils.issubtype(ControlMessage, accepted_union) + + +def test_do_comparison(config): + df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + stage = ValidationStage(config, val_file_name=df) + + mm = _make_multi_message(df) + cm = _make_control_message(df) + + stage._append_message(mm) + mm_results = stage.get_results(clear=True) + stage._append_message(cm) + cm_results = stage.get_results(clear=True) + assert mm_results == cm_results diff --git a/tests/test_add_classifications_stage.py b/tests/test_add_classifications_stage.py index 80091f3dc5..e3bbf70c1a 100755 --- a/tests/test_add_classifications_stage.py +++ b/tests/test_add_classifications_stage.py @@ -14,8 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import cupy as cp import pytest +import typing_utils import cudf @@ -43,10 +46,9 @@ def test_constructor(config: Config): assert stage._idx2label == {0: 'frogs', 1: 'lizards', 2: 'toads'} assert stage.name == "add-class" - # Just ensure that we get a valid non-empty tuple - accepted_types = stage.accepted_types() - assert isinstance(accepted_types, tuple) - assert len(accepted_types) > 0 + accepted_union = typing.Union[stage.accepted_types()] + assert typing_utils.issubtype(MultiResponseMessage, accepted_union) + assert typing_utils.issubtype(ControlMessage, accepted_union) def test_constructor_explicit_labels(config: Config): diff --git a/tests/test_add_scores_stage.py b/tests/test_add_scores_stage.py index e454a0e35f..0e347c7d78 100755 --- a/tests/test_add_scores_stage.py +++ b/tests/test_add_scores_stage.py @@ -14,8 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import cupy as cp import pytest +import typing_utils import cudf @@ -44,10 +47,9 @@ def test_constructor(config: Config): assert stage._idx2label == {0: 'frogs', 1: 'lizards', 2: 'toads'} assert stage.name == "add-scores" - # Just ensure that we get a valid non-empty tuple - accepted_types = stage.accepted_types() - assert isinstance(accepted_types, tuple) - assert len(accepted_types) > 0 + accepted_union = typing.Union[stage.accepted_types()] + assert typing_utils.issubtype(MultiResponseMessage, accepted_union) + assert typing_utils.issubtype(ControlMessage, accepted_union) def test_constructor_explicit_labels(config: Config): diff --git a/tests/test_filter_detections_stage.py b/tests/test_filter_detections_stage.py deleted file mode 100755 index ba8ed0591f..0000000000 --- a/tests/test_filter_detections_stage.py +++ /dev/null @@ -1,192 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cupy as cp -import pytest - -from morpheus.common import FilterSource -from morpheus.messages import MultiResponseMessage -from morpheus.messages import ResponseMemory -from morpheus.messages.message_meta import MessageMeta -from morpheus.stages.postprocess.filter_detections_stage import FilterDetectionsStage - - -def _make_message(df, probs): - df_ = df[0:len(probs)] - mem = ResponseMemory(count=len(df_), tensors={'probs': probs}) - return MultiResponseMessage(meta=MessageMeta(df_), memory=mem) - - -def test_constructor(config): - fds = FilterDetectionsStage(config) - assert fds.name == "filter" - - # Just ensure that we get a valid non-empty tuple - accepted_types = fds.accepted_types() - assert isinstance(accepted_types, tuple) - assert len(accepted_types) > 0 - - fds = FilterDetectionsStage(config, threshold=0.2) - assert fds._controller._threshold == 0.2 - - -@pytest.mark.use_cudf -@pytest.mark.use_python -def test_filter_copy(config, filter_probs_df): - fds = FilterDetectionsStage(config, threshold=0.5, filter_source=FilterSource.TENSOR) - - probs = cp.array([[0.1, 0.5, 0.3], [0.2, 0.3, 0.4]]) - mock_message = _make_message(filter_probs_df, probs) - - # All values are at or below the threshold so nothing should be returned - output_message = fds._controller.filter_copy(mock_message) - assert output_message is None - - # Only one row has a value above the threshold - probs = cp.array([ - [0.2, 0.4, 0.3], - [0.1, 0.5, 0.8], - [0.2, 0.4, 0.3], - ]) - - mock_message = _make_message(filter_probs_df, probs) - - output_message = fds._controller.filter_copy(mock_message) - assert output_message.get_meta().to_cupy().tolist() == filter_probs_df.loc[1:1, :].to_cupy().tolist() - - # Two adjacent rows have a value above the threashold - probs = cp.array([ - [0.2, 0.4, 0.3], - [0.1, 0.2, 0.3], - [0.1, 0.5, 0.8], - [0.1, 0.9, 0.2], - [0.2, 0.4, 0.3], - ]) - - mock_message = _make_message(filter_probs_df, probs) - - output_message = fds._controller.filter_copy(mock_message) - assert output_message.get_meta().to_cupy().tolist() == filter_probs_df.loc[2:3, :].to_cupy().tolist() - - # Two non-adjacent rows have a value above the threashold - probs = cp.array([ - [0.2, 0.4, 0.3], - [0.1, 0.2, 0.3], - [0.1, 0.5, 0.8], - [0.4, 0.3, 0.2], - [0.1, 0.9, 0.2], - [0.2, 0.4, 0.3], - ]) - - mock_message = _make_message(filter_probs_df, probs) - - output_message = fds._controller.filter_copy(mock_message) - mask = cp.zeros(len(filter_probs_df), dtype=cp.bool_) - mask[2] = True - mask[4] = True - assert output_message.get_meta().to_cupy().tolist() == filter_probs_df.loc[mask, :].to_cupy().tolist() - - -@pytest.mark.use_cudf -@pytest.mark.use_python -@pytest.mark.parametrize('do_copy', [True, False]) -@pytest.mark.parametrize('threshold', [0.1, 0.5, 0.8]) -@pytest.mark.parametrize('field_name', ['v1', 'v2', 'v3', 'v4']) -def test_filter_column(config, filter_probs_df, do_copy, threshold, field_name): - fds = FilterDetectionsStage(config, - threshold=threshold, - copy=do_copy, - filter_source=FilterSource.DATAFRAME, - field_name=field_name) - expected_df = filter_probs_df.to_pandas() - expected_df = expected_df[expected_df[field_name] > threshold] - - probs = cp.zeros([len(filter_probs_df), 3], 'float') - mock_message = _make_message(filter_probs_df, probs) - - # All values are at or below the threshold - output_message = fds._controller.filter_copy(mock_message) - - assert output_message.get_meta().to_cupy().tolist() == expected_df.to_numpy().tolist() - - -@pytest.mark.use_cudf -@pytest.mark.use_python -def test_filter_slice(config, filter_probs_df): - fds = FilterDetectionsStage(config, threshold=0.5, filter_source=FilterSource.TENSOR) - - probs = cp.array([[0.1, 0.5, 0.3], [0.2, 0.3, 0.4]]) - mock_message = _make_message(filter_probs_df, probs) - - # All values are at or below the threshold - output_messages = fds._controller.filter_slice(mock_message) - assert len(output_messages) == 0 - - # Only one row has a value above the threshold - probs = cp.array([ - [0.2, 0.4, 0.3], - [0.1, 0.5, 0.8], - [0.2, 0.4, 0.3], - ]) - - mock_message = _make_message(filter_probs_df, probs) - - output_messages = fds._controller.filter_slice(mock_message) - assert len(output_messages) == 1 - output_message = output_messages[0] - assert output_message.get_meta().to_cupy().tolist() == filter_probs_df.loc[1:1, :].to_cupy().tolist() - - # Two adjacent rows have a value above the threashold - probs = cp.array([ - [0.2, 0.4, 0.3], - [0.1, 0.2, 0.3], - [0.1, 0.5, 0.8], - [0.1, 0.9, 0.2], - [0.2, 0.4, 0.3], - ]) - - mock_message = _make_message(filter_probs_df, probs) - - output_messages = fds._controller.filter_slice(mock_message) - assert len(output_messages) == 1 - output_message = output_messages[0] - assert output_message.offset == 2 - assert output_message.count == 2 - assert output_message.get_meta().to_cupy().tolist() == filter_probs_df.loc[2:3, :].to_cupy().tolist() - - # Two non-adjacent rows have a value above the threashold - probs = cp.array([ - [0.2, 0.4, 0.3], - [0.1, 0.2, 0.3], - [0.1, 0.5, 0.8], - [0.4, 0.3, 0.2], - [0.1, 0.9, 0.2], - [0.2, 0.4, 0.3], - ]) - - mock_message = _make_message(filter_probs_df, probs) - - output_messages = fds._controller.filter_slice(mock_message) - assert len(output_messages) == 2 - (msg1, msg2) = output_messages # pylint: disable=unbalanced-tuple-unpacking - assert msg1.offset == 2 - assert msg1.count == 1 - - assert msg2.offset == 4 - assert msg2.count == 1 - - assert msg1.get_meta().to_cupy().tolist() == filter_probs_df.loc[2:2, :].to_cupy().tolist() - assert msg2.get_meta().to_cupy().tolist() == filter_probs_df.loc[4:4, :].to_cupy().tolist() diff --git a/tests/test_filter_detections_stage_pipe.py b/tests/test_filter_detections_stage_pipe.py index e90ea13b3f..15e36bd244 100755 --- a/tests/test_filter_detections_stage_pipe.py +++ b/tests/test_filter_detections_stage_pipe.py @@ -24,7 +24,9 @@ from _utils import assert_results from _utils.dataset_manager import DatasetManager from _utils.stages.conv_msg import ConvMsg +from morpheus.common import FilterSource from morpheus.config import Config +from morpheus.messages import ControlMessage from morpheus.messages import MessageMeta from morpheus.messages import MultiMessage from morpheus.messages import MultiResponseMessage @@ -92,6 +94,30 @@ def _test_filter_detections_stage_multi_segment_pipe(config: Config, dataset_pan assert_results(comp_stage.get_results()) +def _test_filter_detections_control_message_stage_multi_segment_pipe(config: Config, + dataset_pandas: DatasetManager, + copy: bool = True): + threshold = 0.75 + + input_df = dataset_pandas["filter_probs.csv"] + pipe = LinearPipeline(config) + pipe.set_source(InMemorySourceStage(config, [cudf.DataFrame(input_df)])) + pipe.add_segment_boundary(MessageMeta) + pipe.add_stage(DeserializeStage(config, message_type=ControlMessage)) + pipe.add_segment_boundary(data_type=ControlMessage) + pipe.add_stage(ConvMsg(config, message_type=ControlMessage)) + pipe.add_segment_boundary(ControlMessage) + pipe.add_stage(FilterDetectionsStage(config, threshold=threshold, copy=copy, filter_source=FilterSource.TENSOR)) + pipe.add_segment_boundary(ControlMessage) + pipe.add_stage(SerializeStage(config)) + pipe.add_segment_boundary(MessageMeta) + comp_stage = pipe.add_stage( + CompareDataFrameStage(config, build_expected(dataset_pandas["filter_probs.csv"], threshold))) + pipe.run() + + assert_results(comp_stage.get_results()) + + @pytest.mark.slow @pytest.mark.parametrize('order', ['F', 'C']) @pytest.mark.parametrize('pipeline_batch_size', [256, 1024, 2048]) @@ -109,3 +135,10 @@ def test_filter_detections_stage_pipe(config: Config, @pytest.mark.parametrize('do_copy', [True, False]) def test_filter_detections_stage_multi_segment_pipe(config: Config, dataset_pandas: DatasetManager, do_copy: bool): return _test_filter_detections_stage_multi_segment_pipe(config, dataset_pandas, do_copy) + + +@pytest.mark.parametrize('do_copy', [True, False]) +def test_filter_detections_control_message_stage_multi_segment_pipe(config: Config, + dataset_pandas: DatasetManager, + do_copy: bool): + return _test_filter_detections_control_message_stage_multi_segment_pipe(config, dataset_pandas, do_copy)