Skip to content

Commit

Permalink
First round cherry pick to integrate pre and post-proc schema handling
Browse files Browse the repository at this point in the history
  • Loading branch information
drobison00 committed Oct 9, 2023
1 parent 1cda209 commit 18eb1f3
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import warnings
from collections import namedtuple
from datetime import datetime
from datetime import timezone

import fsspec
import mrc
Expand Down Expand Up @@ -121,6 +122,9 @@ def on_data(self, file_objects: fsspec.core.OpenFiles) -> typing.List[typing.Tup
ts = self._date_conversion_func(file_object)

# Exclude any files outside the time window
if (ts.tzinfo is None):
ts = ts.replace(tzinfo=timezone.utc)

if ((self._start_time is not None and ts < self._start_time)
or (self._end_time is not None and ts > self._end_time)):
continue
Expand Down Expand Up @@ -162,17 +166,25 @@ def on_data(self, file_objects: fsspec.core.OpenFiles) -> typing.List[typing.Tup
# Now group the rows by the period
resampled = df.resample(self._period)

n_groups = len(resampled)
n_groups = 0

output_batches = []

for _, period_df in resampled:

file_list = period_df["objects"].to_list()

if (len(file_list) == 0):
continue

obj_list = fsspec.core.OpenFiles(period_df["objects"].to_list(), mode=file_objects.mode, fs=file_objects.fs)

output_batches.append((obj_list, n_groups))
output_batches.append(obj_list)

n_groups += len(file_list)

return output_batches
# Append the batch count with each item
return [(x, n_groups) for x in output_batches]

def _build_single(self, builder: mrc.Builder, input_stream: StreamPair) -> StreamPair:
stream = builder.make_node(self.unique_name, ops.map(self.on_data), ops.flatten())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,18 @@ def on_data(self, message: MultiDFPMessage) -> MultiDFPMessage:
model_cache = self.get_model(user_id)

if (model_cache is None):
raise RuntimeError(f"Could not find model for user {user_id}")
logger.warning("Could not find model for user %s", user_id)
return None
# raise RuntimeError("Could not find model for user {}".format(user_id))

loaded_model = model_cache.load_model(self._client)

except Exception:
logger.exception("Error trying to get model", exc_info=True)
return None

post_model_time = time.time()
post_model_time = time.time()

results_df = loaded_model.get_results(df_user, return_abs=True)
results_df = loaded_model.get_results(df_user, return_abs=True)
except Exception as e: # TODO
logger.exception(f"({user_id}) Error trying to get model: {str(e)}")
return None

# Create an output message to allow setting meta
output_message = MultiDFPMessage(meta=message.meta,
Expand All @@ -128,7 +129,7 @@ def on_data(self, message: MultiDFPMessage) -> MultiDFPMessage:
return output_message

def _build_single(self, builder: mrc.Builder, input_stream: StreamPair) -> StreamPair:
node = builder.make_node(self.unique_name, ops.map(self.on_data))
node = builder.make_node(self.unique_name, ops.map(self.on_data), ops.filter(lambda x: x is not None))
builder.make_edge(input_stream[0], node)

# node.launch_options.pe_count = self._config.num_threads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _build_window(self, message: DFPMessageMeta) -> MultiDFPMessage:
match = train_df[train_df["_row_hash"] == incoming_hash.iloc[0]]

if (len(match) == 0):
raise RuntimeError("Invalid rolling window")
raise RuntimeError(f"Invalid rolling window for user {user_id}")

first_row_idx = match.index[0].item()
last_row_idx = train_df[train_df["_row_hash"] == incoming_hash.iloc[-1]].index[-1].item()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(self, c: Config, model_kwargs: dict = None, epochs=30, validation_s
"scaler": 'standard', # feature scaling method
"min_cats": 1, # cut off for minority categories
"progress_bar": False,
"device": "cuda"
"device": "cuda",
"patience": -1,
}

# Update the defaults
Expand Down Expand Up @@ -167,6 +168,9 @@ def on_data(self, message):

def _build_single(self, builder: mrc.Builder, input_stream: StreamPair) -> StreamPair:
stream = builder.make_node(self.unique_name, ops.map(self.on_data), ops.filter(lambda x: x is not None))

stream.launch_options.pe_count = self._config.num_threads

builder.make_edge(input_stream[0], stream)

return_type = input_stream[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,25 @@

# pylint: disable=invalid-name
iso_date_regex_pattern = (
# YYYY-MM-DD
r"(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})"
r"T(?P<hour>\d{1,2})(:|_|\.)(?P<minute>\d{1,2})(:|_|\.)(?P<second>\d{1,2})(?P<microsecond>\.\d{1,6})?Z")
# Start of time group (must match everything to add fractional days)
r"(?:T"
# HH
r"(?P<hour>\d{1,2})"
# : or _ or .
r"(?::|_|\.)"
# MM
r"(?P<minute>\d{1,2})"
# : or _ or .
r"(?::|_|\.)"
# SS
r"(?P<second>\d{1,2})"
# Optional microseconds (don't capture the period)
r"(?:\.(?P<microsecond>\d{0,6}))?"
# End of time group (optional)
r")?"
# Optional Zulu time
r"(?P<zulu>Z)?")

iso_date_regex = re.compile(iso_date_regex_pattern)
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ def _file_type_name_to_enum(file_type: str) -> FileTypes:
"For example, to make a local cache of an s3 bucket, use `filecache::s3://mybucket/*`. "
"Refer to fsspec documentation for list of possible options."),
)
@click.option(
"--file_type_override",
"-t",
type=click.Choice(["AUTO", "JSON", "CSV", "PARQUET"], case_sensitive=False),
default="JSON",
help="Override the detected file type. Values can be 'AUTO', 'JSON', 'CSV', or 'PARQUET'.",
callback=lambda _, __, value: None if value is None else _file_type_name_to_enum(value)
)
@click.option("--file_type_override",
"-t",
type=click.Choice(["AUTO", "JSON", "CSV", "PARQUET"], case_sensitive=False),
default="JSON",
help="Override the detected file type. Values can be 'AUTO', 'JSON', 'CSV', or 'PARQUET'.",
callback=lambda _,
__,
value: None if value is None else _file_type_name_to_enum(value))
@click.option('--watch_inputs',
type=bool,
is_flag=True,
Expand All @@ -168,6 +168,7 @@ def _file_type_name_to_enum(file_type: str) -> FileTypes:
type=str,
default="DFP-azure-{user_id}",
help="The MLflow model name template to use when logging models. ")
@click.option('--use_postproc_schema', is_flag=True, help='Assume that input data has already been preprocessed.')
def run_pipeline(train_users,
skip_user: typing.Tuple[str],
only_user: typing.Tuple[str],
Expand All @@ -180,6 +181,7 @@ def run_pipeline(train_users,
mlflow_experiment_name_template,
mlflow_model_name_template,
file_type_override,
use_postproc_schema,
**kwargs):
"""Runs the DFP pipeline."""
# To include the generic, we must be training all or generic
Expand All @@ -188,7 +190,7 @@ def run_pipeline(train_users,
# To include individual, we must be either training or inferring
include_individual = train_users != "generic"

# None indicates we aren't training anything
# None indicates we arent training anything
is_training = train_users != "none"

skip_users = list(skip_user)
Expand Down Expand Up @@ -238,58 +240,131 @@ def run_pipeline(train_users,
config.ae.timestamp_column_name = "timestamp"

# Specify the column names to ensure all data is uniform
source_column_info = [
DateTimeColumn(name=config.ae.timestamp_column_name, dtype=datetime, input_name="time"),
RenameColumn(name=config.ae.userid_column_name, dtype=str, input_name="properties.userPrincipalName"),
RenameColumn(name="appDisplayName", dtype=str, input_name="properties.appDisplayName"),
ColumnInfo(name="category", dtype=str),
RenameColumn(name="clientAppUsed", dtype=str, input_name="properties.clientAppUsed"),
RenameColumn(name="deviceDetailbrowser", dtype=str, input_name="properties.deviceDetail.browser"),
RenameColumn(name="deviceDetaildisplayName", dtype=str, input_name="properties.deviceDetail.displayName"),
RenameColumn(name="deviceDetailoperatingSystem",
dtype=str,
input_name="properties.deviceDetail.operatingSystem"),
StringCatColumn(name="location",
dtype=str,
input_columns=[
"properties.location.city",
"properties.location.countryOrRegion",
],
sep=", "),
RenameColumn(name="statusfailureReason", dtype=str, input_name="properties.status.failureReason"),
]

source_schema = DataFrameInputSchema(json_columns=["properties"], column_info=source_column_info)

# Preprocessing schema
preprocess_column_info = [
ColumnInfo(name=config.ae.timestamp_column_name, dtype=datetime),
ColumnInfo(name=config.ae.userid_column_name, dtype=str),
ColumnInfo(name="appDisplayName", dtype=str),
ColumnInfo(name="clientAppUsed", dtype=str),
ColumnInfo(name="deviceDetailbrowser", dtype=str),
ColumnInfo(name="deviceDetaildisplayName", dtype=str),
ColumnInfo(name="deviceDetailoperatingSystem", dtype=str),
ColumnInfo(name="statusfailureReason", dtype=str),

# Derived columns
IncrementColumn(name="logcount",
dtype=int,
input_name=config.ae.timestamp_column_name,
groupby_column=config.ae.userid_column_name),
DistinctIncrementColumn(name="locincrement",
dtype=int,
input_name="location",
groupby_column=config.ae.userid_column_name,
timestamp_column=config.ae.timestamp_column_name),
DistinctIncrementColumn(name="appincrement",
dtype=int,
input_name="appDisplayName",
groupby_column=config.ae.userid_column_name,
timestamp_column=config.ae.timestamp_column_name)
]

preprocess_schema = DataFrameInputSchema(column_info=preprocess_column_info, preserve_columns=["_batch_id"])
if (use_postproc_schema):

source_column_info = [
ColumnInfo(name="autonomousSystemNumber", dtype=str),
ColumnInfo(name="location_geoCoordinates_latitude", dtype=float),
ColumnInfo(name="location_geoCoordinates_longitude", dtype=float),
ColumnInfo(name="resourceDisplayName", dtype=str),
ColumnInfo(name="travel_speed_kmph", dtype=float),
DateTimeColumn(name=config.ae.timestamp_column_name, dtype=datetime, input_name="time"),
ColumnInfo(name="appDisplayName", dtype=str),
ColumnInfo(name="clientAppUsed", dtype=str),
RenameColumn(name=config.ae.userid_column_name, dtype=str, input_name="userPrincipalName"),
RenameColumn(name="deviceDetailbrowser", dtype=str, input_name="deviceDetail_browser"),
RenameColumn(name="deviceDetaildisplayName", dtype=str, input_name="deviceDetail_displayName"),
RenameColumn(name="deviceDetailoperatingSystem", dtype=str, input_name="deviceDetail_operatingSystem"),

# RenameColumn(name="location_country", dtype=str, input_name="location_countryOrRegion"),
ColumnInfo(name="location_city_state_country", dtype=str),
ColumnInfo(name="location_state_country", dtype=str),
ColumnInfo(name="location_country", dtype=str),

# Non-features
ColumnInfo(name="is_corp_vpn", dtype=bool),
ColumnInfo(name="distance_km", dtype=float),
ColumnInfo(name="ts_delta_hour", dtype=float),
]
source_schema = DataFrameInputSchema(column_info=source_column_info)

preprocess_column_info = [
ColumnInfo(name=config.ae.timestamp_column_name, dtype=datetime),
ColumnInfo(name=config.ae.userid_column_name, dtype=str),

# Resource access
ColumnInfo(name="appDisplayName", dtype=str),
ColumnInfo(name="resourceDisplayName", dtype=str),
ColumnInfo(name="clientAppUsed", dtype=str),

# Device detail
ColumnInfo(name="deviceDetailbrowser", dtype=str),
ColumnInfo(name="deviceDetaildisplayName", dtype=str),
ColumnInfo(name="deviceDetailoperatingSystem", dtype=str),

# Location information
ColumnInfo(name="autonomousSystemNumber", dtype=str),
ColumnInfo(name="location_geoCoordinates_latitude", dtype=float),
ColumnInfo(name="location_geoCoordinates_longitude", dtype=float),
ColumnInfo(name="location_city_state_country", dtype=str),
ColumnInfo(name="location_state_country", dtype=str),
ColumnInfo(name="location_country", dtype=str),

# Derived information
ColumnInfo(name="travel_speed_kmph", dtype=float),

# Non-features
ColumnInfo(name="is_corp_vpn", dtype=bool),
ColumnInfo(name="distance_km", dtype=float),
ColumnInfo(name="ts_delta_hour", dtype=float),
]

preprocess_schema = DataFrameInputSchema(column_info=preprocess_column_info, preserve_columns=["_batch_id"])

exclude_from_training = [
config.ae.userid_column_name,
config.ae.timestamp_column_name,
"is_corp_vpn",
"distance_km",
"ts_delta_hour",
]

config.ae.feature_columns = [
name for (name, dtype) in preprocess_schema.output_columns if name not in exclude_from_training
]
else:
source_column_info = [
DateTimeColumn(name=config.ae.timestamp_column_name, dtype=datetime, input_name="time"),
RenameColumn(name=config.ae.userid_column_name, dtype=str, input_name="properties.userPrincipalName"),
RenameColumn(name="appDisplayName", dtype=str, input_name="properties.appDisplayName"),
ColumnInfo(name="category", dtype=str),
RenameColumn(name="clientAppUsed", dtype=str, input_name="properties.clientAppUsed"),
RenameColumn(name="deviceDetailbrowser", dtype=str, input_name="properties.deviceDetail.browser"),
RenameColumn(name="deviceDetaildisplayName", dtype=str, input_name="properties.deviceDetail.displayName"),
RenameColumn(name="deviceDetailoperatingSystem",
dtype=str,
input_name="properties.deviceDetail.operatingSystem"),
StringCatColumn(name="location",
dtype=str,
input_columns=[
"properties.location.city",
"properties.location.countryOrRegion",
],
sep=", "),
RenameColumn(name="statusfailureReason", dtype=str, input_name="properties.status.failureReason"),
]

source_schema = DataFrameInputSchema(json_columns=["properties"], column_info=source_column_info)

# Preprocessing schema
preprocess_column_info = [
ColumnInfo(name=config.ae.timestamp_column_name, dtype=datetime),
ColumnInfo(name=config.ae.userid_column_name, dtype=str),
ColumnInfo(name="appDisplayName", dtype=str),
ColumnInfo(name="clientAppUsed", dtype=str),
ColumnInfo(name="deviceDetailbrowser", dtype=str),
ColumnInfo(name="deviceDetaildisplayName", dtype=str),
ColumnInfo(name="deviceDetailoperatingSystem", dtype=str),
ColumnInfo(name="statusfailureReason", dtype=str),

# Derived columns
IncrementColumn(name="logcount",
dtype=int,
input_name=config.ae.timestamp_column_name,
groupby_column=config.ae.userid_column_name),
DistinctIncrementColumn(name="locincrement",
dtype=int,
input_name="location",
groupby_column=config.ae.userid_column_name,
timestamp_column=config.ae.timestamp_column_name),
DistinctIncrementColumn(name="appincrement",
dtype=int,
input_name="appDisplayName",
groupby_column=config.ae.userid_column_name,
timestamp_column=config.ae.timestamp_column_name)
]

preprocess_schema = DataFrameInputSchema(column_info=preprocess_column_info, preserve_columns=["_batch_id"])

# Create a linear pipeline object
pipeline = LinearPipeline(config)
Expand All @@ -314,11 +389,12 @@ def run_pipeline(train_users,
parser_kwargs = {"lines": False, "orient": "records"}
# Output is a list of fsspec files. Convert to DataFrames. This caches downloaded data
pipeline.add_stage(
DFPFileToDataFrameStage(config,
schema=source_schema,
file_type=file_type_override,
parser_kwargs=parser_kwargs, # TODO(Devin) probably should be configurable too
cache_dir=cache_dir))
DFPFileToDataFrameStage(
config,
schema=source_schema,
file_type=file_type_override,
parser_kwargs=parser_kwargs, # TODO(Devin) probably should be configurable too
cache_dir=cache_dir))

pipeline.add_stage(MonitorStage(config, description="Input data rate"))

Expand Down Expand Up @@ -348,7 +424,7 @@ def run_pipeline(train_users,

if (is_training):
# Finally, perform training which will output a model
pipeline.add_stage(DFPTraining(config, validation_size=0.10))
pipeline.add_stage(DFPTraining(config, epochs=100, validation_size=0.15))

pipeline.add_stage(MonitorStage(config, description="Training rate", smoothing=0.001))

Expand Down
Loading

0 comments on commit 18eb1f3

Please sign in to comment.