Skip to content

Commit

Permalink
fix filter_detections_stage bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yczhang-nv committed May 16, 2024
1 parent a1590eb commit a8af1df
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
5 changes: 3 additions & 2 deletions morpheus/_lib/src/stages/filter_detections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ FilterDetectionsStage<InputT, OutputT>::subscribe_fn_t FilterDetectionsStage<Inp
else if constexpr (std::is_same_v<InputT, ControlMessage>)
{
auto meta = x->payload();
x->payload(meta->get_slice(slice_start, row));
output.on_next(x);
std::shared_ptr<ControlMessage> sliced_cm = std::make_shared<ControlMessage>(*x);
sliced_cm->payload(meta->get_slice(slice_start, row));
output.on_next(sliced_cm);
}
else
{
Expand Down
13 changes: 7 additions & 6 deletions morpheus/stages/postprocess/filter_detections_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,13 @@ def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) ->
self._controller.filter_source,
self._controller.field_name)

node = _stages.FilterDetectionsMultiMessageStage(builder,
self.unique_name,
self._controller.threshold,
self._copy,
self._controller.filter_source,
self._controller.field_name)
else:
node = _stages.FilterDetectionsMultiMessageStage(builder,
self.unique_name,
self._controller.threshold,
self._copy,
self._controller.filter_source,
self._controller.field_name)
else:

if self._copy:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_filter_detections_stage_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from _utils import assert_results
from _utils.dataset_manager import DatasetManager
from _utils.stages.conv_msg import ConvMsg
from morpheus.common import FilterSource
from morpheus.messages import ControlMessage
from morpheus.config import Config
from morpheus.messages import MessageMeta
from morpheus.messages import MultiMessage
Expand Down Expand Up @@ -92,6 +94,30 @@ def _test_filter_detections_stage_multi_segment_pipe(config: Config, dataset_pan
assert_results(comp_stage.get_results())


def _test_filter_detections_control_message_stage_multi_segment_pipe(config: Config,
dataset_pandas: DatasetManager,
copy: bool = True):
threshold = 0.75

input_df = dataset_pandas["filter_probs.csv"]
pipe = LinearPipeline(config)
pipe.set_source(InMemorySourceStage(config, [cudf.DataFrame(input_df)]))
pipe.add_segment_boundary(MessageMeta)
pipe.add_stage(DeserializeStage(config, message_type=ControlMessage))
pipe.add_segment_boundary(data_type=ControlMessage)
pipe.add_stage(ConvMsg(config, message_type=ControlMessage))
pipe.add_segment_boundary(ControlMessage)
pipe.add_stage(FilterDetectionsStage(config, threshold=threshold, copy=copy, filter_source=FilterSource.TENSOR))
pipe.add_segment_boundary(ControlMessage)
pipe.add_stage(SerializeStage(config))
pipe.add_segment_boundary(MessageMeta)
comp_stage = pipe.add_stage(
CompareDataFrameStage(config, build_expected(dataset_pandas["filter_probs.csv"], threshold)))
pipe.run()

assert_results(comp_stage.get_results())


@pytest.mark.slow
@pytest.mark.parametrize('order', ['F', 'C'])
@pytest.mark.parametrize('pipeline_batch_size', [256, 1024, 2048])
Expand All @@ -109,3 +135,11 @@ def test_filter_detections_stage_pipe(config: Config,
@pytest.mark.parametrize('do_copy', [True, False])
def test_filter_detections_stage_multi_segment_pipe(config: Config, dataset_pandas: DatasetManager, do_copy: bool):
return _test_filter_detections_stage_multi_segment_pipe(config, dataset_pandas, do_copy)


@pytest.mark.parametrize('do_copy', [True, False])
@pytest.mark.use_cpp
def test_filter_detections_control_message_stage_multi_segment_pipe(config: Config,
dataset_pandas: DatasetManager,
do_copy: bool):
return _test_filter_detections_control_message_stage_multi_segment_pipe(config, dataset_pandas, do_copy)

0 comments on commit a8af1df

Please sign in to comment.