diff --git a/api/app/models/metrics/model_quality_dto.py b/api/app/models/metrics/model_quality_dto.py index e21ed735..44678383 100644 --- a/api/app/models/metrics/model_quality_dto.py +++ b/api/app/models/metrics/model_quality_dto.py @@ -9,7 +9,7 @@ from app.models.model_dto import ModelType -class BinaryClassModelQuality(BaseModel): +class MetricsBase(BaseModel): f1: Optional[float] = None accuracy: Optional[float] = None precision: Optional[float] = None @@ -22,10 +22,6 @@ class BinaryClassModelQuality(BaseModel): weighted_false_positive_rate: Optional[float] = None true_positive_rate: Optional[float] = None false_positive_rate: Optional[float] = None - true_positive_count: int - false_positive_count: int - true_negative_count: int - false_negative_count: int area_under_roc: Optional[float] = None area_under_pr: Optional[float] = None @@ -34,12 +30,19 @@ class BinaryClassModelQuality(BaseModel): ) +class BinaryClassificationModelQuality(MetricsBase): + true_positive_count: int + false_positive_count: int + true_negative_count: int + false_negative_count: int + + class Distribution(BaseModel): timestamp: str value: Optional[float] = None -class GroupedBinaryClassModelQuality(BaseModel): +class GroupedMetricsBase(BaseModel): f1: List[Distribution] accuracy: List[Distribution] precision: List[Distribution] @@ -55,21 +58,38 @@ class GroupedBinaryClassModelQuality(BaseModel): area_under_roc: Optional[List[Distribution]] = None area_under_pr: Optional[List[Distribution]] = None - model_config = ConfigDict( - populate_by_name=True, alias_generator=to_camel, protected_namespaces=() - ) + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class CurrentBinaryClassModelQuality(BaseModel): - global_metrics: BinaryClassModelQuality - grouped_metrics: GroupedBinaryClassModelQuality +class CurrentBinaryClassificationModelQuality(BaseModel): + global_metrics: BinaryClassificationModelQuality + grouped_metrics: GroupedMetricsBase - model_config = ConfigDict( - populate_by_name=True, alias_generator=to_camel, protected_namespaces=() - ) + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class ClassMetrics(BaseModel): + class_name: str + metrics: MetricsBase + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class MultiClassModelQuality(BaseModel): +class GlobalMetrics(MetricsBase): + confusion_matrix: List[List[int]] + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class MultiClassificationModelQuality(BaseModel): + classes: List[str] + class_metrics: List[ClassMetrics] + global_metrics: GlobalMetrics + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class CurrentMultiClassificationModelQuality(BaseModel): pass @@ -80,15 +100,14 @@ class RegressionModelQuality(BaseModel): class ModelQualityDTO(BaseModel): job_status: JobStatus model_quality: Optional[ - BinaryClassModelQuality - | CurrentBinaryClassModelQuality - | MultiClassModelQuality + BinaryClassificationModelQuality + | CurrentBinaryClassificationModelQuality + | MultiClassificationModelQuality + | CurrentMultiClassificationModelQuality | RegressionModelQuality ] - model_config = ConfigDict( - populate_by_name=True, alias_generator=to_camel, protected_namespaces=() - ) + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) @staticmethod def from_dict( @@ -121,9 +140,9 @@ def _create_model_quality( dataset_type: DatasetType, model_quality_data: Dict, ) -> ( - BinaryClassModelQuality - | CurrentBinaryClassModelQuality - | MultiClassModelQuality + BinaryClassificationModelQuality + | CurrentBinaryClassificationModelQuality + | MultiClassificationModelQuality | RegressionModelQuality ): """Create a specific model quality instance based on model type and dataset type.""" @@ -133,7 +152,10 @@ def _create_model_quality( model_quality_data=model_quality_data, ) if model_type == ModelType.MULTI_CLASS: - return MultiClassModelQuality(**model_quality_data) + return ModelQualityDTO._create_multiclass_model_quality( + dataset_type=dataset_type, + model_quality_data=model_quality_data, + ) if model_type == ModelType.REGRESSION: return RegressionModelQuality(**model_quality_data) raise MetricsInternalError(f'Invalid model type {model_type}') @@ -142,10 +164,22 @@ def _create_model_quality( def _create_binary_model_quality( dataset_type: DatasetType, model_quality_data: Dict, - ) -> BinaryClassModelQuality | CurrentBinaryClassModelQuality: + ) -> BinaryClassificationModelQuality | CurrentBinaryClassificationModelQuality: """Create a binary model quality instance based on dataset type.""" if dataset_type == DatasetType.REFERENCE: - return BinaryClassModelQuality(**model_quality_data) + return BinaryClassificationModelQuality(**model_quality_data) + if dataset_type == DatasetType.CURRENT: + return CurrentBinaryClassificationModelQuality(**model_quality_data) + raise MetricsInternalError(f'Invalid dataset type {dataset_type}') + + @staticmethod + def _create_multiclass_model_quality( + dataset_type: DatasetType, + model_quality_data: Dict, + ) -> MultiClassificationModelQuality | CurrentMultiClassificationModelQuality: + """Create a multiclass model quality instance based on dataset type.""" + if dataset_type == DatasetType.REFERENCE: + return MultiClassificationModelQuality(**model_quality_data) if dataset_type == DatasetType.CURRENT: - return CurrentBinaryClassModelQuality(**model_quality_data) + return CurrentMultiClassificationModelQuality(**model_quality_data) raise MetricsInternalError(f'Invalid dataset type {dataset_type}') diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index c3c25f92..bd27b34f 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -128,7 +128,7 @@ def get_sample_current_dataset( 'datetime': 1, } -model_quality_dict = { +model_quality_base_dict = { 'f1': None, 'accuracy': 0.90, 'precision': 0.88, @@ -141,16 +141,20 @@ def get_sample_current_dataset( 'weightedFalsePositiveRate': 0.10, 'truePositiveRate': 0.87, 'falsePositiveRate': 0.13, + 'areaUnderRoc': 0.92, + 'areaUnderPr': 0.91, +} + +binary_model_quality_dict = { 'truePositiveCount': 870, 'falsePositiveCount': 130, 'trueNegativeCount': 820, 'falseNegativeCount': 180, - 'areaUnderRoc': 0.92, - 'areaUnderPr': 0.91, + **model_quality_base_dict, } -current_model_quality_dict = { - 'globalMetrics': model_quality_dict, +binary_current_model_quality_dict = { + 'globalMetrics': binary_model_quality_dict, 'groupedMetrics': { 'f1': [ {'timestamp': '2024-01-01T00:00:00Z', 'value': 0.8}, @@ -211,6 +215,37 @@ def get_sample_current_dataset( }, } +multiclass_model_quality_dict = { + 'classes': [ + 'classA', + 'classB', + 'classC', + ], + 'class_metrics': [ + { + 'class_name': 'classA', + 'metrics': model_quality_base_dict, + }, + { + 'class_name': 'classB', + 'metrics': model_quality_base_dict, + }, + { + 'class_name': 'classC', + 'metrics': model_quality_base_dict, + }, + ], + 'global_metrics': { + 'confusion_matrix': [ + [3.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 2.0], + [1.0, 0.0, 0.0, 0.0], + ], + **model_quality_base_dict, + }, +} + data_quality_dict = { 'nObservations': 200, 'classMetrics': [ @@ -278,7 +313,7 @@ def get_sample_current_dataset( def get_sample_reference_metrics( reference_uuid: uuid.UUID = REFERENCE_UUID, - model_quality: Dict = model_quality_dict, + model_quality: Dict = binary_model_quality_dict, data_quality: Dict = data_quality_dict, statistics: Dict = statistics_dict, ) -> ReferenceDatasetMetrics: @@ -292,7 +327,7 @@ def get_sample_reference_metrics( def get_sample_current_metrics( current_uuid: uuid.UUID = CURRENT_UUID, - model_quality: Dict = model_quality_dict, + model_quality: Dict = binary_current_model_quality_dict, data_quality: Dict = data_quality_dict, statistics: Dict = statistics_dict, drift: Dict = drift_dict, diff --git a/api/tests/routes/metrics_route_test.py b/api/tests/routes/metrics_route_test.py index 18856306..062afb58 100644 --- a/api/tests/routes/metrics_route_test.py +++ b/api/tests/routes/metrics_route_test.py @@ -159,7 +159,7 @@ def test_get_current_model_quality_by_model_by_uuid(self): model_uuid = uuid.uuid4() current_uuid = uuid.uuid4() current_metrics = db_mock.get_sample_current_metrics( - model_quality=db_mock.current_model_quality_dict + model_quality=db_mock.binary_current_model_quality_dict ) model_quality = ModelQualityDTO.from_dict( dataset_type=DatasetType.CURRENT, diff --git a/api/tests/services/metrics_service_test.py b/api/tests/services/metrics_service_test.py index faf6a5b9..ff2af654 100644 --- a/api/tests/services/metrics_service_test.py +++ b/api/tests/services/metrics_service_test.py @@ -15,6 +15,7 @@ from app.models.metrics.drift_dto import DriftDTO from app.models.metrics.model_quality_dto import ModelQualityDTO from app.models.metrics.statistics_dto import StatisticsDTO +from app.models.model_dto import ModelType from app.services.metrics_service import MetricsService from app.services.model_service import ModelService from tests.commons import db_mock @@ -176,6 +177,37 @@ def test_get_empty_reference_model_quality_by_model_by_uuid(self): model_quality_data=None, ) + def test_get_reference_multiclass_model_quality_by_model_by_uuid(self): + status = JobStatus.SUCCEEDED + reference_dataset = db_mock.get_sample_reference_dataset(status=status.value) + reference_metrics = db_mock.get_sample_reference_metrics( + model_quality=db_mock.multiclass_model_quality_dict + ) + model = db_mock.get_sample_model(model_type=ModelType.MULTI_CLASS) + self.model_service.get_model_by_uuid = MagicMock(return_value=model) + self.reference_dataset_dao.get_reference_dataset_by_model_uuid = MagicMock( + return_value=reference_dataset + ) + self.reference_metrics_dao.get_reference_metrics_by_model_uuid = MagicMock( + return_value=reference_metrics + ) + res = self.metrics_service.get_reference_model_quality_by_model_by_uuid( + model_uuid + ) + self.reference_dataset_dao.get_reference_dataset_by_model_uuid.assert_called_once_with( + model_uuid + ) + self.reference_metrics_dao.get_reference_metrics_by_model_uuid.assert_called_once_with( + model_uuid + ) + + assert res == ModelQualityDTO.from_dict( + dataset_type=DatasetType.REFERENCE, + model_type=model.model_type, + job_status=reference_dataset.status, + model_quality_data=reference_metrics.model_quality, + ) + def test_get_reference_binary_class_data_quality_by_model_by_uuid(self): status = JobStatus.SUCCEEDED reference_dataset = db_mock.get_sample_reference_dataset(status=status.value) @@ -225,6 +257,34 @@ def test_get_empty_reference_data_quality_by_model_by_uuid(self): data_quality_data=None, ) + def test_get_reference_multiclass_data_quality_by_model_by_uuid(self): + status = JobStatus.SUCCEEDED + reference_dataset = db_mock.get_sample_reference_dataset(status=status.value) + reference_metrics = db_mock.get_sample_reference_metrics() + model = db_mock.get_sample_model(model_type=ModelType.MULTI_CLASS) + self.model_service.get_model_by_uuid = MagicMock(return_value=model) + self.reference_dataset_dao.get_reference_dataset_by_model_uuid = MagicMock( + return_value=reference_dataset + ) + self.reference_metrics_dao.get_reference_metrics_by_model_uuid = MagicMock( + return_value=reference_metrics + ) + res = self.metrics_service.get_reference_data_quality_by_model_by_uuid( + model_uuid + ) + self.reference_dataset_dao.get_reference_dataset_by_model_uuid.assert_called_once_with( + model_uuid + ) + self.reference_metrics_dao.get_reference_metrics_by_model_uuid.assert_called_once_with( + model_uuid + ) + + assert res == DataQualityDTO.from_dict( + model_type=model.model_type, + job_status=reference_dataset.status, + data_quality_data=reference_metrics.data_quality, + ) + def test_get_current_statistics_by_model_by_uuid(self): status = JobStatus.SUCCEEDED current_dataset = db_mock.get_sample_current_dataset(status=status.value) @@ -444,9 +504,7 @@ def test_get_empty_current_data_quality_by_model_by_uuid(self): def test_get_current_binary_class_model_quality_by_model_by_uuid(self): status = JobStatus.SUCCEEDED current_dataset = db_mock.get_sample_current_dataset(status=status.value) - current_metrics = db_mock.get_sample_current_metrics( - model_quality=db_mock.current_model_quality_dict - ) + current_metrics = db_mock.get_sample_current_metrics() model = db_mock.get_sample_model() self.model_service.get_model_by_uuid = MagicMock(return_value=model) self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock( diff --git a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py index 75966d5b..4c262bed 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py @@ -8,18 +8,17 @@ from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( BinaryClassDrift, - BinaryClassificationDataQuality, + ClassificationDataQuality, CurrentBinaryClassificationModelQuality, CurrentFileUpload, + CurrentMultiClassificationModelQuality, DataQuality, DatasetStats, Drift, JobStatus, ModelQuality, ModelType, - MultiClassDataQuality, MultiClassDrift, - MultiClassModelQuality, RegressionDataQuality, RegressionDrift, RegressionModelQuality, @@ -188,17 +187,10 @@ def __callback( job_status = JobStatus(response_json['jobStatus']) if 'dataQuality' in response_json: match self.__model_type: - case ModelType.BINARY: - return ( - job_status, - BinaryClassificationDataQuality.model_validate( - response_json['dataQuality'] - ), - ) - case ModelType.MULTI_CLASS: + case ModelType.BINARY | ModelType.MULTI_CLASS: return ( job_status, - MultiClassDataQuality.model_validate( + ClassificationDataQuality.model_validate( response_json['dataQuality'] ), ) @@ -268,7 +260,7 @@ def __callback( case ModelType.MULTI_CLASS: return ( job_status, - MultiClassModelQuality.model_validate( + CurrentMultiClassificationModelQuality.model_validate( response_json['modelQuality'] ), ) diff --git a/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py index 8c4183dd..f52779fd 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py @@ -7,15 +7,14 @@ from radicalbit_platform_sdk.commons import invoke from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( - BinaryClassificationDataQuality, BinaryClassificationModelQuality, + ClassificationDataQuality, DataQuality, DatasetStats, JobStatus, ModelQuality, ModelType, - MultiClassDataQuality, - MultiClassModelQuality, + MultiClassificationModelQuality, ReferenceFileUpload, RegressionDataQuality, RegressionModelQuality, @@ -114,17 +113,10 @@ def __callback( job_status = JobStatus(response_json['jobStatus']) if 'dataQuality' in response_json: match self.__model_type: - case ModelType.BINARY: - return ( - job_status, - BinaryClassificationDataQuality.model_validate( - response_json['dataQuality'] - ), - ) - case ModelType.MULTI_CLASS: + case ModelType.BINARY | ModelType.MULTI_CLASS: return ( job_status, - MultiClassDataQuality.model_validate( + ClassificationDataQuality.model_validate( response_json['dataQuality'] ), ) @@ -194,7 +186,7 @@ def __callback( case ModelType.MULTI_CLASS: return ( job_status, - MultiClassModelQuality.model_validate( + MultiClassificationModelQuality.model_validate( response_json['modelQuality'] ), ) diff --git a/sdk/radicalbit_platform_sdk/models/__init__.py b/sdk/radicalbit_platform_sdk/models/__init__.py index 248a26e7..6e7be01a 100644 --- a/sdk/radicalbit_platform_sdk/models/__init__.py +++ b/sdk/radicalbit_platform_sdk/models/__init__.py @@ -2,7 +2,7 @@ from .column_definition import ColumnDefinition from .data_type import DataType from .dataset_data_quality import ( - BinaryClassificationDataQuality, + ClassificationDataQuality, CategoricalFeatureMetrics, CategoryFrequency, ClassMedianMetrics, @@ -11,7 +11,6 @@ FeatureMetrics, MedianMetrics, MissingValue, - MultiClassDataQuality, NumericalFeatureMetrics, RegressionDataQuality, ) @@ -28,7 +27,8 @@ BinaryClassificationModelQuality, CurrentBinaryClassificationModelQuality, ModelQuality, - MultiClassModelQuality, + MultiClassificationModelQuality, + CurrentMultiClassificationModelQuality, RegressionModelQuality, ) from .dataset_stats import DatasetStats @@ -55,11 +55,11 @@ 'ModelQuality', 'BinaryClassificationModelQuality', 'CurrentBinaryClassificationModelQuality', - 'MultiClassModelQuality', + 'MultiClassificationModelQuality', + 'CurrentMultiClassificationModelQuality', 'RegressionModelQuality', 'DataQuality', - 'BinaryClassificationDataQuality', - 'MultiClassDataQuality', + 'ClassificationDataQuality', 'RegressionDataQuality', 'ClassMetrics', 'MedianMetrics', diff --git a/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py index 2f890443..075bde7b 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py @@ -84,7 +84,7 @@ class DataQuality(BaseModel): pass -class BinaryClassificationDataQuality(DataQuality): +class ClassificationDataQuality(DataQuality): n_observations: int class_metrics: List[ClassMetrics] feature_metrics: List[Union[NumericalFeatureMetrics, CategoricalFeatureMetrics]] @@ -96,9 +96,5 @@ class BinaryClassificationDataQuality(DataQuality): ) -class MultiClassDataQuality(DataQuality): - pass - - class RegressionDataQuality(DataQuality): pass diff --git a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py index 6e2d4877..d0216952 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py @@ -8,7 +8,7 @@ class ModelQuality(BaseModel): pass -class BinaryClassificationModelQuality(ModelQuality): +class MetricsBase(BaseModel): f1: Optional[float] = None accuracy: Optional[float] = None precision: Optional[float] = None @@ -21,10 +21,6 @@ class BinaryClassificationModelQuality(ModelQuality): weighted_false_positive_rate: Optional[float] = None true_positive_rate: Optional[float] = None false_positive_rate: Optional[float] = None - true_positive_count: int - false_positive_count: int - true_negative_count: int - false_negative_count: int area_under_roc: Optional[float] = None area_under_pr: Optional[float] = None @@ -33,12 +29,19 @@ class BinaryClassificationModelQuality(ModelQuality): ) +class BinaryClassificationModelQuality(ModelQuality, MetricsBase): + true_positive_count: int + false_positive_count: int + true_negative_count: int + false_negative_count: int + + class Distribution(BaseModel): timestamp: str value: Optional[float] = None -class GroupedBinaryClassModelQuality(BaseModel): +class GroupedMetricsBase(BaseModel): f1: List[Distribution] accuracy: List[Distribution] precision: List[Distribution] @@ -57,14 +60,35 @@ class GroupedBinaryClassModelQuality(BaseModel): model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class CurrentBinaryClassificationModelQuality(BaseModel): +class CurrentBinaryClassificationModelQuality(ModelQuality): global_metrics: BinaryClassificationModelQuality - grouped_metrics: GroupedBinaryClassModelQuality + grouped_metrics: GroupedMetricsBase + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class ClassMetrics(BaseModel): + class_name: str + metrics: MetricsBase + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class GlobalMetrics(MetricsBase): + confusion_matrix: List[List[int]] + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class MultiClassificationModelQuality(ModelQuality): + classes: List[str] + class_metrics: List[ClassMetrics] + global_metrics: GlobalMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class MultiClassModelQuality(ModelQuality): +class CurrentMultiClassificationModelQuality(ModelQuality): pass diff --git a/sdk/tests/apis/model_current_dataset_test.py b/sdk/tests/apis/model_current_dataset_test.py index eb370f3f..ac82a50d 100644 --- a/sdk/tests/apis/model_current_dataset_test.py +++ b/sdk/tests/apis/model_current_dataset_test.py @@ -8,15 +8,14 @@ from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( BinaryClassDrift, - BinaryClassificationDataQuality, + ClassificationDataQuality, CurrentBinaryClassificationModelQuality, CurrentFileUpload, + CurrentMultiClassificationModelQuality, DriftAlgorithm, JobStatus, ModelType, - MultiClassDataQuality, MultiClassDrift, - MultiClassModelQuality, RegressionDataQuality, RegressionDrift, RegressionModelQuality, @@ -399,7 +398,7 @@ def test_binary_class_data_quality_ok(self): metrics = model_current_dataset.data_quality() - assert isinstance(metrics, BinaryClassificationDataQuality) + assert isinstance(metrics, ClassificationDataQuality) assert metrics.n_observations == 200 assert len(metrics.class_metrics) == 2 @@ -438,16 +437,73 @@ def test_multi_class_data_quality_ok(self): url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/data-quality', status=200, body="""{ - "datetime": "something_not_used", - "jobStatus": "SUCCEEDED", - "dataQuality": {} - }""", + "datetime": "something_not_used", + "jobStatus": "SUCCEEDED", + "dataQuality": { + "nObservations": 200, + "classMetrics": [ + {"name": "classA", "count": 100, "percentage": 50.0}, + {"name": "classB", "count": 100, "percentage": 50.0} + ], + "featureMetrics": [ + { + "featureName": "age", + "type": "numerical", + "mean": 29.5, + "std": 5.2, + "min": 18, + "max": 45, + "medianMetrics": {"perc25": 25.0, "median": 29.0, "perc75": 34.0}, + "missingValue": {"count": 2, "percentage": 0.02}, + "classMedianMetrics": [ + { + "name": "classA", + "mean": 30.0, + "medianMetrics": {"perc25": 27.0, "median": 30.0, "perc75": 33.0} + }, + { + "name": "classB", + "mean": 29.0, + "medianMetrics": {"perc25": 24.0, "median": 28.0, "perc75": 32.0} + } + ], + "histogram": { + "buckets": [40.0, 45.0, 50.0, 55.0, 60.0], + "referenceValues": [50, 150, 200, 150, 50], + "currentValues": [45, 140, 210, 145, 60] + } + }, + { + "featureName": "gender", + "type": "categorical", + "distinctValue": 2, + "categoryFrequency": [ + {"name": "male", "count": 90, "frequency": 0.45}, + {"name": "female", "count": 110, "frequency": 0.55} + ], + "missingValue": {"count": 0, "percentage": 0.0} + } + ] + } + }""", ) metrics = model_current_dataset.data_quality() - assert isinstance(metrics, MultiClassDataQuality) - # TODO: add asserts to properties + assert isinstance(metrics, ClassificationDataQuality) + + assert metrics.n_observations == 200 + assert len(metrics.class_metrics) == 2 + assert metrics.class_metrics[0].name == 'classA' + assert metrics.class_metrics[0].count == 100 + assert metrics.class_metrics[0].percentage == 50.0 + assert len(metrics.feature_metrics) == 2 + assert metrics.feature_metrics[0].feature_name == 'age' + assert metrics.feature_metrics[0].type == 'numerical' + assert metrics.feature_metrics[0].mean == 29.5 + assert metrics.feature_metrics[1].feature_name == 'gender' + assert metrics.feature_metrics[1].type == 'categorical' + assert metrics.feature_metrics[1].distinct_value == 2 assert model_current_dataset.status() == JobStatus.SUCCEEDED @responses.activate @@ -695,7 +751,7 @@ def test_multi_class_model_quality_ok(self): metrics = model_current_dataset.model_quality() - assert isinstance(metrics, MultiClassModelQuality) + assert isinstance(metrics, CurrentMultiClassificationModelQuality) # TODO: add asserts to properties assert model_current_dataset.status() == JobStatus.SUCCEEDED diff --git a/sdk/tests/apis/model_reference_dataset_test.py b/sdk/tests/apis/model_reference_dataset_test.py index b970a73b..845c6498 100644 --- a/sdk/tests/apis/model_reference_dataset_test.py +++ b/sdk/tests/apis/model_reference_dataset_test.py @@ -7,12 +7,11 @@ from radicalbit_platform_sdk.apis import ModelReferenceDataset from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( - BinaryClassificationDataQuality, BinaryClassificationModelQuality, + ClassificationDataQuality, JobStatus, ModelType, - MultiClassDataQuality, - MultiClassModelQuality, + MultiClassificationModelQuality, ReferenceFileUpload, RegressionDataQuality, RegressionModelQuality, @@ -228,6 +227,25 @@ def test_multi_class_model_metrics_ok(self): base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() + f1 = 0.75 + accuracy = 0.98 + recall = 0.23 + weighted_precision = 0.15 + weighted_true_positive_rate = 0.01 + weighted_false_positive_rate = 0.23 + weighted_f_measure = 2.45 + true_positive_rate = 4.12 + false_positive_rate = 5.89 + precision = 2.33 + weighted_recall = 4.22 + f_measure = 9.33 + confusion_matrix = [ + [3.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 2.0], + [1.0, 0.0, 0.0, 0.0], + ] + model_reference_dataset = ModelReferenceDataset( base_url, model_id, @@ -244,17 +262,89 @@ def test_multi_class_model_metrics_ok(self): method=responses.GET, url=f'{base_url}/api/models/{str(model_id)}/reference/model-quality', status=200, - body="""{ + body=f"""{{ "datetime": "something_not_used", "jobStatus": "SUCCEEDED", - "modelQuality": {} - }""", + "modelQuality": {{ + "classes": ["classA", "classB", "classC", "classD"], + "classMetrics": [ + {{ + "className": "classA", + "metrics": {{ + "accuracy": {accuracy} + }} + }}, + {{ + "className": "classB", + "metrics": {{ + "fMeasure": {f_measure} + }} + }}, + {{ + "className": "classC", + "metrics": {{ + "recall": {recall} + }} + }}, + {{ + "className": "classD", + "metrics": {{ + "truePositiveRate": {true_positive_rate}, + "falsePositiveRate": {false_positive_rate} + }} + }} + ], + "globalMetrics": {{ + "f1": {f1}, + "accuracy": {accuracy}, + "precision": {precision}, + "recall": {recall}, + "fMeasure": {f_measure}, + "weightedPrecision": {weighted_precision}, + "weightedRecall": {weighted_recall}, + "weightedFMeasure": {weighted_f_measure}, + "weightedTruePositiveRate": {weighted_true_positive_rate}, + "weightedFalsePositiveRate": {weighted_false_positive_rate}, + "truePositiveRate": {true_positive_rate}, + "falsePositiveRate": {false_positive_rate}, + "confusionMatrix": {confusion_matrix} + }} + }} + }}""", ) metrics = model_reference_dataset.model_quality() - assert isinstance(metrics, MultiClassModelQuality) - # TODO: add asserts to properties + assert isinstance(metrics, MultiClassificationModelQuality) + assert metrics.classes == ['classA', 'classB', 'classC', 'classD'] + assert metrics.global_metrics.accuracy == accuracy + assert metrics.global_metrics.weighted_precision == weighted_precision + assert metrics.global_metrics.weighted_recall == weighted_recall + assert ( + metrics.global_metrics.weighted_true_positive_rate + == weighted_true_positive_rate + ) + assert ( + metrics.global_metrics.weighted_false_positive_rate + == weighted_false_positive_rate + ) + assert metrics.global_metrics.weighted_f_measure == weighted_f_measure + assert metrics.global_metrics.true_positive_rate == true_positive_rate + assert metrics.global_metrics.false_positive_rate == false_positive_rate + assert metrics.global_metrics.precision == precision + assert metrics.global_metrics.f_measure == f_measure + assert metrics.class_metrics[0].class_name == 'classA' + assert metrics.class_metrics[0].metrics.accuracy == accuracy + assert metrics.class_metrics[1].class_name == 'classB' + assert metrics.class_metrics[1].metrics.f_measure == f_measure + assert metrics.class_metrics[2].class_name == 'classC' + assert metrics.class_metrics[2].metrics.recall == recall + assert metrics.class_metrics[3].class_name == 'classD' + assert ( + metrics.class_metrics[3].metrics.false_positive_rate == false_positive_rate + ) + assert metrics.class_metrics[3].metrics.true_positive_rate == true_positive_rate + assert model_reference_dataset.status() == JobStatus.SUCCEEDED @responses.activate @@ -420,7 +510,7 @@ def test_binary_class_data_quality_ok(self): metrics = model_reference_dataset.data_quality() - assert isinstance(metrics, BinaryClassificationDataQuality) + assert isinstance(metrics, ClassificationDataQuality) assert metrics.n_observations == 200 assert len(metrics.class_metrics) == 2 @@ -458,16 +548,73 @@ def test_multi_class_data_quality_ok(self): url=f'{base_url}/api/models/{str(model_id)}/reference/data-quality', status=200, body="""{ - "datetime": "something_not_used", - "jobStatus": "SUCCEEDED", - "dataQuality": {} - }""", + "datetime": "something_not_used", + "jobStatus": "SUCCEEDED", + "dataQuality": { + "nObservations": 200, + "classMetrics": [ + {"name": "classA", "count": 100, "percentage": 50.0}, + {"name": "classB", "count": 100, "percentage": 50.0} + ], + "featureMetrics": [ + { + "featureName": "age", + "type": "numerical", + "mean": 29.5, + "std": 5.2, + "min": 18, + "max": 45, + "medianMetrics": {"perc25": 25.0, "median": 29.0, "perc75": 34.0}, + "missingValue": {"count": 2, "percentage": 0.02}, + "classMedianMetrics": [ + { + "name": "classA", + "mean": 30.0, + "medianMetrics": {"perc25": 27.0, "median": 30.0, "perc75": 33.0} + }, + { + "name": "classB", + "mean": 29.0, + "medianMetrics": {"perc25": 24.0, "median": 28.0, "perc75": 32.0} + } + ], + "histogram": { + "buckets": [40.0, 45.0, 50.0, 55.0, 60.0], + "referenceValues": [50, 150, 200, 150, 50], + "currentValues": [45, 140, 210, 145, 60] + } + }, + { + "featureName": "gender", + "type": "categorical", + "distinctValue": 2, + "categoryFrequency": [ + {"name": "male", "count": 90, "frequency": 0.45}, + {"name": "female", "count": 110, "frequency": 0.55} + ], + "missingValue": {"count": 0, "percentage": 0.0} + } + ] + } + }""", ) metrics = model_reference_dataset.data_quality() - assert isinstance(metrics, MultiClassDataQuality) - # TODO: add asserts to properties + assert isinstance(metrics, ClassificationDataQuality) + + assert metrics.n_observations == 200 + assert len(metrics.class_metrics) == 2 + assert metrics.class_metrics[0].name == 'classA' + assert metrics.class_metrics[0].count == 100 + assert metrics.class_metrics[0].percentage == 50.0 + assert len(metrics.feature_metrics) == 2 + assert metrics.feature_metrics[0].feature_name == 'age' + assert metrics.feature_metrics[0].type == 'numerical' + assert metrics.feature_metrics[0].mean == 29.5 + assert metrics.feature_metrics[1].feature_name == 'gender' + assert metrics.feature_metrics[1].type == 'categorical' + assert metrics.feature_metrics[1].distinct_value == 2 assert model_reference_dataset.status() == JobStatus.SUCCEEDED @responses.activate