From b86c5ecd7ce1e828f3e32079ba96e7bf68abd510 Mon Sep 17 00:00:00 2001 From: Stefano Zamboni <39366866+SteZamboni@users.noreply.github.com> Date: Fri, 21 Jun 2024 15:31:15 +0200 Subject: [PATCH 1/2] fix: removed all .count() and replaced with attribute in the class definition (#16) --- spark/jobs/utils/current.py | 29 ++++++++++++----------------- spark/jobs/utils/reference.py | 19 ++++++++----------- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/spark/jobs/utils/current.py b/spark/jobs/utils/current.py index 3d596353..41f1058c 100644 --- a/spark/jobs/utils/current.py +++ b/spark/jobs/utils/current.py @@ -67,12 +67,14 @@ def __init__( self.spark_session = spark_session self.current = current self.reference = reference + self.current_count = self.current.count() + self.reference_count = self.reference.count() self.model = model # FIXME use pydantic struct like data quality def calculate_statistics(self) -> dict[str, float]: number_of_variables = len(self.model.get_all_variables_current()) - number_of_observations = self.current.count() + number_of_observations = self.current_count number_of_numerical = len(self.model.get_numerical_variables_current()) number_of_categorical = len(self.model.get_categorical_variables_current()) number_of_datetime = len(self.model.get_datetime_variables_current()) @@ -188,12 +190,11 @@ def split_dict(dictionary): for x in numerical_features ] - # FIXME maybe don't self.current.count() missing_values_perc_agg = [ ( ( f.count(f.when(f.col(x).isNull() | f.isnan(x), x)) - / self.current.count() + / self.current_count ) * 100 ).alias(f"{x}-missing_values_perc") @@ -207,11 +208,10 @@ def split_dict(dictionary): for x in numerical_features ] - # FIXME don't use self.current.count() freq_agg = [ ( f.count(f.when(f.col(x).isNotNull() & ~f.isnan(x), True)) - / self.current.count() + / self.current_count ).alias(f"{x}-frequency") for x in numerical_features ] @@ -329,11 +329,10 @@ def split_dict(dictionary): for x in categorical_features ] - # FIXME maybe don't self.current.count() missing_values_perc_agg = [ - ( - (f.count(f.when(f.col(x).isNull(), x)) / self.current.count()) * 100 - ).alias(f"{x}-missing_values_perc") + ((f.count(f.when(f.col(x).isNull(), x)) / self.current_count) * 100).alias( + f"{x}-missing_values_perc" + ) for x in categorical_features ] @@ -351,7 +350,6 @@ def split_dict(dictionary): # FIXME by design this is not efficient # FIXME understand if we want to divide by whole or by number of not null - # FIXME don't use self.reference.count() count_distinct_categories = { column: dict( @@ -361,7 +359,7 @@ def split_dict(dictionary): .agg(*[f.count(check_not_null(column)).alias("count")]) .withColumn( "freq", - f.col("count") / self.current.count(), + f.col("count") / self.current_count, ) .toPandas() .set_index(column) @@ -393,7 +391,7 @@ def calculate_class_metrics(self) -> List[ClassMetrics]: number_of_true = number_true_and_false.get(1.0, 0) number_of_false = number_true_and_false.get(0.0, 0) - number_of_observations = self.current.count() + number_of_observations = self.current_count return [ ClassMetrics( @@ -415,7 +413,7 @@ def calculate_data_quality(self) -> BinaryClassDataQuality: if self.model.get_categorical_features(): feature_metrics.extend(self.calculate_data_quality_categorical()) return BinaryClassDataQuality( - n_observations=self.current.count(), + n_observations=self.current_count, class_metrics=self.calculate_class_metrics(), feature_metrics=feature_metrics, ) @@ -742,9 +740,6 @@ def calculate_drift(self): drift_result = dict() drift_result["feature_metrics"] = [] - ref_count = self.reference.count() - cur_count = self.current.count() - categorical_features = [ categorical.name for categorical in self.model.get_categorical_features() ] @@ -761,7 +756,7 @@ def calculate_drift(self): "type": "CHI2", }, } - if ref_count > 5 and cur_count > 5: + if self.reference_count > 5 and self.current_count > 5: result_tmp = chi2.test(column, column) feature_dict_to_append["drift_calc"]["value"] = float( result_tmp["pValue"] diff --git a/spark/jobs/utils/reference.py b/spark/jobs/utils/reference.py index 36dac983..5015d40e 100644 --- a/spark/jobs/utils/reference.py +++ b/spark/jobs/utils/reference.py @@ -54,6 +54,7 @@ class ReferenceMetricsService: def __init__(self, reference: DataFrame, model: ModelOut): self.model = model self.reference = reference + self.reference_count = self.reference.count() def __evaluate_binary_classification( self, dataset: DataFrame, metric_name: str @@ -99,7 +100,7 @@ def __calc_mc_metrics(self) -> dict[str, float]: # FIXME use pydantic struct like data quality def calculate_statistics(self) -> dict[str, float]: number_of_variables = len(self.model.get_all_variables_reference()) - number_of_observations = self.reference.count() + number_of_observations = self.reference_count number_of_numerical = len(self.model.get_numerical_variables_reference()) number_of_categorical = len(self.model.get_categorical_variables_reference()) number_of_datetime = len(self.model.get_datetime_variables_reference()) @@ -256,12 +257,11 @@ def split_dict(dictionary): for x in numerical_features ] - # FIXME maybe don't self.reference.count() missing_values_perc_agg = [ ( ( f.count(f.when(f.col(x).isNull() | f.isnan(x), x)) - / self.reference.count() + / self.reference_count ) * 100 ).alias(f"{x}-missing_values_perc") @@ -275,11 +275,10 @@ def split_dict(dictionary): for x in numerical_features ] - # FIXME don't use self.reference.count() freq_agg = [ ( f.count(f.when(f.col(x).isNotNull() & ~f.isnan(x), True)) - / self.reference.count() + / self.reference_count ).alias(f"{x}-frequency") for x in numerical_features ] @@ -403,10 +402,9 @@ def split_dict(dictionary): for x in categorical_features ] - # FIXME maybe don't self.reference.count() missing_values_perc_agg = [ ( - (f.count(f.when(f.col(x).isNull(), x)) / self.reference.count()) * 100 + (f.count(f.when(f.col(x).isNull(), x)) / self.reference_count) * 100 ).alias(f"{x}-missing_values_perc") for x in categorical_features ] @@ -425,7 +423,6 @@ def split_dict(dictionary): # FIXME by design this is not efficient # FIXME understand if we want to divide by whole or by number of not null - # FIXME don't use self.reference.count() count_distinct_categories = { column: dict( @@ -435,7 +432,7 @@ def split_dict(dictionary): .agg(*[f.count(check_not_null(column)).alias("count")]) .withColumn( "freq", - f.col("count") / self.reference.count(), + f.col("count") / self.reference_count, ) .toPandas() .set_index(column) @@ -467,7 +464,7 @@ def calculate_class_metrics(self) -> List[ClassMetrics]: number_of_true = number_true_and_false.get(1.0, 0) number_of_false = number_true_and_false.get(0.0, 0) - number_of_observations = self.reference.count() + number_of_observations = self.reference_count return [ ClassMetrics( @@ -489,7 +486,7 @@ def calculate_data_quality(self) -> BinaryClassDataQuality: if self.model.get_categorical_features(): feature_metrics.extend(self.calculate_data_quality_categorical()) return BinaryClassDataQuality( - n_observations=self.reference.count(), + n_observations=self.reference_count, class_metrics=self.calculate_class_metrics(), feature_metrics=feature_metrics, ) From 7513b39aed144d6c14cf60aaeabdffae1fa8b3f0 Mon Sep 17 00:00:00 2001 From: Mauro Cortellazzi Date: Fri, 21 Jun 2024 15:34:14 +0200 Subject: [PATCH 2/2] feat(sdk): implement get current dataset drift (#15) --- .../apis/model_current_dataset.py | 58 ++++++++- .../models/__init__.py | 16 +++ .../models/dataset_drift.py | 43 +++++++ sdk/tests/apis/model_current_dataset_test.py | 120 +++++++++++++++++- 4 files changed, 233 insertions(+), 4 deletions(-) create mode 100644 sdk/radicalbit_platform_sdk/models/dataset_drift.py diff --git a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py index 0c18f928..fc5f4c14 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py @@ -4,6 +4,8 @@ CurrentFileUpload, JobStatus, DatasetStats, + Drift, + BinaryClassDrift, ) from radicalbit_platform_sdk.errors import ClientError from pydantic import ValidationError @@ -96,9 +98,59 @@ def __callback( return self.__statistics - def drift(self): - # TODO: implement get drift - pass + def drift(self) -> Optional[Drift]: + """ + Get drift about the current dataset + + :return: The `Drift` if exists + """ + + def __callback( + response: requests.Response, + ) -> tuple[JobStatus, Optional[Drift]]: + try: + response_json = response.json() + job_status = JobStatus(response_json["jobStatus"]) + if "drift" in response_json: + if self.__model_type is ModelType.BINARY: + return ( + job_status, + BinaryClassDrift.model_validate(response_json["drift"]), + ) + else: + raise ClientError( + "Unable to parse get metrics for not binary models" + ) + else: + return job_status, None + except KeyError as _: + raise ClientError(f"Unable to parse response: {response.text}") + except ValidationError as _: + raise ClientError(f"Unable to parse response: {response.text}") + + match self.__status: + case JobStatus.ERROR: + self.__drift = None + case JobStatus.SUCCEEDED: + if self.__drift is None: + _, drift = invoke( + method="GET", + url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/drift", + valid_response_code=200, + func=__callback, + ) + self.__drift = drift + case JobStatus.IMPORTING: + status, drift = invoke( + method="GET", + url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/drift", + valid_response_code=200, + func=__callback, + ) + self.__status = status + self.__drift = drift + + return self.__drift def data_quality(self): # TODO: implement get data quality diff --git a/sdk/radicalbit_platform_sdk/models/__init__.py b/sdk/radicalbit_platform_sdk/models/__init__.py index 418b8ef0..c05b67ad 100644 --- a/sdk/radicalbit_platform_sdk/models/__init__.py +++ b/sdk/radicalbit_platform_sdk/models/__init__.py @@ -30,6 +30,15 @@ CategoryFrequency, CategoricalFeatureMetrics, ) +from .dataset_drift import ( + DriftAlgorithm, + FeatureDriftCalculation, + FeatureDrift, + Drift, + BinaryClassDrift, + MultiClassDrift, + RegressionDrift, +) from .column_definition import ColumnDefinition from .aws_credentials import AwsCredentials @@ -59,6 +68,13 @@ "NumericalFeatureMetrics", "CategoryFrequency", "CategoricalFeatureMetrics", + "DriftAlgorithm", + "FeatureDriftCalculation", + "FeatureDrift", + "Drift", + "BinaryClassDrift", + "MultiClassDrift", + "RegressionDrift", "PaginatedModelDefinitions", "ReferenceFileUpload", "CurrentFileUpload", diff --git a/sdk/radicalbit_platform_sdk/models/dataset_drift.py b/sdk/radicalbit_platform_sdk/models/dataset_drift.py new file mode 100644 index 00000000..c69524bb --- /dev/null +++ b/sdk/radicalbit_platform_sdk/models/dataset_drift.py @@ -0,0 +1,43 @@ +from enum import Enum +from typing import List, Optional + +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel + + +class DriftAlgorithm(str, Enum): + KS = "KS" + CHI2 = "CHI2" + + +class FeatureDriftCalculation(BaseModel): + type: DriftAlgorithm + value: Optional[float] = None + has_drift: bool + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class FeatureDrift(BaseModel): + feature_name: str + drift_calc: FeatureDriftCalculation + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class Drift(BaseModel): + pass + + +class BinaryClassDrift(Drift): + feature_metrics: List[FeatureDrift] + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class MultiClassDrift(Drift): + pass + + +class RegressionDrift(BaseModel): + pass diff --git a/sdk/tests/apis/model_current_dataset_test.py b/sdk/tests/apis/model_current_dataset_test.py index b547986b..9e7ef6a6 100644 --- a/sdk/tests/apis/model_current_dataset_test.py +++ b/sdk/tests/apis/model_current_dataset_test.py @@ -1,5 +1,5 @@ from radicalbit_platform_sdk.apis import ModelCurrentDataset -from radicalbit_platform_sdk.models import CurrentFileUpload, ModelType, JobStatus +from radicalbit_platform_sdk.models import CurrentFileUpload, ModelType, JobStatus, DriftAlgorithm from radicalbit_platform_sdk.errors import ClientError import responses import unittest @@ -129,3 +129,121 @@ def test_statistics_key_error(self): with self.assertRaises(ClientError): model_reference_dataset.statistics() + + @responses.activate + def test_drift_ok(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_reference_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + correlation_id_column="column", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift", + "status": 200, + "body": """{ + "jobStatus": "SUCCEEDED", + "drift": { + "featureMetrics": [ + { + "featureName": "gender", + "driftCalc": {"type": "CHI2", "value": 0.87, "hasDrift": true} + }, + { + "featureName": "city", + "driftCalc": {"type": "CHI2", "value": 0.12, "hasDrift": false} + }, + { + "featureName": "age", + "driftCalc": {"type": "KS", "value": 0.92, "hasDrift": true} + } + ] + } + }""", + } + ) + + drift = model_reference_dataset.drift() + + assert len(drift.feature_metrics) == 3 + assert drift.feature_metrics[1].feature_name == "city" + assert drift.feature_metrics[1].drift_calc.type == DriftAlgorithm.CHI2 + assert drift.feature_metrics[1].drift_calc.value == 0.12 + assert drift.feature_metrics[1].drift_calc.has_drift is False + assert drift.feature_metrics[2].feature_name == "age" + assert drift.feature_metrics[2].drift_calc.type == DriftAlgorithm.KS + assert drift.feature_metrics[2].drift_calc.value == 0.92 + assert drift.feature_metrics[2].drift_calc.has_drift is True + assert model_reference_dataset.status() == JobStatus.SUCCEEDED + + @responses.activate + def test_drift_validation_error(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_reference_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + correlation_id_column="column", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift", + "status": 200, + "body": '{"statistics": "wrong"}', + } + ) + + with self.assertRaises(ClientError): + model_reference_dataset.drift() + + @responses.activate + def test_drift_key_error(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_reference_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + correlation_id_column="column", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift", + "status": 200, + "body": '{"wrong": "json"}', + } + ) + + with self.assertRaises(ClientError): + model_reference_dataset.drift()