From 0db1dfe365adb7a1973b6302359bc18c158bf720 Mon Sep 17 00:00:00 2001 From: Mauro Cortellazzi Date: Fri, 21 Jun 2024 12:07:44 +0200 Subject: [PATCH] feat(sdk): align reference metrics business models with API (#11) --- .../apis/model_reference_dataset.py | 4 +- .../models/dataset_data_quality.py | 66 ++++---- .../models/dataset_model_quality.py | 24 +-- .../models/dataset_stats.py | 5 +- .../apis/model_reference_dataset_test.py | 149 ++++++++++++++++++ 5 files changed, 197 insertions(+), 51 deletions(-) diff --git a/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py index da485b35..2d99e2bb 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py @@ -102,7 +102,9 @@ def data_quality(self) -> Optional[DataQuality]: :return: The `DataQuality` if exists """ - def __callback(response: requests.Response) -> Optional[DataQuality]: + def __callback( + response: requests.Response, + ) -> tuple[JobStatus, Optional[DataQuality]]: try: response_json = response.json() job_status = JobStatus(response_json["jobStatus"]) diff --git a/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py index 2bb7f6e8..c437394b 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py @@ -1,43 +1,37 @@ from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_camel -from typing import List +from typing import List, Optional, Union class ClassMetrics(BaseModel): name: str count: int - percentage: float + percentage: Optional[float] = None - model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + model_config = ConfigDict(populate_by_name=True) class MedianMetrics(BaseModel): - perc_25: float - median: float - perc_75: float + perc_25: Optional[float] = None + median: Optional[float] = None + perc_75: Optional[float] = None - model_config = ConfigDict( - populate_by_name=True, alias_generator=to_camel, protected_namespaces=() - ) + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) class MissingValue(BaseModel): count: int - percentage: float + percentage: Optional[float] = None - model_config = ConfigDict( - populate_by_name=True, alias_generator=to_camel, protected_namespaces=() - ) + model_config = ConfigDict(populate_by_name=True) class ClassMedianMetrics(BaseModel): name: str - mean: float + mean: Optional[float] = None median_metrics: MedianMetrics - model_config = ConfigDict( - populate_by_name=True, alias_generator=to_camel, protected_namespaces=() - ) + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) class FeatureMetrics(BaseModel): @@ -45,33 +39,36 @@ class FeatureMetrics(BaseModel): type: str missing_value: MissingValue - model_config = ConfigDict( - populate_by_name=True, alias_generator=to_camel, protected_namespaces=() - ) + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class Histogram(BaseModel): + buckets: List[float] + reference_values: List[int] + current_values: Optional[List[int]] = None + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) class NumericalFeatureMetrics(FeatureMetrics): type: str = "numerical" - mean: float - std: float - min: float - max: float + mean: Optional[float] = None + std: Optional[float] = None + min: Optional[float] = None + max: Optional[float] = None median_metrics: MedianMetrics class_median_metrics: List[ClassMedianMetrics] + histogram: Histogram - model_config = ConfigDict( - populate_by_name=True, alias_generator=to_camel, protected_namespaces=() - ) + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) class CategoryFrequency(BaseModel): name: str count: int - frequency: float + frequency: Optional[float] = None - model_config = ConfigDict( - populate_by_name=True, alias_generator=to_camel, protected_namespaces=() - ) + model_config = ConfigDict(populate_by_name=True) class CategoricalFeatureMetrics(FeatureMetrics): @@ -79,9 +76,7 @@ class CategoricalFeatureMetrics(FeatureMetrics): category_frequency: List[CategoryFrequency] distinct_value: int - model_config = ConfigDict( - populate_by_name=True, alias_generator=to_camel, protected_namespaces=() - ) + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) class DataQuality(BaseModel): @@ -91,13 +86,12 @@ class DataQuality(BaseModel): class BinaryClassificationDataQuality(DataQuality): n_observations: int class_metrics: List[ClassMetrics] - feature_metrics: List[FeatureMetrics] + feature_metrics: List[Union[NumericalFeatureMetrics, CategoricalFeatureMetrics]] model_config = ConfigDict( arbitrary_types_allowed=True, populate_by_name=True, alias_generator=to_camel, - protected_namespaces=(), ) diff --git a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py index 9300a226..b5c8bab2 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py @@ -8,18 +8,18 @@ class ModelQuality(BaseModel): class BinaryClassificationModelQuality(ModelQuality): - f1: float - accuracy: float - precision: float - recall: float - f_measure: float - weighted_precision: float - weighted_recall: float - weighted_f_measure: float - weighted_true_positive_rate: float - weighted_false_positive_rate: float - true_positive_rate: float - false_positive_rate: float + f1: Optional[float] = None + accuracy: Optional[float] = None + precision: Optional[float] = None + recall: Optional[float] = None + f_measure: Optional[float] = None + weighted_precision: Optional[float] = None + weighted_recall: Optional[float] = None + weighted_f_measure: Optional[float] = None + weighted_true_positive_rate: Optional[float] = None + weighted_false_positive_rate: Optional[float] = None + true_positive_rate: Optional[float] = None + false_positive_rate: Optional[float] = None true_positive_count: int false_positive_count: int true_negative_count: int diff --git a/sdk/radicalbit_platform_sdk/models/dataset_stats.py b/sdk/radicalbit_platform_sdk/models/dataset_stats.py index d835ac78..115dd496 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_stats.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_stats.py @@ -1,3 +1,4 @@ +from typing import Optional from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_camel @@ -6,9 +7,9 @@ class DatasetStats(BaseModel): n_variables: int n_observations: int missing_cells: int - missing_cells_perc: float + missing_cells_perc: Optional[float] duplicate_rows: int - duplicate_rows_perc: float + duplicate_rows_perc: Optional[float] numeric: int categorical: int datetime: int diff --git a/sdk/tests/apis/model_reference_dataset_test.py b/sdk/tests/apis/model_reference_dataset_test.py index d6c3db8c..052ef3ef 100644 --- a/sdk/tests/apis/model_reference_dataset_test.py +++ b/sdk/tests/apis/model_reference_dataset_test.py @@ -273,3 +273,152 @@ def test_model_metrics_key_error(self): with self.assertRaises(ClientError): model_reference_dataset.model_quality() + + @responses.activate + def test_data_quality_ok(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_reference_dataset = ModelReferenceDataset( + base_url, + model_id, + ModelType.BINARY, + ReferenceFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality", + "status": 200, + "body": """{ + "datetime": "something_not_used", + "jobStatus": "SUCCEEDED", + "dataQuality": { + "nObservations": 200, + "classMetrics": [ + {"name": "classA", "count": 100, "percentage": 50.0}, + {"name": "classB", "count": 100, "percentage": 50.0} + ], + "featureMetrics": [ + { + "featureName": "age", + "type": "numerical", + "mean": 29.5, + "std": 5.2, + "min": 18, + "max": 45, + "medianMetrics": {"perc25": 25.0, "median": 29.0, "perc75": 34.0}, + "missingValue": {"count": 2, "percentage": 0.02}, + "classMedianMetrics": [ + { + "name": "classA", + "mean": 30.0, + "medianMetrics": {"perc25": 27.0, "median": 30.0, "perc75": 33.0} + }, + { + "name": "classB", + "mean": 29.0, + "medianMetrics": {"perc25": 24.0, "median": 28.0, "perc75": 32.0} + } + ], + "histogram": { + "buckets": [40.0, 45.0, 50.0, 55.0, 60.0], + "referenceValues": [50, 150, 200, 150, 50], + "currentValues": [45, 140, 210, 145, 60] + } + }, + { + "featureName": "gender", + "type": "categorical", + "distinctValue": 2, + "categoryFrequency": [ + {"name": "male", "count": 90, "frequency": 0.45}, + {"name": "female", "count": 110, "frequency": 0.55} + ], + "missingValue": {"count": 0, "percentage": 0.0} + } + ] + } + }""", + } + ) + + metrics = model_reference_dataset.data_quality() + + assert metrics.n_observations == 200 + assert len(metrics.class_metrics) == 2 + assert metrics.class_metrics[0].name == "classA" + assert metrics.class_metrics[0].count == 100 + assert metrics.class_metrics[0].percentage == 50.0 + assert len(metrics.feature_metrics) == 2 + assert metrics.feature_metrics[0].feature_name == "age" + assert metrics.feature_metrics[0].type == "numerical" + assert metrics.feature_metrics[0].mean == 29.5 + assert metrics.feature_metrics[1].feature_name == "gender" + assert metrics.feature_metrics[1].type == "categorical" + assert metrics.feature_metrics[1].distinct_value == 2 + assert model_reference_dataset.status() == JobStatus.SUCCEEDED + + @responses.activate + def test_data_quality_validation_error(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_reference_dataset = ModelReferenceDataset( + base_url, + model_id, + ModelType.BINARY, + ReferenceFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality", + "status": 200, + "body": '{"dataQuality": "wrong"}', + } + ) + + with self.assertRaises(ClientError): + model_reference_dataset.data_quality() + + @responses.activate + def test_data_quality_key_error(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_reference_dataset = ModelReferenceDataset( + base_url, + model_id, + ModelType.BINARY, + ReferenceFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality", + "status": 200, + "body": '{"wrong": "json"}', + } + ) + + with self.assertRaises(ClientError): + model_reference_dataset.data_quality()