Skip to content

Commit

Permalink
feat: (sdk) define single class for model quality metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
dtria91 committed Jun 28, 2024
1 parent 6eb62cf commit c778951
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 52 deletions.
14 changes: 3 additions & 11 deletions sdk/radicalbit_platform_sdk/apis/model_current_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
BinaryClassDrift,
BinaryClassificationDataQuality,
ClassificationDataQuality,
CurrentBinaryClassificationModelQuality,
CurrentFileUpload,
CurrentMultiClassificationModelQuality,
Expand All @@ -18,7 +18,6 @@
JobStatus,
ModelQuality,
ModelType,
MultiClassDataQuality,
MultiClassDrift,
RegressionDataQuality,
RegressionDrift,
Expand Down Expand Up @@ -188,17 +187,10 @@ def __callback(
job_status = JobStatus(response_json['jobStatus'])
if 'dataQuality' in response_json:
match self.__model_type:
case ModelType.BINARY:
return (
job_status,
BinaryClassificationDataQuality.model_validate(
response_json['dataQuality']
),
)
case ModelType.MULTI_CLASS:
case ModelType.BINARY | ModelType.MULTI_CLASS:
return (
job_status,
MultiClassDataQuality.model_validate(
ClassificationDataQuality.model_validate(
response_json['dataQuality']
),
)
Expand Down
14 changes: 3 additions & 11 deletions sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
from radicalbit_platform_sdk.commons import invoke
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
BinaryClassificationDataQuality,
ClassificationDataQuality,
BinaryClassificationModelQuality,
DataQuality,
DatasetStats,
JobStatus,
ModelQuality,
ModelType,
MultiClassDataQuality,
MultiClassificationModelQuality,
ReferenceFileUpload,
RegressionDataQuality,
Expand Down Expand Up @@ -114,17 +113,10 @@ def __callback(
job_status = JobStatus(response_json['jobStatus'])
if 'dataQuality' in response_json:
match self.__model_type:
case ModelType.BINARY:
return (
job_status,
BinaryClassificationDataQuality.model_validate(
response_json['dataQuality']
),
)
case ModelType.MULTI_CLASS:
case ModelType.BINARY | ModelType.MULTI_CLASS:
return (
job_status,
MultiClassDataQuality.model_validate(
ClassificationDataQuality.model_validate(
response_json['dataQuality']
),
)
Expand Down
6 changes: 2 additions & 4 deletions sdk/radicalbit_platform_sdk/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .column_definition import ColumnDefinition
from .data_type import DataType
from .dataset_data_quality import (
BinaryClassificationDataQuality,
ClassificationDataQuality,
CategoricalFeatureMetrics,
CategoryFrequency,
ClassMedianMetrics,
Expand All @@ -11,7 +11,6 @@
FeatureMetrics,
MedianMetrics,
MissingValue,
MultiClassDataQuality,
NumericalFeatureMetrics,
RegressionDataQuality,
)
Expand Down Expand Up @@ -60,8 +59,7 @@
'CurrentMultiClassificationModelQuality',
'RegressionModelQuality',
'DataQuality',
'BinaryClassificationDataQuality',
'MultiClassDataQuality',
'ClassificationDataQuality',
'RegressionDataQuality',
'ClassMetrics',
'MedianMetrics',
Expand Down
6 changes: 1 addition & 5 deletions sdk/radicalbit_platform_sdk/models/dataset_data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class DataQuality(BaseModel):
pass


class BinaryClassificationDataQuality(DataQuality):
class ClassificationDataQuality(DataQuality):
n_observations: int
class_metrics: List[ClassMetrics]
feature_metrics: List[Union[NumericalFeatureMetrics, CategoricalFeatureMetrics]]
Expand All @@ -96,9 +96,5 @@ class BinaryClassificationDataQuality(DataQuality):
)


class MultiClassDataQuality(DataQuality):
pass


class RegressionDataQuality(DataQuality):
pass
78 changes: 66 additions & 12 deletions sdk/tests/apis/model_current_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,18 @@
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
BinaryClassDrift,
BinaryClassificationDataQuality,
ClassificationDataQuality,
CurrentBinaryClassificationModelQuality,
CurrentFileUpload,
CurrentMultiClassificationModelQuality,
DriftAlgorithm,
JobStatus,
ModelType,
MultiClassDataQuality,
MultiClassDrift,
RegressionDataQuality,
RegressionDrift,
RegressionModelQuality,
)
from radicalbit_platform_sdk.models.dataset_model_quality import (
CurrentMultiClassificationModelQuality,
)


class ModelCurrentDatasetTest(unittest.TestCase):
Expand Down Expand Up @@ -401,7 +398,7 @@ def test_binary_class_data_quality_ok(self):

metrics = model_current_dataset.data_quality()

assert isinstance(metrics, BinaryClassificationDataQuality)
assert isinstance(metrics, ClassificationDataQuality)

assert metrics.n_observations == 200
assert len(metrics.class_metrics) == 2
Expand Down Expand Up @@ -440,16 +437,73 @@ def test_multi_class_data_quality_ok(self):
url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/data-quality',
status=200,
body="""{
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"dataQuality": {}
}""",
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"dataQuality": {
"nObservations": 200,
"classMetrics": [
{"name": "classA", "count": 100, "percentage": 50.0},
{"name": "classB", "count": 100, "percentage": 50.0}
],
"featureMetrics": [
{
"featureName": "age",
"type": "numerical",
"mean": 29.5,
"std": 5.2,
"min": 18,
"max": 45,
"medianMetrics": {"perc25": 25.0, "median": 29.0, "perc75": 34.0},
"missingValue": {"count": 2, "percentage": 0.02},
"classMedianMetrics": [
{
"name": "classA",
"mean": 30.0,
"medianMetrics": {"perc25": 27.0, "median": 30.0, "perc75": 33.0}
},
{
"name": "classB",
"mean": 29.0,
"medianMetrics": {"perc25": 24.0, "median": 28.0, "perc75": 32.0}
}
],
"histogram": {
"buckets": [40.0, 45.0, 50.0, 55.0, 60.0],
"referenceValues": [50, 150, 200, 150, 50],
"currentValues": [45, 140, 210, 145, 60]
}
},
{
"featureName": "gender",
"type": "categorical",
"distinctValue": 2,
"categoryFrequency": [
{"name": "male", "count": 90, "frequency": 0.45},
{"name": "female", "count": 110, "frequency": 0.55}
],
"missingValue": {"count": 0, "percentage": 0.0}
}
]
}
}""",
)

metrics = model_current_dataset.data_quality()

assert isinstance(metrics, MultiClassDataQuality)
# TODO: add asserts to properties
assert isinstance(metrics, ClassificationDataQuality)

assert metrics.n_observations == 200
assert len(metrics.class_metrics) == 2
assert metrics.class_metrics[0].name == 'classA'
assert metrics.class_metrics[0].count == 100
assert metrics.class_metrics[0].percentage == 50.0
assert len(metrics.feature_metrics) == 2
assert metrics.feature_metrics[0].feature_name == 'age'
assert metrics.feature_metrics[0].type == 'numerical'
assert metrics.feature_metrics[0].mean == 29.5
assert metrics.feature_metrics[1].feature_name == 'gender'
assert metrics.feature_metrics[1].type == 'categorical'
assert metrics.feature_metrics[1].distinct_value == 2
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
Expand Down
74 changes: 65 additions & 9 deletions sdk/tests/apis/model_reference_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from radicalbit_platform_sdk.apis import ModelReferenceDataset
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
BinaryClassificationDataQuality,
ClassificationDataQuality,
BinaryClassificationModelQuality,
JobStatus,
ModelType,
MultiClassDataQuality,
MultiClassificationModelQuality,
ReferenceFileUpload,
RegressionDataQuality,
Expand Down Expand Up @@ -511,7 +510,7 @@ def test_binary_class_data_quality_ok(self):

metrics = model_reference_dataset.data_quality()

assert isinstance(metrics, BinaryClassificationDataQuality)
assert isinstance(metrics, ClassificationDataQuality)

assert metrics.n_observations == 200
assert len(metrics.class_metrics) == 2
Expand Down Expand Up @@ -549,16 +548,73 @@ def test_multi_class_data_quality_ok(self):
url=f'{base_url}/api/models/{str(model_id)}/reference/data-quality',
status=200,
body="""{
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"dataQuality": {}
}""",
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"dataQuality": {
"nObservations": 200,
"classMetrics": [
{"name": "classA", "count": 100, "percentage": 50.0},
{"name": "classB", "count": 100, "percentage": 50.0}
],
"featureMetrics": [
{
"featureName": "age",
"type": "numerical",
"mean": 29.5,
"std": 5.2,
"min": 18,
"max": 45,
"medianMetrics": {"perc25": 25.0, "median": 29.0, "perc75": 34.0},
"missingValue": {"count": 2, "percentage": 0.02},
"classMedianMetrics": [
{
"name": "classA",
"mean": 30.0,
"medianMetrics": {"perc25": 27.0, "median": 30.0, "perc75": 33.0}
},
{
"name": "classB",
"mean": 29.0,
"medianMetrics": {"perc25": 24.0, "median": 28.0, "perc75": 32.0}
}
],
"histogram": {
"buckets": [40.0, 45.0, 50.0, 55.0, 60.0],
"referenceValues": [50, 150, 200, 150, 50],
"currentValues": [45, 140, 210, 145, 60]
}
},
{
"featureName": "gender",
"type": "categorical",
"distinctValue": 2,
"categoryFrequency": [
{"name": "male", "count": 90, "frequency": 0.45},
{"name": "female", "count": 110, "frequency": 0.55}
],
"missingValue": {"count": 0, "percentage": 0.0}
}
]
}
}""",
)

metrics = model_reference_dataset.data_quality()

assert isinstance(metrics, MultiClassDataQuality)
# TODO: add asserts to properties
assert isinstance(metrics, ClassificationDataQuality)

assert metrics.n_observations == 200
assert len(metrics.class_metrics) == 2
assert metrics.class_metrics[0].name == 'classA'
assert metrics.class_metrics[0].count == 100
assert metrics.class_metrics[0].percentage == 50.0
assert len(metrics.feature_metrics) == 2
assert metrics.feature_metrics[0].feature_name == 'age'
assert metrics.feature_metrics[0].type == 'numerical'
assert metrics.feature_metrics[0].mean == 29.5
assert metrics.feature_metrics[1].feature_name == 'gender'
assert metrics.feature_metrics[1].type == 'categorical'
assert metrics.feature_metrics[1].distinct_value == 2
assert model_reference_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
Expand Down

0 comments on commit c778951

Please sign in to comment.