Skip to content

Commit

Permalink
feat(sdk): get current dataset statistics (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
maocorte authored Jun 21, 2024
1 parent a9c2393 commit dec2d05
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 6 deletions.
4 changes: 3 additions & 1 deletion sdk/radicalbit_platform_sdk/apis/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,9 @@ def __bind_current_dataset(
def __callback(response: requests.Response) -> ModelCurrentDataset:
try:
response = CurrentFileUpload.model_validate(response.json())
return ModelCurrentDataset(self.__base_url, response)
return ModelCurrentDataset(
self.__base_url, self.__uuid, self.__model_type, response
)
except ValidationError as _:
raise ClientError(f"Unable to parse response: {response.text}")

Expand Down
76 changes: 71 additions & 5 deletions sdk/radicalbit_platform_sdk/apis/model_current_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
from radicalbit_platform_sdk.models import CurrentFileUpload
from radicalbit_platform_sdk.commons import invoke
from radicalbit_platform_sdk.models import (
ModelType,
CurrentFileUpload,
JobStatus,
DatasetStats,
)
from radicalbit_platform_sdk.errors import ClientError
from pydantic import ValidationError
from typing import Optional
import requests
from uuid import UUID


class ModelCurrentDataset:
def __init__(self, base_url: str, upload: CurrentFileUpload) -> None:
def __init__(
self,
base_url: str,
model_uuid: UUID,
model_type: ModelType,
upload: CurrentFileUpload,
) -> None:
self.__base_url = base_url
self.__model_uuid = model_uuid
self.__model_type = model_type
self.__uuid = upload.uuid
self.__path = upload.path
self.__correlation_id_column = upload.correlation_id_column
self.__date = upload.date
self.__status = upload.status
self.__statistics = None
self.__model_metrics = None
self.__data_metrics = None
self.__drift = None

def uuid(self) -> UUID:
return self.__uuid
Expand All @@ -26,9 +48,53 @@ def date(self) -> str:
def status(self) -> str:
return self.__status

def statistics(self):
# TODO: implement get statistics
pass
def statistics(self) -> Optional[DatasetStats]:
"""
Get statistics about the current dataset
:return: The `DatasetStats` if exists
"""

def __callback(
response: requests.Response,
) -> tuple[JobStatus, Optional[DatasetStats]]:
try:
response_json = response.json()
job_status = JobStatus(response_json["jobStatus"])
if "statistics" in response_json:
return job_status, DatasetStats.model_validate(
response_json["statistics"]
)
else:
return job_status, None
except KeyError as _:
raise ClientError(f"Unable to parse response: {response.text}")
except ValidationError as _:
raise ClientError(f"Unable to parse response: {response.text}")

match self.__status:
case JobStatus.ERROR:
self.__statistics = None
case JobStatus.SUCCEEDED:
if self.__statistics is None:
_, stats = invoke(
method="GET",
url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/statistics",
valid_response_code=200,
func=__callback,
)
self.__statistics = stats
case JobStatus.IMPORTING:
status, stats = invoke(
method="GET",
url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/statistics",
valid_response_code=200,
func=__callback,
)
self.__status = status
self.__statistics = stats

return self.__statistics

def drift(self):
# TODO: implement get drift
Expand Down
131 changes: 131 additions & 0 deletions sdk/tests/apis/model_current_dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from radicalbit_platform_sdk.apis import ModelCurrentDataset
from radicalbit_platform_sdk.models import CurrentFileUpload, ModelType, JobStatus
from radicalbit_platform_sdk.errors import ClientError
import responses
import unittest
import uuid


class ModelCurrentDatasetTest(unittest.TestCase):
@responses.activate
def test_statistics_ok(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
n_variables = 10
n_observations = 1000
missing_cells = 10
missing_cells_perc = 1
duplicate_rows = 10
duplicate_rows_perc = 1
numeric = 3
categorical = 6
datetime = 1
model_reference_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
correlation_id_column="column",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics",
"status": 200,
"body": f"""{{
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"statistics": {{
"nVariables": {n_variables},
"nObservations": {n_observations},
"missingCells": {missing_cells},
"missingCellsPerc": {missing_cells_perc},
"duplicateRows": {duplicate_rows},
"duplicateRowsPerc": {duplicate_rows_perc},
"numeric": {numeric},
"categorical": {categorical},
"datetime": {datetime}
}}
}}""",
}
)

stats = model_reference_dataset.statistics()

assert stats.n_variables == n_variables
assert stats.n_observations == n_observations
assert stats.missing_cells == missing_cells
assert stats.missing_cells_perc == missing_cells_perc
assert stats.duplicate_rows == duplicate_rows
assert stats.duplicate_rows_perc == duplicate_rows_perc
assert stats.numeric == numeric
assert stats.categorical == categorical
assert stats.datetime == datetime
assert model_reference_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_statistics_validation_error(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
correlation_id_column="column",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics",
"status": 200,
"body": '{"statistics": "wrong"}',
}
)

with self.assertRaises(ClientError):
model_reference_dataset.statistics()

@responses.activate
def test_statistics_key_error(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
correlation_id_column="column",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics",
"status": 200,
"body": '{"wrong": "json"}',
}
)

with self.assertRaises(ClientError):
model_reference_dataset.statistics()

0 comments on commit dec2d05

Please sign in to comment.