diff --git a/examples/ransomware_detection/run.py b/examples/ransomware_detection/run.py index 4887e7ff1b..410cb4386a 100644 --- a/examples/ransomware_detection/run.py +++ b/examples/ransomware_detection/run.py @@ -21,9 +21,12 @@ from stages.create_features import CreateFeaturesRWStage from stages.preprocessing import PreprocessingRWStage +from morpheus.common import TypeId from morpheus.config import Config from morpheus.config import PipelineModes +from morpheus.messages import MessageMeta from morpheus.pipeline.linear_pipeline import LinearPipeline +from morpheus.pipeline.stage_decorator import stage from morpheus.stages.general.monitor_stage import MonitorStage from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage from morpheus.stages.input.appshield_source_stage import AppShieldSourceStage @@ -61,6 +64,12 @@ type=click.IntRange(min=1), help="Max batch size to use for the model.", ) +@click.option( + "--pipeline_batch_size", + default=1024, + type=click.IntRange(min=1), + help=("Internal batch size for the pipeline. Can be much larger than the model batch size."), +) @click.option( "--conf_file", type=click.STRING, @@ -98,18 +107,19 @@ default="./ransomware_detection_output.jsonlines", help="The path to the file where the inference output will be saved.", ) -def run_pipeline(debug, - num_threads, - n_dask_workers, - threads_per_dask_worker, - model_max_batch_size, - conf_file, - model_name, - server_url, - sliding_window, - input_glob, - watch_directory, - output_file): +def run_pipeline(debug: bool, + num_threads: int, + n_dask_workers: int, + threads_per_dask_worker: int, + model_max_batch_size: int, + pipeline_batch_size: int, + conf_file: str, + model_name: str, + server_url: str, + sliding_window: int, + input_glob: str, + watch_directory: bool, + output_file: str): if debug: configure_logging(log_level=logging.DEBUG) @@ -125,6 +135,7 @@ def run_pipeline(debug, # Below properties are specified by the command line. config.num_threads = num_threads config.model_max_batch_size = model_max_batch_size + config.pipeline_batch_size = pipeline_batch_size config.feature_length = snapshot_fea_length * sliding_window config.class_labels = ["pred", "score"] @@ -222,6 +233,18 @@ def run_pipeline(debug, # This stage logs the metrics (msg/sec) from the above stage. pipeline.add_stage(MonitorStage(config, description="Serialize rate")) + @stage(needed_columns={'timestamp_process': TypeId.STRING}) + def concat_columns(msg: MessageMeta) -> MessageMeta: + """ + This stage concatinates the timestamp and pid_process columns to create a unique field. + """ + with msg.mutable_dataframe() as df: + df['timestamp_process'] = df['timestamp'] + df['pid_process'] + + return msg + + pipeline.add_stage(concat_columns(config)) + # Add a write file stage. # This stage writes all messages to a file. pipeline.add_stage(WriteToFileStage(config, filename=output_file, overwrite=True))