diff --git a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py index 88dcbc7f..a39fc0ed 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py @@ -20,6 +20,7 @@ ModelType, RegressionDataQuality, ) +from radicalbit_platform_sdk.models.dataset_percentages import Percentages class ModelCurrentDataset: @@ -42,6 +43,7 @@ def __init__( self.__model_metrics = None self.__data_metrics = None self.__drift = None + self.__percentages = None def uuid(self) -> UUID: return self.__uuid @@ -108,6 +110,56 @@ def __callback( return self.__statistics + def percentages(self) -> Optional[Percentages]: + """Get percentages about the actual dataset + + :return: The `Percentages` if exists + """ + + def __callback( + response: requests.Response, + ) -> tuple[JobStatus, Optional[Percentages]]: + try: + response_json = response.json() + job_status = JobStatus(response_json['jobStatus']) + if 'percentages' in response_json: + return ( + job_status, + Percentages.model_validate(response_json['percentages']), + ) + except KeyError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + else: + return job_status, None + + match self.__status: + case JobStatus.ERROR: + self.__percentages = None + case JobStatus.MISSING_CURRENT: + self.__percentages = None + case JobStatus.SUCCEEDED: + if self.__percentages is None: + _, percentages = invoke( + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/percentages', + valid_response_code=200, + func=__callback, + ) + self.__percentages = percentages + case JobStatus.IMPORTING: + status, percentages = invoke( + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/percentages', + valid_response_code=200, + func=__callback, + ) + self.__status = status + self.__percentages = percentages + + return self.__percentages + def drift(self) -> Optional[Drift]: """Get drift about the actual dataset diff --git a/sdk/radicalbit_platform_sdk/models/dataset_percentages.py b/sdk/radicalbit_platform_sdk/models/dataset_percentages.py new file mode 100644 index 00000000..225553aa --- /dev/null +++ b/sdk/radicalbit_platform_sdk/models/dataset_percentages.py @@ -0,0 +1,23 @@ +from typing import List, Optional + +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel + + +class DetailPercentage(BaseModel): + feature_name: str + score: float + + +class MetricPercentage(BaseModel): + value: float + details: List[Optional[DetailPercentage]] = None + + +class Percentages(BaseModel): + data_quality: MetricPercentage + model_quality: MetricPercentage + drift: MetricPercentage + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + diff --git a/sdk/tests/apis/model_current_dataset_test.py b/sdk/tests/apis/model_current_dataset_test.py index 4ef64a3f..47901291 100644 --- a/sdk/tests/apis/model_current_dataset_test.py +++ b/sdk/tests/apis/model_current_dataset_test.py @@ -19,6 +19,7 @@ ModelType, RegressionDataQuality, ) +from radicalbit_platform_sdk.models.dataset_percentages import Percentages class ModelCurrentDatasetTest(unittest.TestCase): @@ -139,6 +140,137 @@ def test_statistics_key_error(self): with pytest.raises(ClientError): model_current_dataset.statistics() + @responses.activate + def test_percentages_ok(self): + base_url = 'http://api:9000' + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_current_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path='s3://bucket/file.csv', + date='2014', + correlation_id_column='column', + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/percentages', + status=200, + body="""{ + "jobStatus": "SUCCEEDED", + "percentages": { + "data_quality": { + "value": 0.9, + "details": [ + { + "feature_name": "num1", + "score": 0.4 + }, + { + "feature_name": "num2", + "score": 0.0 + }, + { + "feature_name": "cat1", + "score": 0.0 + }, + { + "feature_name": "cat2", + "score": 0.0 + } + ] + }, + "model_quality": { + "value": -1, + "details": [] + }, + "drift": { + "value": 0.75, + "details": [ + { + "feature_name": "num1", + "score": 1.0 + } + ] + } + } + }""", + ) + + percentages = model_current_dataset.percentages() + + assert isinstance(percentages, Percentages) + + assert percentages.data_quality.value == 0.9 + assert len(percentages.data_quality.details) == 4 + assert percentages.model_quality.value == -1 + assert len(percentages.model_quality.details) == 0 + assert percentages.drift.value == 0.75 + assert len(percentages.drift.details) == 1 + assert model_current_dataset.status() == JobStatus.SUCCEEDED + + @responses.activate + def test_percentages_validation_error(self): + base_url = 'http://api:9000' + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_current_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path='s3://bucket/file.csv', + date='2014', + correlation_id_column='column', + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/percentages', + status=200, + body='{"statistics": "wrong"}', + ) + + with pytest.raises(ClientError): + model_current_dataset.percentages() + + @responses.activate + def test_percentages_key_error(self): + base_url = 'http://api:9000' + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_current_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path='s3://bucket/file.csv', + date='2014', + correlation_id_column='column', + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/percentages', + status=200, + body='{"wrong": "json"}', + ) + + with pytest.raises(ClientError): + model_current_dataset.percentages() + @responses.activate def test_drift_ok(self): base_url = 'http://api:9000'