From 51fc3eaf181324319b58eb3f57a27b62e4fe48ba Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Mon, 8 Jul 2024 15:35:07 +0200 Subject: [PATCH] feat: add current model quality regression --- api/app/models/metrics/model_quality_dto.py | 61 ++++++++++++++--- api/tests/commons/db_mock.py | 37 ++++++++++ api/tests/routes/upload_dataset_route_test.py | 8 +-- api/tests/services/file_service_test.py | 3 +- api/tests/services/metrics_service_test.py | 31 +++++++++ .../models/dataset_model_quality.py | 25 ++++++- sdk/tests/apis/model_current_dataset_test.py | 68 +++++++++++++++++-- .../apis/model_reference_dataset_test.py | 6 +- 8 files changed, 215 insertions(+), 24 deletions(-) diff --git a/api/app/models/metrics/model_quality_dto.py b/api/app/models/metrics/model_quality_dto.py index d4d17278..5998c752 100644 --- a/api/app/models/metrics/model_quality_dto.py +++ b/api/app/models/metrics/model_quality_dto.py @@ -90,11 +90,11 @@ class MultiClassificationModelQuality(BaseModel): model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class RegressionModelQuality(BaseModel): +class RegressionMetricsBase(BaseModel): r2: Optional[float] = None mae: Optional[float] = None mse: Optional[float] = None - var: Optional[float] = None + variance: Optional[float] = None mape: Optional[float] = None rmse: Optional[float] = None adj_r2: Optional[float] = None @@ -102,6 +102,41 @@ class RegressionModelQuality(BaseModel): model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) +class BaseRegressionMetrics(BaseModel): + r2: Optional[float] = None + mae: Optional[float] = None + mse: Optional[float] = None + variance: Optional[float] = None + mape: Optional[float] = None + rmse: Optional[float] = None + adj_r2: Optional[float] = None + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class GroupedBaseRegressionMetrics(BaseModel): + r2: List[Distribution] + mae: List[Distribution] + mse: List[Distribution] + variance: List[Distribution] + mape: List[Distribution] + rmse: List[Distribution] + adj_r2: List[Distribution] + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class RegressionModelQuality(BaseRegressionMetrics): + pass + + +class CurrentRegressionModelQuality(BaseModel): + global_metrics: BaseRegressionMetrics + grouped_metrics: GroupedBaseRegressionMetrics + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + class ModelQualityDTO(BaseModel): job_status: JobStatus model_quality: Optional[ @@ -109,6 +144,7 @@ class ModelQualityDTO(BaseModel): | CurrentBinaryClassificationModelQuality | MultiClassificationModelQuality | RegressionModelQuality + | CurrentRegressionModelQuality ] model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) @@ -143,11 +179,6 @@ def _create_model_quality( model_type: ModelType, dataset_type: DatasetType, model_quality_data: Dict, - ) -> ( - BinaryClassificationModelQuality - | CurrentBinaryClassificationModelQuality - | MultiClassificationModelQuality - | RegressionModelQuality ): """Create a specific model quality instance based on model type and dataset type.""" if model_type == ModelType.BINARY: @@ -158,7 +189,9 @@ def _create_model_quality( if model_type == ModelType.MULTI_CLASS: return MultiClassificationModelQuality(**model_quality_data) if model_type == ModelType.REGRESSION: - return RegressionModelQuality(**model_quality_data) + return ModelQualityDTO._create_regression_model_quality( + dataset_type=dataset_type, model_quality_data=model_quality_data + ) raise MetricsInternalError(f'Invalid model type {model_type}') @staticmethod @@ -172,3 +205,15 @@ def _create_binary_model_quality( if dataset_type == DatasetType.CURRENT: return CurrentBinaryClassificationModelQuality(**model_quality_data) raise MetricsInternalError(f'Invalid dataset type {dataset_type}') + + @staticmethod + def _create_regression_model_quality( + dataset_type: DatasetType, + model_quality_data: Dict, + ) -> RegressionModelQuality | CurrentRegressionModelQuality: + """Create a binary model quality instance based on dataset type.""" + if dataset_type == DatasetType.REFERENCE: + return RegressionModelQuality(**model_quality_data) + if dataset_type == DatasetType.CURRENT: + return CurrentRegressionModelQuality(**model_quality_data) + raise MetricsInternalError(f'Invalid dataset type {dataset_type}') diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index 36c76786..040fe88a 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -323,6 +323,43 @@ def get_sample_current_dataset( 'mape': 35.19314237273801, 'rmse': 202.23194752188695, 'adj_r2': 0.9116805380966796, + 'variance': 0.23 +} + +grouped_regression_model_quality_dict = { + 'r2': [ + {'timestamp': '2024-01-01T00:00:00Z', 'value': 0.8}, + {'timestamp': '2024-02-01T00:00:00Z', 'value': 0.85}, + ], + 'mae': [ + {'timestamp': '2024-01-01T00:00:00Z', 'value': 0.88}, + {'timestamp': '2024-02-01T00:00:00Z', 'value': 0.9}, + ], + 'mse': [ + {'timestamp': '2024-01-01T00:00:00Z', 'value': 0.86}, + {'timestamp': '2024-02-01T00:00:00Z', 'value': 0.88}, + ], + 'mape': [ + {'timestamp': '2024-01-01T00:00:00Z', 'value': 0.81}, + {'timestamp': '2024-02-01T00:00:00Z', 'value': 0.83}, + ], + 'rmse': [ + {'timestamp': '2024-01-01T00:00:00Z', 'value': 0.8}, + {'timestamp': '2024-02-01T00:00:00Z', 'value': 0.85}, + ], + 'adj_r2': [ + {'timestamp': '2024-01-01T00:00:00Z', 'value': 0.85}, + {'timestamp': '2024-02-01T00:00:00Z', 'value': 0.87}, + ], + 'variance': [ + {'timestamp': '2024-01-01T00:00:00Z', 'value': 0.82}, + {'timestamp': '2024-02-01T00:00:00Z', 'value': 0.84}, + ], +} + +current_regression_model_quality_dict = { + 'global_metrics': regression_model_quality_dict, + 'grouped_metrics': grouped_regression_model_quality_dict, } regression_data_quality_dict = { diff --git a/api/tests/routes/upload_dataset_route_test.py b/api/tests/routes/upload_dataset_route_test.py index 73c16524..994f8438 100644 --- a/api/tests/routes/upload_dataset_route_test.py +++ b/api/tests/routes/upload_dataset_route_test.py @@ -82,7 +82,7 @@ def test_upload_current(self): path='test', date=str(datetime.datetime.now(tz=datetime.UTC)), status=JobStatus.IMPORTING, - correlation_id_column=None + correlation_id_column=None, ) self.file_service.upload_current_file = MagicMock( return_value=upload_file_result @@ -103,11 +103,9 @@ def test_bind_current(self): path='test', date=str(datetime.datetime.now(tz=datetime.UTC)), status=JobStatus.IMPORTING, - correlation_id_column=None - ) - self.file_service.bind_current_file = MagicMock( - return_value=upload_file_result + correlation_id_column=None, ) + self.file_service.bind_current_file = MagicMock(return_value=upload_file_result) res = self.client.post( f'{self.prefix}/{model_uuid}/current/bind', json=jsonable_encoder(file_ref), diff --git a/api/tests/services/file_service_test.py b/api/tests/services/file_service_test.py index 65efc577..92bf4fbc 100644 --- a/api/tests/services/file_service_test.py +++ b/api/tests/services/file_service_test.py @@ -220,7 +220,8 @@ def test_upload_current_file_ok(self): self.spark_k8s_client.submit_app = MagicMock() result = self.files_service.upload_current_file( - model.uuid, file, + model.uuid, + file, ) self.model_svc.get_model_by_uuid.assert_called_once() diff --git a/api/tests/services/metrics_service_test.py b/api/tests/services/metrics_service_test.py index ba7fafdf..c7808af8 100644 --- a/api/tests/services/metrics_service_test.py +++ b/api/tests/services/metrics_service_test.py @@ -635,6 +635,37 @@ def test_get_current_multiclass_model_quality_by_model_by_uuid(self): model_quality_data=current_metrics.model_quality, ) + def test_get_current_regression_model_quality_by_model_by_uuid(self): + status = JobStatus.SUCCEEDED + current_dataset = db_mock.get_sample_current_dataset(status=status.value) + current_metrics = db_mock.get_sample_current_metrics( + model_quality=db_mock.current_regression_model_quality_dict + ) + model = db_mock.get_sample_model(model_type=ModelType.REGRESSION) + self.model_service.get_model_by_uuid = MagicMock(return_value=model) + self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock( + return_value=current_dataset + ) + self.current_metrics_dao.get_current_metrics_by_model_uuid = MagicMock( + return_value=current_metrics + ) + res = self.metrics_service.get_current_model_quality_by_model_by_uuid( + model_uuid, current_dataset.uuid + ) + self.current_dataset_dao.get_current_dataset_by_model_uuid.assert_called_once_with( + model_uuid, current_dataset.uuid + ) + self.current_metrics_dao.get_current_metrics_by_model_uuid.assert_called_once_with( + model_uuid, current_dataset.uuid + ) + + assert res == ModelQualityDTO.from_dict( + dataset_type=DatasetType.CURRENT, + model_type=model.model_type, + job_status=current_dataset.status, + model_quality_data=current_metrics.model_quality, + ) + model_uuid = db_mock.MODEL_UUID current_uuid = db_mock.CURRENT_UUID diff --git a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py index f69f87f2..df4e15ea 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py @@ -89,11 +89,11 @@ class MultiClassificationModelQuality(ModelQuality): model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class RegressionModelQuality(ModelQuality): +class BaseRegressionMetrics(BaseModel): r2: Optional[float] = None mae: Optional[float] = None mse: Optional[float] = None - var: Optional[float] = None + variance: Optional[float] = None mape: Optional[float] = None rmse: Optional[float] = None adj_r2: Optional[float] = None @@ -101,5 +101,24 @@ class RegressionModelQuality(ModelQuality): model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class CurrentRegressionModelQuality(ModelQuality): +class GroupedBaseRegressionMetrics(BaseModel): + r2: List[Distribution] + mae: List[Distribution] + mse: List[Distribution] + variance: List[Distribution] + mape: List[Distribution] + rmse: List[Distribution] + adj_r2: List[Distribution] + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class RegressionModelQuality(ModelQuality, BaseRegressionMetrics): pass + + +class CurrentRegressionModelQuality(ModelQuality): + global_metrics: BaseRegressionMetrics + grouped_metrics: GroupedBaseRegressionMetrics + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) diff --git a/sdk/tests/apis/model_current_dataset_test.py b/sdk/tests/apis/model_current_dataset_test.py index b395ced0..6ca3e16e 100644 --- a/sdk/tests/apis/model_current_dataset_test.py +++ b/sdk/tests/apis/model_current_dataset_test.py @@ -898,6 +898,13 @@ def test_regression_model_quality_ok(self): base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() + r2 = 0.91 + mae = 125.01 + mse = 408.76 + variance = 393.31 + mape = 35.19 + rmse = 202.23 + adj_r2 = 0.91 model_current_dataset = ModelCurrentDataset( base_url, model_id, @@ -915,17 +922,70 @@ def test_regression_model_quality_ok(self): method=responses.GET, url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/model-quality', status=200, - body="""{ + body=f"""{{ "datetime": "something_not_used", "jobStatus": "SUCCEEDED", - "modelQuality": {} - }""", + "modelQuality": {{ + "global_metrics": {{ + "r2": {r2}, + "mae": {mae}, + "mse": {mse}, + "variance": {variance}, + "mape": {mape}, + "rmse": {rmse}, + "adjR2": {adj_r2} + }}, + "grouped_metrics": {{ + "r2": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": {r2}}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.88}} + ], + "mae": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": {mae}}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}} + ], + "mse": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": {mse}}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.85}} + ], + "variance": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": {variance}}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}} + ], + "mape": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": {mape}}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}} + ], + "rmse": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": {rmse}}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}} + ], + "adjR2": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": {adj_r2}}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}} + ] + }} + }} + }}""", ) metrics = model_current_dataset.model_quality() assert isinstance(metrics, CurrentRegressionModelQuality) - # TODO: add asserts to properties + assert metrics.global_metrics.r2 == r2 + assert metrics.global_metrics.mae == mae + assert metrics.global_metrics.mse == mse + assert metrics.global_metrics.variance == variance + assert metrics.global_metrics.mape == mape + assert metrics.global_metrics.rmse == rmse + assert metrics.global_metrics.adj_r2 == adj_r2 + assert metrics.grouped_metrics.r2[0].value == r2 + assert metrics.grouped_metrics.mae[0].value == mae + assert metrics.grouped_metrics.mse[0].value == mse + assert metrics.grouped_metrics.variance[0].value == variance + assert metrics.grouped_metrics.mape[0].value == mape + assert metrics.grouped_metrics.rmse[0].value == rmse + assert metrics.grouped_metrics.adj_r2[0].value == adj_r2 assert model_current_dataset.status() == JobStatus.SUCCEEDED @responses.activate diff --git a/sdk/tests/apis/model_reference_dataset_test.py b/sdk/tests/apis/model_reference_dataset_test.py index 31a1fb55..a9bda9b1 100644 --- a/sdk/tests/apis/model_reference_dataset_test.py +++ b/sdk/tests/apis/model_reference_dataset_test.py @@ -355,7 +355,7 @@ def test_regression_model_metrics_ok(self): r2 = 0.91 mae = 125.01 mse = 408.76 - var = 393.31 + variance = 393.31 mape = 35.19 rmse = 202.23 adj_r2 = 0.91 @@ -382,7 +382,7 @@ def test_regression_model_metrics_ok(self): "r2": {r2}, "mae": {mae}, "mse": {mse}, - "var": {var}, + "variance": {variance}, "mape": {mape}, "rmse": {rmse}, "adjR2": {adj_r2} @@ -396,7 +396,7 @@ def test_regression_model_metrics_ok(self): assert metrics.r2 == r2 assert metrics.mae == mae assert metrics.mse == mse - assert metrics.var == var + assert metrics.variance == variance assert metrics.mape == mape assert metrics.rmse == rmse assert metrics.adj_r2 == adj_r2