diff --git a/examples/digital_fingerprinting/production/Dockerfile b/examples/digital_fingerprinting/production/Dockerfile index dbefaf8db3..5c65e6e752 100644 --- a/examples/digital_fingerprinting/production/Dockerfile +++ b/examples/digital_fingerprinting/production/Dockerfile @@ -63,7 +63,7 @@ RUN source activate morpheus &&\ # notebook v7 is incompatible with jupyter_contrib_nbextensions notebook=6 &&\ jupyter contrib nbextension install --user &&\ - pip install jupyterlab_nvdashboard==0.7.0 + pip install jupyterlab_nvdashboard==0.9 # Launch jupyter CMD ["jupyter-lab", "--ip=0.0.0.0", "--no-browser", "--allow-root"] diff --git a/examples/digital_fingerprinting/production/conda_env.yml b/examples/digital_fingerprinting/production/conda_env.yml index b19e4baa2f..777fb6dbaa 100644 --- a/examples/digital_fingerprinting/production/conda_env.yml +++ b/examples/digital_fingerprinting/production/conda_env.yml @@ -28,7 +28,7 @@ dependencies: - kfp - librdkafka - mlflow>=2.2.1,<3 - - nodejs=17.4.0 + - nodejs=18.* - nvtabular=23.06 - papermill - s3fs>=2023.6 diff --git a/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_file_batcher_stage.py b/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_file_batcher_stage.py index a3c61d149f..af8c4caa12 100644 --- a/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_file_batcher_stage.py +++ b/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_file_batcher_stage.py @@ -18,6 +18,7 @@ import warnings from collections import namedtuple from datetime import datetime +from datetime import timezone import fsspec import mrc @@ -124,6 +125,9 @@ def on_data(self, file_objects: fsspec.core.OpenFiles) -> typing.List[typing.Tup ts = self._date_conversion_func(file_object) # Exclude any files outside the time window + if (ts.tzinfo is None): + ts = ts.replace(tzinfo=timezone.utc) + if ((self._start_time is not None and ts < self._start_time) or (self._end_time is not None and ts > self._end_time)): continue @@ -171,7 +175,12 @@ def on_data(self, file_objects: fsspec.core.OpenFiles) -> typing.List[typing.Tup for _, period_df in resampled: - obj_list = fsspec.core.OpenFiles(period_df["objects"].to_list(), mode=file_objects.mode, fs=file_objects.fs) + file_list = period_df["objects"].to_list() + + if (len(file_list) == 0): + continue + + obj_list = fsspec.core.OpenFiles(file_list, mode=file_objects.mode, fs=file_objects.fs) output_batches.append((obj_list, n_groups)) diff --git a/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_inference_stage.py b/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_inference_stage.py index 79b6ea2da2..db7fb1b086 100644 --- a/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_inference_stage.py +++ b/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_inference_stage.py @@ -131,7 +131,7 @@ def on_data(self, message: MultiDFPMessage) -> MultiDFPMessage: return output_message def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject: - node = builder.make_node(self.unique_name, ops.map(self.on_data)) + node = builder.make_node(self.unique_name, ops.map(self.on_data), ops.filter(lambda x: x is not None)) builder.make_edge(input_node, node) # node.launch_options.pe_count = self._config.num_threads diff --git a/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_rolling_window_stage.py b/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_rolling_window_stage.py index 775853640b..36d4c2eba1 100644 --- a/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_rolling_window_stage.py +++ b/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_rolling_window_stage.py @@ -149,7 +149,7 @@ def _build_window(self, message: DFPMessageMeta) -> MultiDFPMessage: match = train_df[train_df["_row_hash"] == incoming_hash.iloc[0]] if (len(match) == 0): - raise RuntimeError("Invalid rolling window") + raise RuntimeError(f"Invalid rolling window for user {user_id}") first_row_idx = match.index[0].item() last_row_idx = train_df[train_df["_row_hash"] == incoming_hash.iloc[-1]].index[-1].item() diff --git a/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_training.py b/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_training.py index a346d1f4a7..efaaf89372 100644 --- a/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_training.py +++ b/examples/digital_fingerprinting/production/morpheus/dfp/stages/dfp_training.py @@ -66,7 +66,8 @@ def __init__(self, c: Config, model_kwargs: dict = None, epochs=30, validation_s "scaler": 'standard', # feature scaling method "min_cats": 1, # cut off for minority categories "progress_bar": False, - "device": "cuda" + "device": "cuda", + "patience": -1, } # Update the defaults diff --git a/examples/digital_fingerprinting/production/morpheus/dfp/stages/multi_file_source.py b/examples/digital_fingerprinting/production/morpheus/dfp/stages/multi_file_source.py index e74a1f99e0..0829ecabf3 100644 --- a/examples/digital_fingerprinting/production/morpheus/dfp/stages/multi_file_source.py +++ b/examples/digital_fingerprinting/production/morpheus/dfp/stages/multi_file_source.py @@ -60,13 +60,42 @@ def __init__( self._batch_size = c.pipeline_batch_size - self._filenames = filenames + self._filenames = self._expand_directories(filenames) + # Support directory expansion self._input_count = None self._max_concurrent = c.num_threads self._watch = watch self._watch_interval = watch_interval + @staticmethod + def _expand_directories(filenames: typing.List[str]) -> typing.List[str]: + """ + Expand to glob all files in any directories in the input filenames, + provided they actually exist. + + ex. /path/to/dir -> /path/to/dir/* + """ + updated_list = [] + for file_name in filenames: + # Skip any filenames that already contain wildcards + if '*' in file_name or '?' in file_name: + updated_list.append(file_name) + continue + + # Check if the file or directory actually exists + fs_spec = fsspec.filesystem(protocol='file') + if not fs_spec.exists(file_name): + updated_list.append(file_name) + continue + + if fs_spec.isdir(file_name): + updated_list.append(f"{file_name}/*") + else: + updated_list.append(file_name) + + return updated_list + @property def name(self) -> str: """Return the name of the stage.""" diff --git a/examples/digital_fingerprinting/production/morpheus/dfp/utils/regex_utils.py b/examples/digital_fingerprinting/production/morpheus/dfp/utils/regex_utils.py index e3afcb155d..ca306d9acc 100644 --- a/examples/digital_fingerprinting/production/morpheus/dfp/utils/regex_utils.py +++ b/examples/digital_fingerprinting/production/morpheus/dfp/utils/regex_utils.py @@ -17,7 +17,25 @@ # pylint: disable=invalid-name iso_date_regex_pattern = ( + # YYYY-MM-DD r"(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})" - r"T(?P\d{1,2})(:|_|\.)(?P\d{1,2})(:|_|\.)(?P\d{1,2})(?P\.\d{1,6})?Z") + # Start of time group (must match everything to add fractional days) + r"(?:T" + # HH + r"(?P\d{1,2})" + # : or _ or . + r"(?::|_|\.)" + # MM + r"(?P\d{1,2})" + # : or _ or . + r"(?::|_|\.)" + # SS + r"(?P\d{1,2})" + # Optional microseconds (don't capture the period) + r"(?:\.(?P\d{0,6}))?" + # End of time group (optional) + r")?" + # Optional Zulu time + r"(?PZ)?") iso_date_regex = re.compile(iso_date_regex_pattern) diff --git a/examples/digital_fingerprinting/production/morpheus/dfp_azure_pipeline.py b/examples/digital_fingerprinting/production/morpheus/dfp_azure_pipeline.py index 4785fc8fa3..f03359c871 100644 --- a/examples/digital_fingerprinting/production/morpheus/dfp_azure_pipeline.py +++ b/examples/digital_fingerprinting/production/morpheus/dfp_azure_pipeline.py @@ -61,6 +61,18 @@ from morpheus.utils.logger import configure_logging +def _file_type_name_to_enum(file_type: str) -> FileTypes: + """Converts a file type name to a FileTypes enum.""" + if (file_type == "JSON"): + return FileTypes.JSON + if (file_type == "CSV"): + return FileTypes.CSV + if (file_type == "PARQUET"): + return FileTypes.PARQUET + + return FileTypes.Auto + + @click.command() @click.option( "--train_users", @@ -125,6 +137,14 @@ "For example, to make a local cache of an s3 bucket, use `filecache::s3://mybucket/*`. " "Refer to fsspec documentation for list of possible options."), ) +@click.option("--file_type_override", + "-t", + type=click.Choice(["AUTO", "JSON", "CSV", "PARQUET"], case_sensitive=False), + default="JSON", + help="Override the detected file type. Values can be 'AUTO', 'JSON', 'CSV', or 'PARQUET'.", + callback=lambda _, + __, + value: None if value is None else _file_type_name_to_enum(value)) @click.option('--watch_inputs', type=bool, is_flag=True, @@ -148,6 +168,7 @@ type=str, default="DFP-azure-{user_id}", help="The MLflow model name template to use when logging models. ") +@click.option('--use_postproc_schema', is_flag=True, help='Assume that input data has already been preprocessed.') @click.option('--inference_detection_file_name', type=str, default="dfp_detections_azure.csv") def run_pipeline(train_users, skip_user: typing.Tuple[str], @@ -160,6 +181,8 @@ def run_pipeline(train_users, filter_threshold, mlflow_experiment_name_template, mlflow_model_name_template, + file_type_override, + use_postproc_schema, inference_detection_file_name, **kwargs): """Runs the DFP pipeline.""" @@ -219,58 +242,131 @@ def run_pipeline(train_users, config.ae.timestamp_column_name = "timestamp" # Specify the column names to ensure all data is uniform - source_column_info = [ - DateTimeColumn(name=config.ae.timestamp_column_name, dtype=datetime, input_name="time"), - RenameColumn(name=config.ae.userid_column_name, dtype=str, input_name="properties.userPrincipalName"), - RenameColumn(name="appDisplayName", dtype=str, input_name="properties.appDisplayName"), - ColumnInfo(name="category", dtype=str), - RenameColumn(name="clientAppUsed", dtype=str, input_name="properties.clientAppUsed"), - RenameColumn(name="deviceDetailbrowser", dtype=str, input_name="properties.deviceDetail.browser"), - RenameColumn(name="deviceDetaildisplayName", dtype=str, input_name="properties.deviceDetail.displayName"), - RenameColumn(name="deviceDetailoperatingSystem", - dtype=str, - input_name="properties.deviceDetail.operatingSystem"), - StringCatColumn(name="location", - dtype=str, - input_columns=[ - "properties.location.city", - "properties.location.countryOrRegion", - ], - sep=", "), - RenameColumn(name="statusfailureReason", dtype=str, input_name="properties.status.failureReason"), - ] - - source_schema = DataFrameInputSchema(json_columns=["properties"], column_info=source_column_info) - - # Preprocessing schema - preprocess_column_info = [ - ColumnInfo(name=config.ae.timestamp_column_name, dtype=datetime), - ColumnInfo(name=config.ae.userid_column_name, dtype=str), - ColumnInfo(name="appDisplayName", dtype=str), - ColumnInfo(name="clientAppUsed", dtype=str), - ColumnInfo(name="deviceDetailbrowser", dtype=str), - ColumnInfo(name="deviceDetaildisplayName", dtype=str), - ColumnInfo(name="deviceDetailoperatingSystem", dtype=str), - ColumnInfo(name="statusfailureReason", dtype=str), - - # Derived columns - IncrementColumn(name="logcount", - dtype=int, - input_name=config.ae.timestamp_column_name, - groupby_column=config.ae.userid_column_name), - DistinctIncrementColumn(name="locincrement", - dtype=int, - input_name="location", - groupby_column=config.ae.userid_column_name, - timestamp_column=config.ae.timestamp_column_name), - DistinctIncrementColumn(name="appincrement", - dtype=int, - input_name="appDisplayName", - groupby_column=config.ae.userid_column_name, - timestamp_column=config.ae.timestamp_column_name) - ] - - preprocess_schema = DataFrameInputSchema(column_info=preprocess_column_info, preserve_columns=["_batch_id"]) + if (use_postproc_schema): + + source_column_info = [ + ColumnInfo(name="autonomousSystemNumber", dtype=str), + ColumnInfo(name="location_geoCoordinates_latitude", dtype=float), + ColumnInfo(name="location_geoCoordinates_longitude", dtype=float), + ColumnInfo(name="resourceDisplayName", dtype=str), + ColumnInfo(name="travel_speed_kmph", dtype=float), + DateTimeColumn(name=config.ae.timestamp_column_name, dtype=datetime, input_name="time"), + ColumnInfo(name="appDisplayName", dtype=str), + ColumnInfo(name="clientAppUsed", dtype=str), + RenameColumn(name=config.ae.userid_column_name, dtype=str, input_name="userPrincipalName"), + RenameColumn(name="deviceDetailbrowser", dtype=str, input_name="deviceDetail_browser"), + RenameColumn(name="deviceDetaildisplayName", dtype=str, input_name="deviceDetail_displayName"), + RenameColumn(name="deviceDetailoperatingSystem", dtype=str, input_name="deviceDetail_operatingSystem"), + + # RenameColumn(name="location_country", dtype=str, input_name="location_countryOrRegion"), + ColumnInfo(name="location_city_state_country", dtype=str), + ColumnInfo(name="location_state_country", dtype=str), + ColumnInfo(name="location_country", dtype=str), + + # Non-features + ColumnInfo(name="is_corp_vpn", dtype=bool), + ColumnInfo(name="distance_km", dtype=float), + ColumnInfo(name="ts_delta_hour", dtype=float), + ] + source_schema = DataFrameInputSchema(column_info=source_column_info) + + preprocess_column_info = [ + ColumnInfo(name=config.ae.timestamp_column_name, dtype=datetime), + ColumnInfo(name=config.ae.userid_column_name, dtype=str), + + # Resource access + ColumnInfo(name="appDisplayName", dtype=str), + ColumnInfo(name="resourceDisplayName", dtype=str), + ColumnInfo(name="clientAppUsed", dtype=str), + + # Device detail + ColumnInfo(name="deviceDetailbrowser", dtype=str), + ColumnInfo(name="deviceDetaildisplayName", dtype=str), + ColumnInfo(name="deviceDetailoperatingSystem", dtype=str), + + # Location information + ColumnInfo(name="autonomousSystemNumber", dtype=str), + ColumnInfo(name="location_geoCoordinates_latitude", dtype=float), + ColumnInfo(name="location_geoCoordinates_longitude", dtype=float), + ColumnInfo(name="location_city_state_country", dtype=str), + ColumnInfo(name="location_state_country", dtype=str), + ColumnInfo(name="location_country", dtype=str), + + # Derived information + ColumnInfo(name="travel_speed_kmph", dtype=float), + + # Non-features + ColumnInfo(name="is_corp_vpn", dtype=bool), + ColumnInfo(name="distance_km", dtype=float), + ColumnInfo(name="ts_delta_hour", dtype=float), + ] + + preprocess_schema = DataFrameInputSchema(column_info=preprocess_column_info, preserve_columns=["_batch_id"]) + + exclude_from_training = [ + config.ae.userid_column_name, + config.ae.timestamp_column_name, + "is_corp_vpn", + "distance_km", + "ts_delta_hour", + ] + + config.ae.feature_columns = [ + name for (name, dtype) in preprocess_schema.output_columns if name not in exclude_from_training + ] + else: + source_column_info = [ + DateTimeColumn(name=config.ae.timestamp_column_name, dtype=datetime, input_name="time"), + RenameColumn(name=config.ae.userid_column_name, dtype=str, input_name="properties.userPrincipalName"), + RenameColumn(name="appDisplayName", dtype=str, input_name="properties.appDisplayName"), + ColumnInfo(name="category", dtype=str), + RenameColumn(name="clientAppUsed", dtype=str, input_name="properties.clientAppUsed"), + RenameColumn(name="deviceDetailbrowser", dtype=str, input_name="properties.deviceDetail.browser"), + RenameColumn(name="deviceDetaildisplayName", dtype=str, input_name="properties.deviceDetail.displayName"), + RenameColumn(name="deviceDetailoperatingSystem", + dtype=str, + input_name="properties.deviceDetail.operatingSystem"), + StringCatColumn(name="location", + dtype=str, + input_columns=[ + "properties.location.city", + "properties.location.countryOrRegion", + ], + sep=", "), + RenameColumn(name="statusfailureReason", dtype=str, input_name="properties.status.failureReason"), + ] + + source_schema = DataFrameInputSchema(json_columns=["properties"], column_info=source_column_info) + + # Preprocessing schema + preprocess_column_info = [ + ColumnInfo(name=config.ae.timestamp_column_name, dtype=datetime), + ColumnInfo(name=config.ae.userid_column_name, dtype=str), + ColumnInfo(name="appDisplayName", dtype=str), + ColumnInfo(name="clientAppUsed", dtype=str), + ColumnInfo(name="deviceDetailbrowser", dtype=str), + ColumnInfo(name="deviceDetaildisplayName", dtype=str), + ColumnInfo(name="deviceDetailoperatingSystem", dtype=str), + ColumnInfo(name="statusfailureReason", dtype=str), + + # Derived columns + IncrementColumn(name="logcount", + dtype=int, + input_name=config.ae.timestamp_column_name, + groupby_column=config.ae.userid_column_name), + DistinctIncrementColumn(name="locincrement", + dtype=int, + input_name="location", + groupby_column=config.ae.userid_column_name, + timestamp_column=config.ae.timestamp_column_name), + DistinctIncrementColumn(name="appincrement", + dtype=int, + input_name="appDisplayName", + groupby_column=config.ae.userid_column_name, + timestamp_column=config.ae.timestamp_column_name) + ] + + preprocess_schema = DataFrameInputSchema(column_info=preprocess_column_info, preserve_columns=["_batch_id"]) # Create a linear pipeline object pipeline = LinearPipeline(config) @@ -290,15 +386,17 @@ def run_pipeline(train_users, start_time=start_time, end_time=end_time)) + parser_kwargs = None + if (file_type_override == FileTypes.JSON): + parser_kwargs = {"lines": False, "orient": "records"} # Output is a list of fsspec files. Convert to DataFrames. This caches downloaded data pipeline.add_stage( - DFPFileToDataFrameStage(config, - schema=source_schema, - file_type=FileTypes.JSON, - parser_kwargs={ - "lines": False, "orient": "records" - }, - cache_dir=cache_dir)) + DFPFileToDataFrameStage( + config, + schema=source_schema, + file_type=file_type_override, + parser_kwargs=parser_kwargs, # TODO(Devin) probably should be configurable too + cache_dir=cache_dir)) pipeline.add_stage(MonitorStage(config, description="Input data rate")) @@ -328,7 +426,7 @@ def run_pipeline(train_users, if (is_training): # Finally, perform training which will output a model - pipeline.add_stage(DFPTraining(config, validation_size=0.10)) + pipeline.add_stage(DFPTraining(config, epochs=100, validation_size=0.15)) pipeline.add_stage(MonitorStage(config, description="Training rate", smoothing=0.001)) diff --git a/examples/digital_fingerprinting/production/morpheus/dfp_duo_pipeline.py b/examples/digital_fingerprinting/production/morpheus/dfp_duo_pipeline.py index a9d588fae1..9af7d2c2fe 100644 --- a/examples/digital_fingerprinting/production/morpheus/dfp_duo_pipeline.py +++ b/examples/digital_fingerprinting/production/morpheus/dfp_duo_pipeline.py @@ -62,6 +62,18 @@ from morpheus.utils.logger import configure_logging +def _file_type_name_to_enum(file_type: str) -> FileTypes: + """Converts a file type name to a FileTypes enum.""" + if (file_type == "JSON"): + return FileTypes.JSON + if (file_type == "CSV"): + return FileTypes.CSV + if (file_type == "PARQUET"): + return FileTypes.PARQUET + + return FileTypes.Auto + + @click.command() @click.option( "--train_users", @@ -126,6 +138,14 @@ "For example, to make a local cache of an s3 bucket, use `filecache::s3://mybucket/*`. " "Refer to fsspec documentation for list of possible options."), ) +@click.option("--file_type_override", + "-t", + type=click.Choice(["AUTO", "JSON", "CSV", "PARQUET"], case_sensitive=False), + default="JSON", + help="Override the detected file type. Values can be 'AUTO', 'JSON', 'CSV', or 'PARQUET'.", + callback=lambda _, + __, + value: None if value is None else _file_type_name_to_enum(value)) @click.option('--watch_inputs', type=bool, is_flag=True, @@ -160,6 +180,7 @@ def run_pipeline(train_users, filter_threshold, mlflow_experiment_name_template, mlflow_model_name_template, + file_type_override, **kwargs): """Runs the DFP pipeline.""" # To include the generic, we must be training all or generic @@ -283,14 +304,16 @@ def run_pipeline(train_users, start_time=start_time, end_time=end_time)) + parser_kwargs = None + if (file_type_override == FileTypes.JSON): + parser_kwargs = {"lines": False, "orient": "records"} + # Output is a list of fsspec files. Convert to DataFrames. This caches downloaded data pipeline.add_stage( DFPFileToDataFrameStage(config, schema=source_schema, - file_type=FileTypes.JSON, - parser_kwargs={ - "lines": False, "orient": "records" - }, + file_type=file_type_override, + parser_kwargs=parser_kwargs, cache_dir=cache_dir)) pipeline.add_stage(MonitorStage(config, description="Input data rate")) diff --git a/morpheus/service/milvus_vector_db_service.py b/morpheus/service/milvus_vector_db_service.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/morpheus/utils/file_utils.py b/morpheus/utils/file_utils.py index 3451ccba5f..4b0191939e 100644 --- a/morpheus/utils/file_utils.py +++ b/morpheus/utils/file_utils.py @@ -55,7 +55,6 @@ def get_data_file_path(data_filename: str) -> str: # If the file relative to our package exists, use that instead if (os.path.exists(value_abs_to_root)): - return value_abs_to_root return data_filename @@ -93,33 +92,36 @@ def date_extractor(file_object: fsspec.core.OpenFile, filename_regex: re.Pattern Returns ------- - int - Timestamp + datetime + Extracted timestamp """ - assert isinstance(file_object, fsspec.core.OpenFile) + if not isinstance(file_object, fsspec.core.OpenFile): + raise ValueError("file_object must be an instance of fsspec.core.OpenFile") file_path = file_object.path - # Match regex with the pathname since that can be more accurate + # Match regex with the pathname match = filename_regex.search(file_path) - if (match): - # Convert the regex match + if match: groups = match.groupdict() - if ("microsecond" in groups and groups["microsecond"] is not None): - groups["microsecond"] = int(float(groups["microsecond"]) * 1000000) + # Convert the microsecond value if present + if groups.get("microsecond"): + groups["microsecond"] = min(int(float(groups["microsecond"]) * 1000000), 999999) - groups = {key: int(value) for key, value in groups.items() if value is not None} + # Filter out any None values and convert the rest to integers + groups = {key: int(value) for key, value in groups.items() if value and key != "zulu"} + # Assign timezone groups["tzinfo"] = timezone.utc ts_object = datetime(**groups) else: - # Otherwise, fallback to the file modified (created?) time + # Fallback to the file modified time ts_object = file_object.fs.modified(file_object.path) - # Assume that its using the same timez - ts_object.replace(tzinfo=datetime.now().astimezone().tzinfo) + # Set the timezone to the current system's timezone + ts_object = ts_object.replace(tzinfo=datetime.now().astimezone().tzinfo) return ts_object diff --git a/tests/examples/digital_fingerprinting/test_dfp_file_batcher_stage.py b/tests/examples/digital_fingerprinting/test_dfp_file_batcher_stage.py index d60afb6323..ef965793cd 100644 --- a/tests/examples/digital_fingerprinting/test_dfp_file_batcher_stage.py +++ b/tests/examples/digital_fingerprinting/test_dfp_file_batcher_stage.py @@ -92,7 +92,7 @@ def test_on_data(config: Config, date_conversion_func: typing.Callable, file_spe from dfp.stages.dfp_file_batcher_stage import DFPFileBatcherStage stage = DFPFileBatcherStage(config, date_conversion_func) - assert stage.on_data([]) == [] + assert not stage.on_data([]) # With a one-day batch all files will fit in the batch batches = stage.on_data(file_specs) @@ -118,7 +118,7 @@ def test_on_data_two_batches(config: Config, expected_10_26_files = sorted(f.path for f in fsspec.open_files(os.path.join(test_data_dir, '*_2022-01-30_10-26*.json'))) - (batch1, batch2) = batches + (batch1, batch2) = batches[0], batches[1] # Make pylint happy. It doesn't like ambiguous unpacking assert sorted(f.path for f in batch1[0]) == expected_10_25_files assert batch1[1] == 2