diff --git a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py index 0c18f928..fc5f4c14 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py @@ -4,6 +4,8 @@ CurrentFileUpload, JobStatus, DatasetStats, + Drift, + BinaryClassDrift, ) from radicalbit_platform_sdk.errors import ClientError from pydantic import ValidationError @@ -96,9 +98,59 @@ def __callback( return self.__statistics - def drift(self): - # TODO: implement get drift - pass + def drift(self) -> Optional[Drift]: + """ + Get drift about the current dataset + + :return: The `Drift` if exists + """ + + def __callback( + response: requests.Response, + ) -> tuple[JobStatus, Optional[Drift]]: + try: + response_json = response.json() + job_status = JobStatus(response_json["jobStatus"]) + if "drift" in response_json: + if self.__model_type is ModelType.BINARY: + return ( + job_status, + BinaryClassDrift.model_validate(response_json["drift"]), + ) + else: + raise ClientError( + "Unable to parse get metrics for not binary models" + ) + else: + return job_status, None + except KeyError as _: + raise ClientError(f"Unable to parse response: {response.text}") + except ValidationError as _: + raise ClientError(f"Unable to parse response: {response.text}") + + match self.__status: + case JobStatus.ERROR: + self.__drift = None + case JobStatus.SUCCEEDED: + if self.__drift is None: + _, drift = invoke( + method="GET", + url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/drift", + valid_response_code=200, + func=__callback, + ) + self.__drift = drift + case JobStatus.IMPORTING: + status, drift = invoke( + method="GET", + url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/drift", + valid_response_code=200, + func=__callback, + ) + self.__status = status + self.__drift = drift + + return self.__drift def data_quality(self): # TODO: implement get data quality diff --git a/sdk/radicalbit_platform_sdk/models/__init__.py b/sdk/radicalbit_platform_sdk/models/__init__.py index 418b8ef0..c05b67ad 100644 --- a/sdk/radicalbit_platform_sdk/models/__init__.py +++ b/sdk/radicalbit_platform_sdk/models/__init__.py @@ -30,6 +30,15 @@ CategoryFrequency, CategoricalFeatureMetrics, ) +from .dataset_drift import ( + DriftAlgorithm, + FeatureDriftCalculation, + FeatureDrift, + Drift, + BinaryClassDrift, + MultiClassDrift, + RegressionDrift, +) from .column_definition import ColumnDefinition from .aws_credentials import AwsCredentials @@ -59,6 +68,13 @@ "NumericalFeatureMetrics", "CategoryFrequency", "CategoricalFeatureMetrics", + "DriftAlgorithm", + "FeatureDriftCalculation", + "FeatureDrift", + "Drift", + "BinaryClassDrift", + "MultiClassDrift", + "RegressionDrift", "PaginatedModelDefinitions", "ReferenceFileUpload", "CurrentFileUpload", diff --git a/sdk/radicalbit_platform_sdk/models/dataset_drift.py b/sdk/radicalbit_platform_sdk/models/dataset_drift.py new file mode 100644 index 00000000..c69524bb --- /dev/null +++ b/sdk/radicalbit_platform_sdk/models/dataset_drift.py @@ -0,0 +1,43 @@ +from enum import Enum +from typing import List, Optional + +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel + + +class DriftAlgorithm(str, Enum): + KS = "KS" + CHI2 = "CHI2" + + +class FeatureDriftCalculation(BaseModel): + type: DriftAlgorithm + value: Optional[float] = None + has_drift: bool + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class FeatureDrift(BaseModel): + feature_name: str + drift_calc: FeatureDriftCalculation + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class Drift(BaseModel): + pass + + +class BinaryClassDrift(Drift): + feature_metrics: List[FeatureDrift] + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class MultiClassDrift(Drift): + pass + + +class RegressionDrift(BaseModel): + pass diff --git a/sdk/tests/apis/model_current_dataset_test.py b/sdk/tests/apis/model_current_dataset_test.py index b547986b..9e7ef6a6 100644 --- a/sdk/tests/apis/model_current_dataset_test.py +++ b/sdk/tests/apis/model_current_dataset_test.py @@ -1,5 +1,5 @@ from radicalbit_platform_sdk.apis import ModelCurrentDataset -from radicalbit_platform_sdk.models import CurrentFileUpload, ModelType, JobStatus +from radicalbit_platform_sdk.models import CurrentFileUpload, ModelType, JobStatus, DriftAlgorithm from radicalbit_platform_sdk.errors import ClientError import responses import unittest @@ -129,3 +129,121 @@ def test_statistics_key_error(self): with self.assertRaises(ClientError): model_reference_dataset.statistics() + + @responses.activate + def test_drift_ok(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_reference_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + correlation_id_column="column", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift", + "status": 200, + "body": """{ + "jobStatus": "SUCCEEDED", + "drift": { + "featureMetrics": [ + { + "featureName": "gender", + "driftCalc": {"type": "CHI2", "value": 0.87, "hasDrift": true} + }, + { + "featureName": "city", + "driftCalc": {"type": "CHI2", "value": 0.12, "hasDrift": false} + }, + { + "featureName": "age", + "driftCalc": {"type": "KS", "value": 0.92, "hasDrift": true} + } + ] + } + }""", + } + ) + + drift = model_reference_dataset.drift() + + assert len(drift.feature_metrics) == 3 + assert drift.feature_metrics[1].feature_name == "city" + assert drift.feature_metrics[1].drift_calc.type == DriftAlgorithm.CHI2 + assert drift.feature_metrics[1].drift_calc.value == 0.12 + assert drift.feature_metrics[1].drift_calc.has_drift is False + assert drift.feature_metrics[2].feature_name == "age" + assert drift.feature_metrics[2].drift_calc.type == DriftAlgorithm.KS + assert drift.feature_metrics[2].drift_calc.value == 0.92 + assert drift.feature_metrics[2].drift_calc.has_drift is True + assert model_reference_dataset.status() == JobStatus.SUCCEEDED + + @responses.activate + def test_drift_validation_error(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_reference_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + correlation_id_column="column", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift", + "status": 200, + "body": '{"statistics": "wrong"}', + } + ) + + with self.assertRaises(ClientError): + model_reference_dataset.drift() + + @responses.activate + def test_drift_key_error(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_reference_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + correlation_id_column="column", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift", + "status": 200, + "body": '{"wrong": "json"}', + } + ) + + with self.assertRaises(ClientError): + model_reference_dataset.drift()