Skip to content

Commit

Permalink
feat: added model quality metrics for current dataset for regression (#…
Browse files Browse the repository at this point in the history
…87)

* feat: added model quality metrics for current dataset for regression

* feat: fixed reference metrics with int values for prediction-target

* feat: changed var to variance in regression model quality

* feat: removed cast on ground-truth since it's there already
  • Loading branch information
SteZamboni authored Jul 8, 2024
1 parent e9d2b99 commit 6a63d26
Show file tree
Hide file tree
Showing 7 changed files with 3,599 additions and 10 deletions.
2 changes: 2 additions & 0 deletions spark/jobs/current_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@ def main(
)
statistics = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality(is_current=True)
model_quality = metrics_service.calculate_model_quality()
complete_record["STATISTICS"] = statistics.model_dump_json(
serialize_as_any=True
)
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)
complete_record["MODEL_QUALITY"] = model_quality

schema = StructType(
[
Expand Down
20 changes: 15 additions & 5 deletions spark/jobs/metrics/model_quality_regression_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@

class ModelQualityRegressionCalculator:
@staticmethod
def __eval_model_quality_metric(
def eval_model_quality_metric(
model: ModelOut,
dataframe: DataFrame,
dataframe_count: int,
metric_name: RegressionMetricType,
) -> float:
try:
dataframe = dataframe.withColumn(
model.outputs.prediction.name,
F.col(model.outputs.prediction.name).cast("float"),
)
match metric_name:
case RegressionMetricType.ADJ_R2:
# Source: https://medium.com/analytics-vidhya/adjusted-r-squared-formula-explanation-1ce033e25699
Expand All @@ -25,7 +29,7 @@ def __eval_model_quality_metric(
p: float = len(model.features)
n: float = dataframe_count
r2: float = (
ModelQualityRegressionCalculator.__eval_model_quality_metric(
ModelQualityRegressionCalculator.eval_model_quality_metric(
model, dataframe, dataframe_count, RegressionMetricType.R2
)
)
Expand All @@ -44,19 +48,25 @@ def __eval_model_quality_metric(
),
)
return _dataframe.agg({"mape": "avg"}).collect()[0][0] * 100
case RegressionMetricType.VAR:
return RegressionEvaluator(
metricName="var",
labelCol=model.target.name,
predictionCol=model.outputs.prediction.name,
).evaluate(dataframe)
case (
RegressionMetricType.MAE
| RegressionMetricType.MSE
| RegressionMetricType.RMSE
| RegressionMetricType.R2
| RegressionMetricType.VAR
):
return RegressionEvaluator(
metricName=metric_name.value,
labelCol=model.target.name,
predictionCol=model.outputs.prediction.name,
).evaluate(dataframe)
except Exception:
except Exception as e:
print(e)
return float("nan")

@staticmethod
Expand All @@ -65,7 +75,7 @@ def __calc_mq_metrics(
) -> ModelQualityRegression:
return ModelQualityRegression(
**{
metric_name.value: ModelQualityRegressionCalculator.__eval_model_quality_metric(
metric_name.value: ModelQualityRegressionCalculator.eval_model_quality_metric(
model,
dataframe,
dataframe_count,
Expand Down
4 changes: 2 additions & 2 deletions spark/jobs/models/regression_model_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class RegressionMetricType(str, Enum):
RMSE = "rmse"
R2 = "r2"
ADJ_R2 = "adj_r2"
VAR = "var"
VAR = "variance"


class ModelQualityRegression(BaseModel):
Expand All @@ -20,4 +20,4 @@ class ModelQualityRegression(BaseModel):
rmse: float
r2: float
adj_r2: float
var: float
variance: float
87 changes: 87 additions & 0 deletions spark/jobs/utils/current_regression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional

from pyspark.sql import SparkSession
import pyspark.sql.functions as F

from metrics.data_quality_calculator import DataQualityCalculator
from models.current_dataset import CurrentDataset
Expand All @@ -11,6 +12,10 @@
RegressionDataQuality,
)
from models.reference_dataset import ReferenceDataset
from models.regression_model_quality import ModelQualityRegression, RegressionMetricType
from metrics.model_quality_regression_calculator import ModelQualityRegressionCalculator
from .misc import create_time_format
from .models import Granularity


class CurrentMetricsRegressionService:
Expand All @@ -24,6 +29,88 @@ def __init__(
self.current = current
self.reference = reference

def calculate_model_quality(self) -> ModelQualityRegression:
metrics = dict()
metrics["global_metrics"] = ModelQualityRegressionCalculator.numerical_metrics(
model=self.current.model,
dataframe=self.current.current,
dataframe_count=self.current.current_count,
).model_dump(serialize_as_any=True)
metrics["grouped_metrics"] = (
self.calculate_regression_model_quality_group_by_timestamp()
)
return metrics

def calculate_regression_model_quality_group_by_timestamp(self):
if self.current.model.granularity == Granularity.WEEK:
dataset_with_group = self.current.current.select(
[
self.current.model.outputs.prediction.name,
self.current.model.target.name,
F.date_format(
F.to_timestamp(
F.date_sub(
F.next_day(
F.date_format(
self.current.model.timestamp.name,
create_time_format(
self.current.model.granularity
),
),
"sunday",
),
7,
)
),
"yyyy-MM-dd HH:mm:ss",
).alias("time_group"),
]
)
else:
dataset_with_group = self.current.current.select(
[
self.current.model.outputs.prediction.name,
self.current.model.target.name,
F.date_format(
F.to_timestamp(
F.date_format(
self.current.model.timestamp.name,
create_time_format(self.current.model.granularity),
)
),
"yyyy-MM-dd HH:mm:ss",
).alias("time_group"),
]
)

list_of_time_group = (
dataset_with_group.select("time_group")
.distinct()
.orderBy(F.col("time_group").asc())
.rdd.flatMap(lambda x: x)
.collect()
)
array_of_groups = [
dataset_with_group.where(F.col("time_group") == x)
for x in list_of_time_group
]

return {
metric_name.value: [
{
"timestamp": group,
"value": ModelQualityRegressionCalculator.eval_model_quality_metric(
self.current.model,
group_dataset,
group_dataset.count(),
metric_name,
),
}
for group, group_dataset in zip(list_of_time_group, array_of_groups)
]
for metric_name in RegressionMetricType
}

def calculate_data_quality_numerical(self) -> List[NumericalFeatureMetrics]:
return DataQualityCalculator.calculate_combined_data_quality_numerical(
model=self.current.model,
Expand Down
76 changes: 75 additions & 1 deletion spark/tests/regression_current_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def model():
)
target = ColumnDefinition(name="ground_truth", type=SupportedTypes.int)
timestamp = ColumnDefinition(name="dteday", type=SupportedTypes.datetime)
granularity = Granularity.HOUR
granularity = Granularity.MONTH
features = [
ColumnDefinition(name="season", type=SupportedTypes.int),
ColumnDefinition(name="yr", type=SupportedTypes.int),
Expand Down Expand Up @@ -504,3 +504,77 @@ def expected_data_quality():
},
],
}


def test_model_quality(spark_fixture, current_dataset, reference_dataset):
metrics_service = CurrentMetricsRegressionService(
spark_session=spark_fixture,
current=current_dataset,
reference=reference_dataset,
)

model_quality = metrics_service.calculate_model_quality()

assert not deepdiff.DeepDiff(
model_quality,
{
"global_metrics": {
"mae": 71.82559791564941,
"mape": 64.05699022707124,
"mse": 17820.506660010054,
"rmse": 133.49347047706138,
"r2": 0.8210737408739541,
"adj_r2": 0.7987079584831984,
"variance": 118288.02759401732,
},
"grouped_metrics": {
"mae": [
{"timestamp": "2011-01-01 00:00:00", "value": 35.67896665375808},
{"timestamp": "2011-02-01 00:00:00", "value": 89.13965238373855},
{"timestamp": "2011-03-01 00:00:00", "value": 91.54030847549438},
{"timestamp": "2011-04-01 00:00:00", "value": 63.352996826171875},
],
"mape": [
{"timestamp": "2011-01-01 00:00:00", "value": 106.34668638669385},
{"timestamp": "2011-02-01 00:00:00", "value": 50.266650033642435},
{"timestamp": "2011-03-01 00:00:00", "value": 53.63275529139244},
{"timestamp": "2011-04-01 00:00:00", "value": 14.766409719281478},
],
"mse": [
{"timestamp": "2011-01-01 00:00:00", "value": 2848.1117152678507},
{"timestamp": "2011-02-01 00:00:00", "value": 21631.812814960613},
{"timestamp": "2011-03-01 00:00:00", "value": 31460.34954782362},
{"timestamp": "2011-04-01 00:00:00", "value": 6540.166909402423},
],
"rmse": [
{"timestamp": "2011-01-01 00:00:00", "value": 53.36770292290882},
{"timestamp": "2011-02-01 00:00:00", "value": 147.07757414018158},
{"timestamp": "2011-03-01 00:00:00", "value": 177.37065582509305},
{"timestamp": "2011-04-01 00:00:00", "value": 80.87129842782556},
],
"r2": [
{"timestamp": "2011-01-01 00:00:00", "value": 0.17834457710460982},
{"timestamp": "2011-02-01 00:00:00", "value": 0.3895389519246505},
{"timestamp": "2011-03-01 00:00:00", "value": 0.7043715304337479},
{"timestamp": "2011-04-01 00:00:00", "value": 0.9678020649997567},
],
"adj_r2": [
{"timestamp": "2011-01-01 00:00:00", "value": -0.3533148141806426},
{
"timestamp": "2011-02-01 00:00:00",
"value": -0.005465255653516854,
},
{"timestamp": "2011-03-01 00:00:00", "value": 0.5417758721723092},
{"timestamp": "2011-04-01 00:00:00", "value": 1.1448907075010948},
],
"variance": [
{"timestamp": "2011-01-01 00:00:00", "value": 4720.867246089001},
{"timestamp": "2011-02-01 00:00:00", "value": 70942.48575413873},
{"timestamp": "2011-03-01 00:00:00", "value": 150522.0080596708},
{"timestamp": "2011-04-01 00:00:00", "value": 163422.9263027128},
],
},
},
ignore_order=True,
ignore_type_subclasses=True,
)
77 changes: 75 additions & 2 deletions spark/tests/regression_reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def reference_bike_nulls(spark_fixture, test_data_dir):
)


@pytest.fixture()
def reference_test_fe(spark_fixture, test_data_dir):
yield spark_fixture.read.csv(
f"{test_data_dir}/reference/regression/regression_reference_test_FE.csv",
header=True,
)


@pytest.fixture()
def expected_data_quality_json():
yield {
Expand Down Expand Up @@ -484,6 +492,50 @@ def reference_dataset_nulls(spark_fixture, reference_bike_nulls):
)


@pytest.fixture()
def reference_dataset_test_fe(spark_fixture, reference_test_fe):
output = OutputType(
prediction=ColumnDefinition(name="prediction", type=SupportedTypes.int),
prediction_proba=None,
output=[ColumnDefinition(name="prediction", type=SupportedTypes.int)],
)
target = ColumnDefinition(name="ground_truth", type=SupportedTypes.int)
timestamp = ColumnDefinition(name="timestamp", type=SupportedTypes.datetime)
granularity = Granularity.MONTH
features = [
ColumnDefinition(name="Sex", type=SupportedTypes.string),
ColumnDefinition(name="Length", type=SupportedTypes.float),
ColumnDefinition(name="Diameter", type=SupportedTypes.float),
ColumnDefinition(name="Height", type=SupportedTypes.float),
ColumnDefinition(name="Whole_weight", type=SupportedTypes.float),
ColumnDefinition(name="Shucked_weight", type=SupportedTypes.float),
ColumnDefinition(name="Viscera_weight", type=SupportedTypes.float),
ColumnDefinition(name="Shell_weight", type=SupportedTypes.float),
ColumnDefinition(name="pred_id", type=SupportedTypes.string),
]
model = ModelOut(
uuid=uuid.uuid4(),
name="regression model",
description="description",
model_type=ModelType.REGRESSION,
data_type=DataType.TABULAR,
timestamp=timestamp,
granularity=granularity,
outputs=output,
target=target,
features=features,
frameworks="framework",
algorithm="algorithm",
created_at=str(datetime.datetime.now()),
updated_at=str(datetime.datetime.now()),
)

yield ReferenceDataset(
raw_dataframe=reference_test_fe,
model=model,
)


def test_model_quality_metrics(reference_dataset):
assert reference_dataset.reference_count == 731

Expand All @@ -498,7 +550,7 @@ def test_model_quality_metrics(reference_dataset):
"rmse": 202.2319475218869,
"r2": 0.9131323648676931,
"adj_r2": 0.9118033746222753,
"var": 393448.3132709007,
"variance": 393448.3132709007,
}
)

Expand Down Expand Up @@ -565,7 +617,28 @@ def test_model_quality_metrics_nulls(reference_dataset_nulls):
"rmse": 202.40446933755922,
"r2": 0.9130200184348737,
"adj_r2": 0.9116855975182538,
"var": 393588.541292358,
"variance": 393588.541292358,
}
)

assert model_quality_metrics.model_dump() == expected


def test_model_quality_metrics_test_int(reference_dataset_test_fe):
regression_service = ReferenceMetricsRegressionService(
reference=reference_dataset_test_fe
)
model_quality_metrics = regression_service.calculate_model_quality()

expected = my_approx(
{
"mae": 1.9473369239976062,
"mape": 32.80513741749641,
"mse": 27.478755236385403,
"rmse": 5.242018240752831,
"r2": 0.5846200997960207,
"adj_r2": 0.5834981252756619,
"variance": 66.40816713170544,
}
)

Expand Down
Loading

0 comments on commit 6a63d26

Please sign in to comment.