Skip to content

Commit

Permalink
Avoid cudf warning that reading a json file uses pandas, create fixtu…
Browse files Browse the repository at this point in the history
…res for the input_df
  • Loading branch information
dagardner-nv committed Jan 12, 2024
1 parent cb66ea0 commit 0913d14
Showing 1 changed file with 30 additions and 18 deletions.
48 changes: 30 additions & 18 deletions tests/test_column_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
from datetime import datetime
from datetime import timezone
from functools import partial

import numpy as np
Expand All @@ -27,6 +28,8 @@
import cudf

from _utils import TEST_DIRS
from morpheus.common import FileTypes
from morpheus.io.deserializers import read_file_to_df
from morpheus.utils.column_info import ColumnInfo
from morpheus.utils.column_info import CustomColumn
from morpheus.utils.column_info import DataFrameInputSchema
Expand All @@ -37,13 +40,26 @@
from morpheus.utils.nvt.schema_converters import create_and_attach_nvt_workflow
from morpheus.utils.schema_transforms import process_dataframe


@pytest.mark.use_python
def test_dataframe_input_schema_with_json_cols():
@pytest.fixture(name="_azure_ad_logs_pdf", scope="module")
def fixture__azure_ad_logs_pdf():
# Explicitly reading this in to ensure that lines=False.
# Using pandas since the C++ impl for read_file_to_df doesn't support parser_kwargs, this also avoids a warning
# that cudf.read_json uses pandas.read_json under the hood.
src_file = os.path.join(TEST_DIRS.tests_data_dir, "azure_ad_logs.json")
yield read_file_to_df(src_file, df_type='pandas', parser_kwargs={'lines': False})

input_df = cudf.read_json(src_file)
@pytest.fixture(name="azure_ad_logs_pdf", scope="function")
def fixture_azure_ad_logs_pdf(_azure_ad_logs_pdf: pd.DataFrame):
yield _azure_ad_logs_pdf.copy(deep=True)

@pytest.fixture(name="azure_ad_logs_cdf", scope="function")
def fixture_azure_ad_logs_cdf(azure_ad_logs_pdf: pd.DataFrame):
# cudf.from_pandas essentially does a deep copy, so we can use this to ensure that the source pandas df is not
# modified
yield cudf.from_pandas(azure_ad_logs_pdf)

@pytest.mark.use_python
def test_dataframe_input_schema_with_json_cols(azure_ad_logs_cdf: cudf.DataFrame):
raw_data_columns = [
'time',
'resourceId',
Expand All @@ -63,8 +79,8 @@ def test_dataframe_input_schema_with_json_cols():
'properties'
]

assert len(input_df.columns) == 16
assert list(input_df.columns) == raw_data_columns
assert len(azure_ad_logs_cdf.columns) == 16
assert list(azure_ad_logs_cdf.columns) == raw_data_columns

column_info = [
DateTimeColumn(name="timestamp", dtype='datetime64[ns]', input_name="time"),
Expand All @@ -89,28 +105,24 @@ def test_dataframe_input_schema_with_json_cols():

schema = DataFrameInputSchema(json_columns=["properties"], column_info=column_info)

df_processed_schema = process_dataframe(input_df, schema)
df_processed_schema = process_dataframe(azure_ad_logs_cdf, schema)
processed_df_cols = df_processed_schema.columns

assert len(input_df) == len(df_processed_schema)
assert len(azure_ad_logs_cdf) == len(df_processed_schema)
assert len(processed_df_cols) == len(column_info)
assert "timestamp" in processed_df_cols
assert "userId" in processed_df_cols
assert "time" not in processed_df_cols
assert "properties.userPrincipalName" not in processed_df_cols

nvt_workflow = create_and_attach_nvt_workflow(schema)
df_processed_workflow = process_dataframe(input_df, nvt_workflow)
df_processed_workflow = process_dataframe(azure_ad_logs_cdf, nvt_workflow)
assert df_processed_schema.equals(df_processed_workflow)


@pytest.mark.use_python
def test_dataframe_input_schema_without_json_cols():
src_file = os.path.join(TEST_DIRS.tests_data_dir, "azure_ad_logs.json")

input_df = pd.read_json(src_file)

assert len(input_df.columns) == 16
def test_dataframe_input_schema_without_json_cols(azure_ad_logs_pdf: pd.DataFrame):
assert len(azure_ad_logs_pdf.columns) == 16

column_info = [
DateTimeColumn(name="timestamp", dtype='datetime64[ns]', input_name="time"),
Expand All @@ -119,10 +131,10 @@ def test_dataframe_input_schema_without_json_cols():

schema = DataFrameInputSchema(column_info=column_info)

df_processed = process_dataframe(input_df, schema)
df_processed = process_dataframe(azure_ad_logs_pdf, schema)
processed_df_cols = df_processed.columns

assert len(input_df) == len(df_processed)
assert len(azure_ad_logs_pdf) == len(df_processed)
assert len(processed_df_cols) == len(column_info)
assert "timestamp" in processed_df_cols
assert "time" not in processed_df_cols
Expand Down Expand Up @@ -152,7 +164,7 @@ def test_dataframe_input_schema_without_json_cols():

# When trying to concat columns that don't exist in the dataframe, an exception is raised.
with pytest.raises(Exception):
process_dataframe(input_df, schema2)
process_dataframe(azure_ad_logs_pdf, schema2)


@pytest.mark.use_python
Expand Down

0 comments on commit 0913d14

Please sign in to comment.