diff --git a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py index 23623af4..4c262bed 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py @@ -8,7 +8,7 @@ from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( BinaryClassDrift, - BinaryClassificationDataQuality, + ClassificationDataQuality, CurrentBinaryClassificationModelQuality, CurrentFileUpload, CurrentMultiClassificationModelQuality, @@ -18,7 +18,6 @@ JobStatus, ModelQuality, ModelType, - MultiClassDataQuality, MultiClassDrift, RegressionDataQuality, RegressionDrift, @@ -188,17 +187,10 @@ def __callback( job_status = JobStatus(response_json['jobStatus']) if 'dataQuality' in response_json: match self.__model_type: - case ModelType.BINARY: - return ( - job_status, - BinaryClassificationDataQuality.model_validate( - response_json['dataQuality'] - ), - ) - case ModelType.MULTI_CLASS: + case ModelType.BINARY | ModelType.MULTI_CLASS: return ( job_status, - MultiClassDataQuality.model_validate( + ClassificationDataQuality.model_validate( response_json['dataQuality'] ), ) diff --git a/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py index d8f6d031..b8339d66 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py @@ -7,14 +7,13 @@ from radicalbit_platform_sdk.commons import invoke from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( - BinaryClassificationDataQuality, + ClassificationDataQuality, BinaryClassificationModelQuality, DataQuality, DatasetStats, JobStatus, ModelQuality, ModelType, - MultiClassDataQuality, MultiClassificationModelQuality, ReferenceFileUpload, RegressionDataQuality, @@ -114,17 +113,10 @@ def __callback( job_status = JobStatus(response_json['jobStatus']) if 'dataQuality' in response_json: match self.__model_type: - case ModelType.BINARY: - return ( - job_status, - BinaryClassificationDataQuality.model_validate( - response_json['dataQuality'] - ), - ) - case ModelType.MULTI_CLASS: + case ModelType.BINARY | ModelType.MULTI_CLASS: return ( job_status, - MultiClassDataQuality.model_validate( + ClassificationDataQuality.model_validate( response_json['dataQuality'] ), ) diff --git a/sdk/radicalbit_platform_sdk/models/__init__.py b/sdk/radicalbit_platform_sdk/models/__init__.py index 4cb44dac..6e7be01a 100644 --- a/sdk/radicalbit_platform_sdk/models/__init__.py +++ b/sdk/radicalbit_platform_sdk/models/__init__.py @@ -2,7 +2,7 @@ from .column_definition import ColumnDefinition from .data_type import DataType from .dataset_data_quality import ( - BinaryClassificationDataQuality, + ClassificationDataQuality, CategoricalFeatureMetrics, CategoryFrequency, ClassMedianMetrics, @@ -11,7 +11,6 @@ FeatureMetrics, MedianMetrics, MissingValue, - MultiClassDataQuality, NumericalFeatureMetrics, RegressionDataQuality, ) @@ -60,8 +59,7 @@ 'CurrentMultiClassificationModelQuality', 'RegressionModelQuality', 'DataQuality', - 'BinaryClassificationDataQuality', - 'MultiClassDataQuality', + 'ClassificationDataQuality', 'RegressionDataQuality', 'ClassMetrics', 'MedianMetrics', diff --git a/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py index 2f890443..075bde7b 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py @@ -84,7 +84,7 @@ class DataQuality(BaseModel): pass -class BinaryClassificationDataQuality(DataQuality): +class ClassificationDataQuality(DataQuality): n_observations: int class_metrics: List[ClassMetrics] feature_metrics: List[Union[NumericalFeatureMetrics, CategoricalFeatureMetrics]] @@ -96,9 +96,5 @@ class BinaryClassificationDataQuality(DataQuality): ) -class MultiClassDataQuality(DataQuality): - pass - - class RegressionDataQuality(DataQuality): pass diff --git a/sdk/tests/apis/model_current_dataset_test.py b/sdk/tests/apis/model_current_dataset_test.py index b5de421c..ac82a50d 100644 --- a/sdk/tests/apis/model_current_dataset_test.py +++ b/sdk/tests/apis/model_current_dataset_test.py @@ -8,21 +8,18 @@ from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( BinaryClassDrift, - BinaryClassificationDataQuality, + ClassificationDataQuality, CurrentBinaryClassificationModelQuality, CurrentFileUpload, + CurrentMultiClassificationModelQuality, DriftAlgorithm, JobStatus, ModelType, - MultiClassDataQuality, MultiClassDrift, RegressionDataQuality, RegressionDrift, RegressionModelQuality, ) -from radicalbit_platform_sdk.models.dataset_model_quality import ( - CurrentMultiClassificationModelQuality, -) class ModelCurrentDatasetTest(unittest.TestCase): @@ -401,7 +398,7 @@ def test_binary_class_data_quality_ok(self): metrics = model_current_dataset.data_quality() - assert isinstance(metrics, BinaryClassificationDataQuality) + assert isinstance(metrics, ClassificationDataQuality) assert metrics.n_observations == 200 assert len(metrics.class_metrics) == 2 @@ -440,16 +437,73 @@ def test_multi_class_data_quality_ok(self): url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/data-quality', status=200, body="""{ - "datetime": "something_not_used", - "jobStatus": "SUCCEEDED", - "dataQuality": {} - }""", + "datetime": "something_not_used", + "jobStatus": "SUCCEEDED", + "dataQuality": { + "nObservations": 200, + "classMetrics": [ + {"name": "classA", "count": 100, "percentage": 50.0}, + {"name": "classB", "count": 100, "percentage": 50.0} + ], + "featureMetrics": [ + { + "featureName": "age", + "type": "numerical", + "mean": 29.5, + "std": 5.2, + "min": 18, + "max": 45, + "medianMetrics": {"perc25": 25.0, "median": 29.0, "perc75": 34.0}, + "missingValue": {"count": 2, "percentage": 0.02}, + "classMedianMetrics": [ + { + "name": "classA", + "mean": 30.0, + "medianMetrics": {"perc25": 27.0, "median": 30.0, "perc75": 33.0} + }, + { + "name": "classB", + "mean": 29.0, + "medianMetrics": {"perc25": 24.0, "median": 28.0, "perc75": 32.0} + } + ], + "histogram": { + "buckets": [40.0, 45.0, 50.0, 55.0, 60.0], + "referenceValues": [50, 150, 200, 150, 50], + "currentValues": [45, 140, 210, 145, 60] + } + }, + { + "featureName": "gender", + "type": "categorical", + "distinctValue": 2, + "categoryFrequency": [ + {"name": "male", "count": 90, "frequency": 0.45}, + {"name": "female", "count": 110, "frequency": 0.55} + ], + "missingValue": {"count": 0, "percentage": 0.0} + } + ] + } + }""", ) metrics = model_current_dataset.data_quality() - assert isinstance(metrics, MultiClassDataQuality) - # TODO: add asserts to properties + assert isinstance(metrics, ClassificationDataQuality) + + assert metrics.n_observations == 200 + assert len(metrics.class_metrics) == 2 + assert metrics.class_metrics[0].name == 'classA' + assert metrics.class_metrics[0].count == 100 + assert metrics.class_metrics[0].percentage == 50.0 + assert len(metrics.feature_metrics) == 2 + assert metrics.feature_metrics[0].feature_name == 'age' + assert metrics.feature_metrics[0].type == 'numerical' + assert metrics.feature_metrics[0].mean == 29.5 + assert metrics.feature_metrics[1].feature_name == 'gender' + assert metrics.feature_metrics[1].type == 'categorical' + assert metrics.feature_metrics[1].distinct_value == 2 assert model_current_dataset.status() == JobStatus.SUCCEEDED @responses.activate diff --git a/sdk/tests/apis/model_reference_dataset_test.py b/sdk/tests/apis/model_reference_dataset_test.py index 1d0dd4fe..b303b42e 100644 --- a/sdk/tests/apis/model_reference_dataset_test.py +++ b/sdk/tests/apis/model_reference_dataset_test.py @@ -7,11 +7,10 @@ from radicalbit_platform_sdk.apis import ModelReferenceDataset from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( - BinaryClassificationDataQuality, + ClassificationDataQuality, BinaryClassificationModelQuality, JobStatus, ModelType, - MultiClassDataQuality, MultiClassificationModelQuality, ReferenceFileUpload, RegressionDataQuality, @@ -511,7 +510,7 @@ def test_binary_class_data_quality_ok(self): metrics = model_reference_dataset.data_quality() - assert isinstance(metrics, BinaryClassificationDataQuality) + assert isinstance(metrics, ClassificationDataQuality) assert metrics.n_observations == 200 assert len(metrics.class_metrics) == 2 @@ -549,16 +548,73 @@ def test_multi_class_data_quality_ok(self): url=f'{base_url}/api/models/{str(model_id)}/reference/data-quality', status=200, body="""{ - "datetime": "something_not_used", - "jobStatus": "SUCCEEDED", - "dataQuality": {} - }""", + "datetime": "something_not_used", + "jobStatus": "SUCCEEDED", + "dataQuality": { + "nObservations": 200, + "classMetrics": [ + {"name": "classA", "count": 100, "percentage": 50.0}, + {"name": "classB", "count": 100, "percentage": 50.0} + ], + "featureMetrics": [ + { + "featureName": "age", + "type": "numerical", + "mean": 29.5, + "std": 5.2, + "min": 18, + "max": 45, + "medianMetrics": {"perc25": 25.0, "median": 29.0, "perc75": 34.0}, + "missingValue": {"count": 2, "percentage": 0.02}, + "classMedianMetrics": [ + { + "name": "classA", + "mean": 30.0, + "medianMetrics": {"perc25": 27.0, "median": 30.0, "perc75": 33.0} + }, + { + "name": "classB", + "mean": 29.0, + "medianMetrics": {"perc25": 24.0, "median": 28.0, "perc75": 32.0} + } + ], + "histogram": { + "buckets": [40.0, 45.0, 50.0, 55.0, 60.0], + "referenceValues": [50, 150, 200, 150, 50], + "currentValues": [45, 140, 210, 145, 60] + } + }, + { + "featureName": "gender", + "type": "categorical", + "distinctValue": 2, + "categoryFrequency": [ + {"name": "male", "count": 90, "frequency": 0.45}, + {"name": "female", "count": 110, "frequency": 0.55} + ], + "missingValue": {"count": 0, "percentage": 0.0} + } + ] + } + }""", ) metrics = model_reference_dataset.data_quality() - assert isinstance(metrics, MultiClassDataQuality) - # TODO: add asserts to properties + assert isinstance(metrics, ClassificationDataQuality) + + assert metrics.n_observations == 200 + assert len(metrics.class_metrics) == 2 + assert metrics.class_metrics[0].name == 'classA' + assert metrics.class_metrics[0].count == 100 + assert metrics.class_metrics[0].percentage == 50.0 + assert len(metrics.feature_metrics) == 2 + assert metrics.feature_metrics[0].feature_name == 'age' + assert metrics.feature_metrics[0].type == 'numerical' + assert metrics.feature_metrics[0].mean == 29.5 + assert metrics.feature_metrics[1].feature_name == 'gender' + assert metrics.feature_metrics[1].type == 'categorical' + assert metrics.feature_metrics[1].distinct_value == 2 assert model_reference_dataset.status() == JobStatus.SUCCEEDED @responses.activate