Skip to content

Commit

Permalink
feat(sdk): get data quality for current dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
maocorte committed Jun 21, 2024
1 parent 41fd0cf commit 1a5f9f7
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 17 deletions.
58 changes: 55 additions & 3 deletions sdk/radicalbit_platform_sdk/apis/model_current_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
BinaryClassDrift,
BinaryClassificationDataQuality,
CurrentFileUpload,
DataQuality,
DatasetStats,
Drift,
JobStatus,
Expand Down Expand Up @@ -152,9 +154,59 @@ def __callback(

return self.__drift

def data_quality(self):
# TODO: implement get data quality
pass
def data_quality(self) -> Optional[DataQuality]:
"""Get data quality metrics about the current dataset
:return: The `DataQuality` if exists
"""

def __callback(
response: requests.Response,
) -> tuple[JobStatus, Optional[DataQuality]]:
try:
response_json = response.json()
job_status = JobStatus(response_json['jobStatus'])
if 'dataQuality' in response_json:
if self.__model_type is ModelType.BINARY:
return (
job_status,
BinaryClassificationDataQuality.model_validate(
response_json['dataQuality']
),
)
raise ClientError(
'Unable to parse get metrics for not binary models'
) from None
except KeyError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e
except ValidationError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e
else:
return job_status, None

match self.__status:
case JobStatus.ERROR:
self.__data_metrics = None
case JobStatus.SUCCEEDED:
if self.__data_metrics is None:
_, metrics = invoke(
method='GET',
url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/data-quality',
valid_response_code=200,
func=__callback,
)
self.__data_metrics = metrics
case JobStatus.IMPORTING:
status, metrics = invoke(
method='GET',
url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/data-quality',
valid_response_code=200,
func=__callback,
)
self.__status = status
self.__data_metrics = metrics

return self.__data_metrics

def model_quality(self):
# TODO: implement get model quality
Expand Down
174 changes: 160 additions & 14 deletions sdk/tests/apis/model_current_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_statistics_ok(self):
numeric = 3
categorical = 6
datetime = 1
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_statistics_ok(self):
}}""",
)

stats = model_reference_dataset.statistics()
stats = model_current_dataset.statistics()

assert stats.n_variables == n_variables
assert stats.n_observations == n_observations
Expand All @@ -74,14 +74,14 @@ def test_statistics_ok(self):
assert stats.numeric == numeric
assert stats.categorical == categorical
assert stats.datetime == datetime
assert model_reference_dataset.status() == JobStatus.SUCCEEDED
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_statistics_validation_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand All @@ -102,14 +102,14 @@ def test_statistics_validation_error(self):
)

with pytest.raises(ClientError):
model_reference_dataset.statistics()
model_current_dataset.statistics()

@responses.activate
def test_statistics_key_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand All @@ -130,14 +130,14 @@ def test_statistics_key_error(self):
)

with pytest.raises(ClientError):
model_reference_dataset.statistics()
model_current_dataset.statistics()

@responses.activate
def test_drift_ok(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_drift_ok(self):
}""",
)

drift = model_reference_dataset.drift()
drift = model_current_dataset.drift()

assert len(drift.feature_metrics) == 3
assert drift.feature_metrics[1].feature_name == 'city'
Expand All @@ -186,14 +186,14 @@ def test_drift_ok(self):
assert drift.feature_metrics[2].drift_calc.type == DriftAlgorithm.KS
assert drift.feature_metrics[2].drift_calc.value == 0.92
assert drift.feature_metrics[2].drift_calc.has_drift is True
assert model_reference_dataset.status() == JobStatus.SUCCEEDED
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_drift_validation_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand All @@ -214,14 +214,14 @@ def test_drift_validation_error(self):
)

with pytest.raises(ClientError):
model_reference_dataset.drift()
model_current_dataset.drift()

@responses.activate
def test_drift_key_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand All @@ -242,4 +242,150 @@ def test_drift_key_error(self):
)

with pytest.raises(ClientError):
model_reference_dataset.drift()
model_current_dataset.drift()

@responses.activate
def test_data_quality_ok(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path='s3://bucket/file.csv',
date='2014',
correlation_id_column='column',
status=JobStatus.IMPORTING,
),
)

responses.add(
method=responses.GET,
url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/data-quality',
status=200,
body="""{
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"dataQuality": {
"nObservations": 200,
"classMetrics": [
{"name": "classA", "count": 100, "percentage": 50.0},
{"name": "classB", "count": 100, "percentage": 50.0}
],
"featureMetrics": [
{
"featureName": "age",
"type": "numerical",
"mean": 29.5,
"std": 5.2,
"min": 18,
"max": 45,
"medianMetrics": {"perc25": 25.0, "median": 29.0, "perc75": 34.0},
"missingValue": {"count": 2, "percentage": 0.02},
"classMedianMetrics": [
{
"name": "classA",
"mean": 30.0,
"medianMetrics": {"perc25": 27.0, "median": 30.0, "perc75": 33.0}
},
{
"name": "classB",
"mean": 29.0,
"medianMetrics": {"perc25": 24.0, "median": 28.0, "perc75": 32.0}
}
],
"histogram": {
"buckets": [40.0, 45.0, 50.0, 55.0, 60.0],
"referenceValues": [50, 150, 200, 150, 50],
"currentValues": [45, 140, 210, 145, 60]
}
},
{
"featureName": "gender",
"type": "categorical",
"distinctValue": 2,
"categoryFrequency": [
{"name": "male", "count": 90, "frequency": 0.45},
{"name": "female", "count": 110, "frequency": 0.55}
],
"missingValue": {"count": 0, "percentage": 0.0}
}
]
}
}""",
)

metrics = model_current_dataset.data_quality()

assert metrics.n_observations == 200
assert len(metrics.class_metrics) == 2
assert metrics.class_metrics[0].name == 'classA'
assert metrics.class_metrics[0].count == 100
assert metrics.class_metrics[0].percentage == 50.0
assert len(metrics.feature_metrics) == 2
assert metrics.feature_metrics[0].feature_name == 'age'
assert metrics.feature_metrics[0].type == 'numerical'
assert metrics.feature_metrics[0].mean == 29.5
assert metrics.feature_metrics[1].feature_name == 'gender'
assert metrics.feature_metrics[1].type == 'categorical'
assert metrics.feature_metrics[1].distinct_value == 2
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_data_quality_validation_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path='s3://bucket/file.csv',
date='2014',
correlation_id_column='column',
status=JobStatus.IMPORTING,
),
)

responses.add(
method=responses.GET,
url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/data-quality',
status=200,
body='{"dataQuality": "wrong"}',
)

with pytest.raises(ClientError):
model_current_dataset.data_quality()

@responses.activate
def test_data_quality_key_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path='s3://bucket/file.csv',
date='2014',
correlation_id_column='column',
status=JobStatus.IMPORTING,
),
)

responses.add(
method=responses.GET,
url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/data-quality',
status=200,
body='{"wrong": "json"}',
)

with pytest.raises(ClientError):
model_current_dataset.data_quality()

0 comments on commit 1a5f9f7

Please sign in to comment.