Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk): get data quality for current dataset #19

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions sdk/radicalbit_platform_sdk/apis/model_current_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
BinaryClassDrift,
BinaryClassificationDataQuality,
CurrentFileUpload,
DataQuality,
DatasetStats,
Drift,
JobStatus,
Expand Down Expand Up @@ -152,9 +154,59 @@ def __callback(

return self.__drift

def data_quality(self):
# TODO: implement get data quality
pass
def data_quality(self) -> Optional[DataQuality]:
"""Get data quality metrics about the current dataset

:return: The `DataQuality` if exists
"""

def __callback(
response: requests.Response,
) -> tuple[JobStatus, Optional[DataQuality]]:
try:
response_json = response.json()
job_status = JobStatus(response_json['jobStatus'])
if 'dataQuality' in response_json:
if self.__model_type is ModelType.BINARY:
return (
job_status,
BinaryClassificationDataQuality.model_validate(
response_json['dataQuality']
),
)
raise ClientError(
'Unable to parse get metrics for not binary models'
) from None
except KeyError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e
except ValidationError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e
else:
return job_status, None

match self.__status:
case JobStatus.ERROR:
self.__data_metrics = None
case JobStatus.SUCCEEDED:
if self.__data_metrics is None:
_, metrics = invoke(
method='GET',
url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/data-quality',
valid_response_code=200,
func=__callback,
)
self.__data_metrics = metrics
case JobStatus.IMPORTING:
status, metrics = invoke(
method='GET',
url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/data-quality',
valid_response_code=200,
func=__callback,
)
self.__status = status
self.__data_metrics = metrics

return self.__data_metrics

def model_quality(self):
# TODO: implement get model quality
Expand Down
174 changes: 160 additions & 14 deletions sdk/tests/apis/model_current_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_statistics_ok(self):
numeric = 3
categorical = 6
datetime = 1
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_statistics_ok(self):
}}""",
)

stats = model_reference_dataset.statistics()
stats = model_current_dataset.statistics()

assert stats.n_variables == n_variables
assert stats.n_observations == n_observations
Expand All @@ -74,14 +74,14 @@ def test_statistics_ok(self):
assert stats.numeric == numeric
assert stats.categorical == categorical
assert stats.datetime == datetime
assert model_reference_dataset.status() == JobStatus.SUCCEEDED
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_statistics_validation_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand All @@ -102,14 +102,14 @@ def test_statistics_validation_error(self):
)

with pytest.raises(ClientError):
model_reference_dataset.statistics()
model_current_dataset.statistics()

@responses.activate
def test_statistics_key_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand All @@ -130,14 +130,14 @@ def test_statistics_key_error(self):
)

with pytest.raises(ClientError):
model_reference_dataset.statistics()
model_current_dataset.statistics()

@responses.activate
def test_drift_ok(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_drift_ok(self):
}""",
)

drift = model_reference_dataset.drift()
drift = model_current_dataset.drift()

assert len(drift.feature_metrics) == 3
assert drift.feature_metrics[1].feature_name == 'city'
Expand All @@ -186,14 +186,14 @@ def test_drift_ok(self):
assert drift.feature_metrics[2].drift_calc.type == DriftAlgorithm.KS
assert drift.feature_metrics[2].drift_calc.value == 0.92
assert drift.feature_metrics[2].drift_calc.has_drift is True
assert model_reference_dataset.status() == JobStatus.SUCCEEDED
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_drift_validation_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand All @@ -214,14 +214,14 @@ def test_drift_validation_error(self):
)

with pytest.raises(ClientError):
model_reference_dataset.drift()
model_current_dataset.drift()

@responses.activate
def test_drift_key_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
Expand All @@ -242,4 +242,150 @@ def test_drift_key_error(self):
)

with pytest.raises(ClientError):
model_reference_dataset.drift()
model_current_dataset.drift()

@responses.activate
def test_data_quality_ok(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path='s3://bucket/file.csv',
date='2014',
correlation_id_column='column',
status=JobStatus.IMPORTING,
),
)

responses.add(
method=responses.GET,
url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/data-quality',
status=200,
body="""{
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"dataQuality": {
"nObservations": 200,
"classMetrics": [
{"name": "classA", "count": 100, "percentage": 50.0},
{"name": "classB", "count": 100, "percentage": 50.0}
],
"featureMetrics": [
{
"featureName": "age",
"type": "numerical",
"mean": 29.5,
"std": 5.2,
"min": 18,
"max": 45,
"medianMetrics": {"perc25": 25.0, "median": 29.0, "perc75": 34.0},
"missingValue": {"count": 2, "percentage": 0.02},
"classMedianMetrics": [
{
"name": "classA",
"mean": 30.0,
"medianMetrics": {"perc25": 27.0, "median": 30.0, "perc75": 33.0}
},
{
"name": "classB",
"mean": 29.0,
"medianMetrics": {"perc25": 24.0, "median": 28.0, "perc75": 32.0}
}
],
"histogram": {
"buckets": [40.0, 45.0, 50.0, 55.0, 60.0],
"referenceValues": [50, 150, 200, 150, 50],
"currentValues": [45, 140, 210, 145, 60]
}
},
{
"featureName": "gender",
"type": "categorical",
"distinctValue": 2,
"categoryFrequency": [
{"name": "male", "count": 90, "frequency": 0.45},
{"name": "female", "count": 110, "frequency": 0.55}
],
"missingValue": {"count": 0, "percentage": 0.0}
}
]
}
}""",
)

metrics = model_current_dataset.data_quality()

assert metrics.n_observations == 200
assert len(metrics.class_metrics) == 2
assert metrics.class_metrics[0].name == 'classA'
assert metrics.class_metrics[0].count == 100
assert metrics.class_metrics[0].percentage == 50.0
assert len(metrics.feature_metrics) == 2
assert metrics.feature_metrics[0].feature_name == 'age'
assert metrics.feature_metrics[0].type == 'numerical'
assert metrics.feature_metrics[0].mean == 29.5
assert metrics.feature_metrics[1].feature_name == 'gender'
assert metrics.feature_metrics[1].type == 'categorical'
assert metrics.feature_metrics[1].distinct_value == 2
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_data_quality_validation_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path='s3://bucket/file.csv',
date='2014',
correlation_id_column='column',
status=JobStatus.IMPORTING,
),
)

responses.add(
method=responses.GET,
url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/data-quality',
status=200,
body='{"dataQuality": "wrong"}',
)

with pytest.raises(ClientError):
model_current_dataset.data_quality()

@responses.activate
def test_data_quality_key_error(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path='s3://bucket/file.csv',
date='2014',
correlation_id_column='column',
status=JobStatus.IMPORTING,
),
)

responses.add(
method=responses.GET,
url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/data-quality',
status=200,
body='{"wrong": "json"}',
)

with pytest.raises(ClientError):
model_current_dataset.data_quality()