From b86c5ecd7ce1e828f3e32079ba96e7bf68abd510 Mon Sep 17 00:00:00 2001 From: Stefano Zamboni <39366866+SteZamboni@users.noreply.github.com> Date: Fri, 21 Jun 2024 15:31:15 +0200 Subject: [PATCH] fix: removed all .count() and replaced with attribute in the class definition (#16) --- spark/jobs/utils/current.py | 29 ++++++++++++----------------- spark/jobs/utils/reference.py | 19 ++++++++----------- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/spark/jobs/utils/current.py b/spark/jobs/utils/current.py index 3d596353..41f1058c 100644 --- a/spark/jobs/utils/current.py +++ b/spark/jobs/utils/current.py @@ -67,12 +67,14 @@ def __init__( self.spark_session = spark_session self.current = current self.reference = reference + self.current_count = self.current.count() + self.reference_count = self.reference.count() self.model = model # FIXME use pydantic struct like data quality def calculate_statistics(self) -> dict[str, float]: number_of_variables = len(self.model.get_all_variables_current()) - number_of_observations = self.current.count() + number_of_observations = self.current_count number_of_numerical = len(self.model.get_numerical_variables_current()) number_of_categorical = len(self.model.get_categorical_variables_current()) number_of_datetime = len(self.model.get_datetime_variables_current()) @@ -188,12 +190,11 @@ def split_dict(dictionary): for x in numerical_features ] - # FIXME maybe don't self.current.count() missing_values_perc_agg = [ ( ( f.count(f.when(f.col(x).isNull() | f.isnan(x), x)) - / self.current.count() + / self.current_count ) * 100 ).alias(f"{x}-missing_values_perc") @@ -207,11 +208,10 @@ def split_dict(dictionary): for x in numerical_features ] - # FIXME don't use self.current.count() freq_agg = [ ( f.count(f.when(f.col(x).isNotNull() & ~f.isnan(x), True)) - / self.current.count() + / self.current_count ).alias(f"{x}-frequency") for x in numerical_features ] @@ -329,11 +329,10 @@ def split_dict(dictionary): for x in categorical_features ] - # FIXME maybe don't self.current.count() missing_values_perc_agg = [ - ( - (f.count(f.when(f.col(x).isNull(), x)) / self.current.count()) * 100 - ).alias(f"{x}-missing_values_perc") + ((f.count(f.when(f.col(x).isNull(), x)) / self.current_count) * 100).alias( + f"{x}-missing_values_perc" + ) for x in categorical_features ] @@ -351,7 +350,6 @@ def split_dict(dictionary): # FIXME by design this is not efficient # FIXME understand if we want to divide by whole or by number of not null - # FIXME don't use self.reference.count() count_distinct_categories = { column: dict( @@ -361,7 +359,7 @@ def split_dict(dictionary): .agg(*[f.count(check_not_null(column)).alias("count")]) .withColumn( "freq", - f.col("count") / self.current.count(), + f.col("count") / self.current_count, ) .toPandas() .set_index(column) @@ -393,7 +391,7 @@ def calculate_class_metrics(self) -> List[ClassMetrics]: number_of_true = number_true_and_false.get(1.0, 0) number_of_false = number_true_and_false.get(0.0, 0) - number_of_observations = self.current.count() + number_of_observations = self.current_count return [ ClassMetrics( @@ -415,7 +413,7 @@ def calculate_data_quality(self) -> BinaryClassDataQuality: if self.model.get_categorical_features(): feature_metrics.extend(self.calculate_data_quality_categorical()) return BinaryClassDataQuality( - n_observations=self.current.count(), + n_observations=self.current_count, class_metrics=self.calculate_class_metrics(), feature_metrics=feature_metrics, ) @@ -742,9 +740,6 @@ def calculate_drift(self): drift_result = dict() drift_result["feature_metrics"] = [] - ref_count = self.reference.count() - cur_count = self.current.count() - categorical_features = [ categorical.name for categorical in self.model.get_categorical_features() ] @@ -761,7 +756,7 @@ def calculate_drift(self): "type": "CHI2", }, } - if ref_count > 5 and cur_count > 5: + if self.reference_count > 5 and self.current_count > 5: result_tmp = chi2.test(column, column) feature_dict_to_append["drift_calc"]["value"] = float( result_tmp["pValue"] diff --git a/spark/jobs/utils/reference.py b/spark/jobs/utils/reference.py index 36dac983..5015d40e 100644 --- a/spark/jobs/utils/reference.py +++ b/spark/jobs/utils/reference.py @@ -54,6 +54,7 @@ class ReferenceMetricsService: def __init__(self, reference: DataFrame, model: ModelOut): self.model = model self.reference = reference + self.reference_count = self.reference.count() def __evaluate_binary_classification( self, dataset: DataFrame, metric_name: str @@ -99,7 +100,7 @@ def __calc_mc_metrics(self) -> dict[str, float]: # FIXME use pydantic struct like data quality def calculate_statistics(self) -> dict[str, float]: number_of_variables = len(self.model.get_all_variables_reference()) - number_of_observations = self.reference.count() + number_of_observations = self.reference_count number_of_numerical = len(self.model.get_numerical_variables_reference()) number_of_categorical = len(self.model.get_categorical_variables_reference()) number_of_datetime = len(self.model.get_datetime_variables_reference()) @@ -256,12 +257,11 @@ def split_dict(dictionary): for x in numerical_features ] - # FIXME maybe don't self.reference.count() missing_values_perc_agg = [ ( ( f.count(f.when(f.col(x).isNull() | f.isnan(x), x)) - / self.reference.count() + / self.reference_count ) * 100 ).alias(f"{x}-missing_values_perc") @@ -275,11 +275,10 @@ def split_dict(dictionary): for x in numerical_features ] - # FIXME don't use self.reference.count() freq_agg = [ ( f.count(f.when(f.col(x).isNotNull() & ~f.isnan(x), True)) - / self.reference.count() + / self.reference_count ).alias(f"{x}-frequency") for x in numerical_features ] @@ -403,10 +402,9 @@ def split_dict(dictionary): for x in categorical_features ] - # FIXME maybe don't self.reference.count() missing_values_perc_agg = [ ( - (f.count(f.when(f.col(x).isNull(), x)) / self.reference.count()) * 100 + (f.count(f.when(f.col(x).isNull(), x)) / self.reference_count) * 100 ).alias(f"{x}-missing_values_perc") for x in categorical_features ] @@ -425,7 +423,6 @@ def split_dict(dictionary): # FIXME by design this is not efficient # FIXME understand if we want to divide by whole or by number of not null - # FIXME don't use self.reference.count() count_distinct_categories = { column: dict( @@ -435,7 +432,7 @@ def split_dict(dictionary): .agg(*[f.count(check_not_null(column)).alias("count")]) .withColumn( "freq", - f.col("count") / self.reference.count(), + f.col("count") / self.reference_count, ) .toPandas() .set_index(column) @@ -467,7 +464,7 @@ def calculate_class_metrics(self) -> List[ClassMetrics]: number_of_true = number_true_and_false.get(1.0, 0) number_of_false = number_true_and_false.get(0.0, 0) - number_of_observations = self.reference.count() + number_of_observations = self.reference_count return [ ClassMetrics( @@ -489,7 +486,7 @@ def calculate_data_quality(self) -> BinaryClassDataQuality: if self.model.get_categorical_features(): feature_metrics.extend(self.calculate_data_quality_categorical()) return BinaryClassDataQuality( - n_observations=self.reference.count(), + n_observations=self.reference_count, class_metrics=self.calculate_class_metrics(), feature_metrics=feature_metrics, )