diff --git a/sdk/radicalbit_platform_sdk/apis/model.py b/sdk/radicalbit_platform_sdk/apis/model.py index bf621a05..3564a37f 100644 --- a/sdk/radicalbit_platform_sdk/apis/model.py +++ b/sdk/radicalbit_platform_sdk/apis/model.py @@ -388,7 +388,9 @@ def __bind_current_dataset( def __callback(response: requests.Response) -> ModelCurrentDataset: try: response = CurrentFileUpload.model_validate(response.json()) - return ModelCurrentDataset(self.__base_url, response) + return ModelCurrentDataset( + self.__base_url, self.__uuid, self.__model_type, response + ) except ValidationError as _: raise ClientError(f"Unable to parse response: {response.text}") diff --git a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py index 8200749a..0c18f928 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py @@ -1,15 +1,37 @@ -from radicalbit_platform_sdk.models import CurrentFileUpload +from radicalbit_platform_sdk.commons import invoke +from radicalbit_platform_sdk.models import ( + ModelType, + CurrentFileUpload, + JobStatus, + DatasetStats, +) +from radicalbit_platform_sdk.errors import ClientError +from pydantic import ValidationError +from typing import Optional +import requests from uuid import UUID class ModelCurrentDataset: - def __init__(self, base_url: str, upload: CurrentFileUpload) -> None: + def __init__( + self, + base_url: str, + model_uuid: UUID, + model_type: ModelType, + upload: CurrentFileUpload, + ) -> None: self.__base_url = base_url + self.__model_uuid = model_uuid + self.__model_type = model_type self.__uuid = upload.uuid self.__path = upload.path self.__correlation_id_column = upload.correlation_id_column self.__date = upload.date self.__status = upload.status + self.__statistics = None + self.__model_metrics = None + self.__data_metrics = None + self.__drift = None def uuid(self) -> UUID: return self.__uuid @@ -26,9 +48,53 @@ def date(self) -> str: def status(self) -> str: return self.__status - def statistics(self): - # TODO: implement get statistics - pass + def statistics(self) -> Optional[DatasetStats]: + """ + Get statistics about the current dataset + + :return: The `DatasetStats` if exists + """ + + def __callback( + response: requests.Response, + ) -> tuple[JobStatus, Optional[DatasetStats]]: + try: + response_json = response.json() + job_status = JobStatus(response_json["jobStatus"]) + if "statistics" in response_json: + return job_status, DatasetStats.model_validate( + response_json["statistics"] + ) + else: + return job_status, None + except KeyError as _: + raise ClientError(f"Unable to parse response: {response.text}") + except ValidationError as _: + raise ClientError(f"Unable to parse response: {response.text}") + + match self.__status: + case JobStatus.ERROR: + self.__statistics = None + case JobStatus.SUCCEEDED: + if self.__statistics is None: + _, stats = invoke( + method="GET", + url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/statistics", + valid_response_code=200, + func=__callback, + ) + self.__statistics = stats + case JobStatus.IMPORTING: + status, stats = invoke( + method="GET", + url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/statistics", + valid_response_code=200, + func=__callback, + ) + self.__status = status + self.__statistics = stats + + return self.__statistics def drift(self): # TODO: implement get drift diff --git a/sdk/tests/apis/model_current_dataset_test.py b/sdk/tests/apis/model_current_dataset_test.py new file mode 100644 index 00000000..b547986b --- /dev/null +++ b/sdk/tests/apis/model_current_dataset_test.py @@ -0,0 +1,131 @@ +from radicalbit_platform_sdk.apis import ModelCurrentDataset +from radicalbit_platform_sdk.models import CurrentFileUpload, ModelType, JobStatus +from radicalbit_platform_sdk.errors import ClientError +import responses +import unittest +import uuid + + +class ModelCurrentDatasetTest(unittest.TestCase): + @responses.activate + def test_statistics_ok(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + n_variables = 10 + n_observations = 1000 + missing_cells = 10 + missing_cells_perc = 1 + duplicate_rows = 10 + duplicate_rows_perc = 1 + numeric = 3 + categorical = 6 + datetime = 1 + model_reference_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + correlation_id_column="column", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics", + "status": 200, + "body": f"""{{ + "datetime": "something_not_used", + "jobStatus": "SUCCEEDED", + "statistics": {{ + "nVariables": {n_variables}, + "nObservations": {n_observations}, + "missingCells": {missing_cells}, + "missingCellsPerc": {missing_cells_perc}, + "duplicateRows": {duplicate_rows}, + "duplicateRowsPerc": {duplicate_rows_perc}, + "numeric": {numeric}, + "categorical": {categorical}, + "datetime": {datetime} + }} + }}""", + } + ) + + stats = model_reference_dataset.statistics() + + assert stats.n_variables == n_variables + assert stats.n_observations == n_observations + assert stats.missing_cells == missing_cells + assert stats.missing_cells_perc == missing_cells_perc + assert stats.duplicate_rows == duplicate_rows + assert stats.duplicate_rows_perc == duplicate_rows_perc + assert stats.numeric == numeric + assert stats.categorical == categorical + assert stats.datetime == datetime + assert model_reference_dataset.status() == JobStatus.SUCCEEDED + + @responses.activate + def test_statistics_validation_error(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_reference_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + correlation_id_column="column", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics", + "status": 200, + "body": '{"statistics": "wrong"}', + } + ) + + with self.assertRaises(ClientError): + model_reference_dataset.statistics() + + @responses.activate + def test_statistics_key_error(self): + base_url = "http://api:9000" + model_id = uuid.uuid4() + import_uuid = uuid.uuid4() + model_reference_dataset = ModelCurrentDataset( + base_url, + model_id, + ModelType.BINARY, + CurrentFileUpload( + uuid=import_uuid, + path="s3://bucket/file.csv", + date="2014", + correlation_id_column="column", + status=JobStatus.IMPORTING, + ), + ) + + responses.add( + **{ + "method": responses.GET, + "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics", + "status": 200, + "body": '{"wrong": "json"}', + } + ) + + with self.assertRaises(ClientError): + model_reference_dataset.statistics()