From aed740b0b8f62e1b22bd9d183306b62a1b409205 Mon Sep 17 00:00:00 2001 From: Daniele Tria <36860433+dtria91@users.noreply.github.com> Date: Tue, 2 Jul 2024 14:29:40 +0200 Subject: [PATCH] fix: define single structure for drift metrics (#67) * feat: define single structure for drift metrics * fix: remove import --- api/app/models/metrics/drift_dto.py | 43 +++-------- api/app/services/metrics_service.py | 6 -- api/tests/routes/metrics_route_test.py | 1 - api/tests/services/metrics_service_test.py | 15 +--- .../apis/model_current_dataset.py | 27 +------ .../models/__init__.py | 6 -- .../models/dataset_drift.py | 11 --- sdk/tests/apis/model_current_dataset_test.py | 76 +------------------ 8 files changed, 18 insertions(+), 167 deletions(-) diff --git a/api/app/models/metrics/drift_dto.py b/api/app/models/metrics/drift_dto.py index 99e79d34..09ada9b2 100644 --- a/api/app/models/metrics/drift_dto.py +++ b/api/app/models/metrics/drift_dto.py @@ -4,9 +4,7 @@ from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_camel -from app.models.exceptions import MetricsInternalError from app.models.job_status import JobStatus -from app.models.model_dto import ModelType class DriftAlgorithm(str, Enum): @@ -29,23 +27,15 @@ class FeatureMetrics(BaseModel): model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class BinaryClassDrift(BaseModel): +class Drift(BaseModel): feature_metrics: List[FeatureMetrics] model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class MultiClassDrift(BaseModel): - pass - - -class RegressionDrift(BaseModel): - pass - - class DriftDTO(BaseModel): job_status: JobStatus - drift: Optional[BinaryClassDrift | MultiClassDrift | RegressionDrift] + drift: Optional[Drift] model_config = ConfigDict( arbitrary_types_allowed=True, @@ -55,21 +45,11 @@ class DriftDTO(BaseModel): @staticmethod def from_dict( - model_type: ModelType, job_status: JobStatus, drift_data: Optional[Dict], ) -> 'DriftDTO': """Create a DriftDTO from a dictionary of data.""" - if not drift_data: - return DriftDTO( - job_status=job_status, - drift=None, - ) - - drift = DriftDTO._create_drift( - model_type=model_type, - drift_data=drift_data, - ) + drift = DriftDTO._create_drift(drift_data=drift_data) return DriftDTO( job_status=job_status, @@ -78,14 +58,9 @@ def from_dict( @staticmethod def _create_drift( - model_type: ModelType, - drift_data: Dict, - ) -> BinaryClassDrift | MultiClassDrift | RegressionDrift: - """Create a specific drift instance based on the model type.""" - if model_type == ModelType.BINARY: - return BinaryClassDrift(**drift_data) - if model_type == ModelType.MULTI_CLASS: - return MultiClassDrift(**drift_data) - if model_type == ModelType.REGRESSION: - return RegressionDrift(**drift_data) - raise MetricsInternalError(f'Invalid model type {model_type}') + drift_data: Optional[Dict], + ) -> Optional[Drift]: + """Create a specific drift instance from a dictionary of data.""" + if not drift_data: + return None + return Drift(**drift_data) diff --git a/api/app/services/metrics_service.py b/api/app/services/metrics_service.py index 53adedda..8a8ea6a7 100644 --- a/api/app/services/metrics_service.py +++ b/api/app/services/metrics_service.py @@ -229,10 +229,8 @@ def _get_drift_by_model_uuid( missing_status, ) -> DriftDTO: """Retrieve drift for a model by its UUID.""" - model = self.model_service.get_model_by_uuid(model_uuid) dataset, metrics = dataset_and_metrics_getter(model_uuid) return self._create_drift_dto( - model_type=model.model_type, dataset=dataset, metrics=metrics, missing_status=missing_status, @@ -321,7 +319,6 @@ def _create_data_quality_dto( @staticmethod def _create_drift_dto( - model_type: ModelType, dataset: Optional[ReferenceDataset | CurrentDataset], metrics: Optional[ReferenceDatasetMetrics | CurrentDatasetMetrics], missing_status, @@ -329,18 +326,15 @@ def _create_drift_dto( """Create a DriftDTO from the provided dataset and metrics.""" if not dataset: return DriftDTO.from_dict( - model_type=model_type, job_status=missing_status, drift_data=None, ) if not metrics: return DriftDTO.from_dict( - model_type=model_type, job_status=dataset.status, drift_data=None, ) return DriftDTO.from_dict( - model_type=model_type, job_status=dataset.status, drift_data=metrics.drift, ) diff --git a/api/tests/routes/metrics_route_test.py b/api/tests/routes/metrics_route_test.py index 062afb58..83f08a4e 100644 --- a/api/tests/routes/metrics_route_test.py +++ b/api/tests/routes/metrics_route_test.py @@ -141,7 +141,6 @@ def test_get_current_drift(self): current_metrics = db_mock.get_sample_current_metrics() drift = DriftDTO.from_dict( job_status=JobStatus.SUCCEEDED, - model_type=model.model_type, drift_data=current_metrics.drift, ) self.metrics_service.get_current_drift = MagicMock(return_value=drift) diff --git a/api/tests/services/metrics_service_test.py b/api/tests/services/metrics_service_test.py index a2c82992..0139a3b1 100644 --- a/api/tests/services/metrics_service_test.py +++ b/api/tests/services/metrics_service_test.py @@ -344,7 +344,6 @@ def test_get_current_drift(self): model = db_mock.get_sample_model() current_dataset = db_mock.get_sample_current_dataset(status=status.value) current_metrics = db_mock.get_sample_current_metrics() - self.model_service.get_model_by_uuid = MagicMock(return_value=model) self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock( return_value=current_dataset ) @@ -352,7 +351,6 @@ def test_get_current_drift(self): return_value=current_metrics ) res = self.metrics_service.get_current_drift(model.uuid, current_dataset.uuid) - self.model_service.get_model_by_uuid.assert_called_once_with(model.uuid) self.current_dataset_dao.get_current_dataset_by_model_uuid.assert_called_once_with( model.uuid, current_dataset.uuid ) @@ -362,46 +360,37 @@ def test_get_current_drift(self): assert res == DriftDTO.from_dict( job_status=status, - model_type=model.model_type, drift_data=current_metrics.drift, ) def test_get_empty_current_drift(self): status = JobStatus.IMPORTING - model = db_mock.get_sample_model() current_dataset = db_mock.get_sample_current_dataset(status=status.value) - self.model_service.get_model_by_uuid = MagicMock(return_value=model) self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock( return_value=current_dataset ) - res = self.metrics_service.get_current_drift(model.uuid, current_dataset.uuid) - self.model_service.get_model_by_uuid.assert_called_once_with(model.uuid) + res = self.metrics_service.get_current_drift(model_uuid, current_dataset.uuid) self.current_dataset_dao.get_current_dataset_by_model_uuid.assert_called_once_with( - model.uuid, current_dataset.uuid + model_uuid, current_dataset.uuid ) assert res == DriftDTO.from_dict( job_status=status, - model_type=model.model_type, drift_data=None, ) def test_get_missing_current_drift(self): status = JobStatus.MISSING_CURRENT - model = db_mock.get_sample_model() - self.model_service.get_model_by_uuid = MagicMock(return_value=model) self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock( return_value=None ) res = self.metrics_service.get_current_drift(model_uuid, current_uuid) - self.model_service.get_model_by_uuid.assert_called_once_with(model.uuid) self.current_dataset_dao.get_current_dataset_by_model_uuid.assert_called_once_with( model_uuid, current_uuid ) assert res == DriftDTO.from_dict( job_status=status, - model_type=model.model_type, drift_data=None, ) diff --git a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py index 0f20a9bf..a16cd862 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py @@ -7,7 +7,6 @@ from radicalbit_platform_sdk.commons import invoke from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( - BinaryClassDrift, ClassificationDataQuality, CurrentBinaryClassificationModelQuality, CurrentFileUpload, @@ -17,10 +16,8 @@ JobStatus, ModelQuality, ModelType, - MultiClassDrift, MultiClassificationModelQuality, RegressionDataQuality, - RegressionDrift, RegressionModelQuality, ) @@ -122,26 +119,10 @@ def __callback( response_json = response.json() job_status = JobStatus(response_json['jobStatus']) if 'drift' in response_json: - match self.__model_type: - case ModelType.BINARY: - return ( - job_status, - BinaryClassDrift.model_validate(response_json['drift']), - ) - case ModelType.MULTI_CLASS: - return ( - job_status, - MultiClassDrift.model_validate(response_json['drift']), - ) - case ModelType.REGRESSION: - return ( - job_status, - RegressionDrift.model_validate(response_json['drift']), - ) - case _: - raise ClientError( - 'Unable to parse metrics because of not managed model type' - ) from None + return ( + job_status, + Drift.model_validate(response_json['drift']), + ) except KeyError as e: raise ClientError(f'Unable to parse response: {response.text}') from e except ValidationError as e: diff --git a/sdk/radicalbit_platform_sdk/models/__init__.py b/sdk/radicalbit_platform_sdk/models/__init__.py index 18d58c8d..28a218f2 100644 --- a/sdk/radicalbit_platform_sdk/models/__init__.py +++ b/sdk/radicalbit_platform_sdk/models/__init__.py @@ -15,13 +15,10 @@ RegressionDataQuality, ) from .dataset_drift import ( - BinaryClassDrift, Drift, DriftAlgorithm, FeatureDrift, FeatureDriftCalculation, - MultiClassDrift, - RegressionDrift, ) from .dataset_model_quality import ( BinaryClassificationModelQuality, @@ -71,9 +68,6 @@ 'FeatureDriftCalculation', 'FeatureDrift', 'Drift', - 'BinaryClassDrift', - 'MultiClassDrift', - 'RegressionDrift', 'ReferenceFileUpload', 'CurrentFileUpload', 'FileReference', diff --git a/sdk/radicalbit_platform_sdk/models/dataset_drift.py b/sdk/radicalbit_platform_sdk/models/dataset_drift.py index 503761bb..532774a7 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_drift.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_drift.py @@ -26,18 +26,7 @@ class FeatureDrift(BaseModel): class Drift(BaseModel): - pass - - -class BinaryClassDrift(Drift): feature_metrics: List[FeatureDrift] model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) - -class MultiClassDrift(Drift): - pass - - -class RegressionDrift(BaseModel): - pass diff --git a/sdk/tests/apis/model_current_dataset_test.py b/sdk/tests/apis/model_current_dataset_test.py index 8d5b483b..b9606d4c 100644 --- a/sdk/tests/apis/model_current_dataset_test.py +++ b/sdk/tests/apis/model_current_dataset_test.py @@ -7,17 +7,15 @@ from radicalbit_platform_sdk.apis import ModelCurrentDataset from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( - BinaryClassDrift, ClassificationDataQuality, CurrentBinaryClassificationModelQuality, CurrentFileUpload, + Drift, DriftAlgorithm, JobStatus, ModelType, - MultiClassDrift, MultiClassificationModelQuality, RegressionDataQuality, - RegressionDrift, RegressionModelQuality, ) @@ -141,7 +139,7 @@ def test_statistics_key_error(self): model_current_dataset.statistics() @responses.activate - def test_binary_class_drift_ok(self): + def test_drift_ok(self): base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() @@ -185,7 +183,7 @@ def test_binary_class_drift_ok(self): drift = model_current_dataset.drift() - assert isinstance(drift, BinaryClassDrift) + assert isinstance(drift, Drift) assert len(drift.feature_metrics) == 3 assert drift.feature_metrics[1].feature_name == 'city' @@ -198,74 +196,6 @@ def test_binary_class_drift_ok(self): assert drift.feature_metrics[2].drift_calc.has_drift is True assert model_current_dataset.status() == JobStatus.SUCCEEDED - @responses.activate - def test_multi_class_drift_ok(self): - base_url = 'http://api:9000' - model_id = uuid.uuid4() - import_uuid = uuid.uuid4() - model_current_dataset = ModelCurrentDataset( - base_url, - model_id, - ModelType.MULTI_CLASS, - CurrentFileUpload( - uuid=import_uuid, - path='s3://bucket/file.csv', - date='2014', - correlation_id_column='column', - status=JobStatus.IMPORTING, - ), - ) - - responses.add( - method=responses.GET, - url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift', - status=200, - body="""{ - "jobStatus": "SUCCEEDED", - "drift": {} - }""", - ) - - drift = model_current_dataset.drift() - - assert isinstance(drift, MultiClassDrift) - # TODO: add asserts to properties - assert model_current_dataset.status() == JobStatus.SUCCEEDED - - @responses.activate - def test_regression_drift_ok(self): - base_url = 'http://api:9000' - model_id = uuid.uuid4() - import_uuid = uuid.uuid4() - model_current_dataset = ModelCurrentDataset( - base_url, - model_id, - ModelType.REGRESSION, - CurrentFileUpload( - uuid=import_uuid, - path='s3://bucket/file.csv', - date='2014', - correlation_id_column='column', - status=JobStatus.IMPORTING, - ), - ) - - responses.add( - method=responses.GET, - url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift', - status=200, - body="""{ - "jobStatus": "SUCCEEDED", - "drift": {} - }""", - ) - - drift = model_current_dataset.drift() - - assert isinstance(drift, RegressionDrift) - # TODO: add asserts to properties - assert model_current_dataset.status() == JobStatus.SUCCEEDED - @responses.activate def test_drift_validation_error(self): base_url = 'http://api:9000'