Skip to content

Commit

Permalink
feat(sdk): align reference metrics business models with API
Browse files Browse the repository at this point in the history
  • Loading branch information
maocorte committed Jun 21, 2024
1 parent c9c37ff commit 8479099
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 51 deletions.
5 changes: 4 additions & 1 deletion sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,11 @@ def data_quality(self) -> Optional[DataQuality]:
:return: The `DataQuality` if exists
"""

def __callback(response: requests.Response) -> Optional[DataQuality]:
def __callback(
response: requests.Response,
) -> tuple[JobStatus, Optional[DataQuality]]:
try:
print(response.text)
response_json = response.json()
job_status = JobStatus(response_json["jobStatus"])
if "dataQuality" in response_json:
Expand Down
66 changes: 30 additions & 36 deletions sdk/radicalbit_platform_sdk/models/dataset_data_quality.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,82 @@
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
from typing import List
from typing import List, Optional, Union


class ClassMetrics(BaseModel):
name: str
count: int
percentage: float
percentage: Optional[float] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
model_config = ConfigDict(populate_by_name=True)


class MedianMetrics(BaseModel):
perc_25: float
median: float
perc_75: float
perc_25: Optional[float] = None
median: Optional[float] = None
perc_75: Optional[float] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MissingValue(BaseModel):
count: int
percentage: float
percentage: Optional[float] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True)


class ClassMedianMetrics(BaseModel):
name: str
mean: float
mean: Optional[float] = None
median_metrics: MedianMetrics

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class FeatureMetrics(BaseModel):
feature_name: str
type: str
missing_value: MissingValue

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class Histogram(BaseModel):
buckets: List[float]
reference_values: List[int]
current_values: Optional[List[int]] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class NumericalFeatureMetrics(FeatureMetrics):
type: str = "numerical"
mean: float
std: float
min: float
max: float
mean: Optional[float] = None
std: Optional[float] = None
min: Optional[float] = None
max: Optional[float] = None
median_metrics: MedianMetrics
class_median_metrics: List[ClassMedianMetrics]
histogram: Histogram

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CategoryFrequency(BaseModel):
name: str
count: int
frequency: float
frequency: Optional[float] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True)


class CategoricalFeatureMetrics(FeatureMetrics):
type: str = "categorical"
category_frequency: List[CategoryFrequency]
distinct_value: int

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class DataQuality(BaseModel):
Expand All @@ -91,13 +86,12 @@ class DataQuality(BaseModel):
class BinaryClassificationDataQuality(DataQuality):
n_observations: int
class_metrics: List[ClassMetrics]
feature_metrics: List[FeatureMetrics]
feature_metrics: List[Union[NumericalFeatureMetrics, CategoricalFeatureMetrics]]

model_config = ConfigDict(
arbitrary_types_allowed=True,
populate_by_name=True,
alias_generator=to_camel,
protected_namespaces=(),
)


Expand Down
24 changes: 12 additions & 12 deletions sdk/radicalbit_platform_sdk/models/dataset_model_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ class ModelQuality(BaseModel):


class BinaryClassificationModelQuality(ModelQuality):
f1: float
accuracy: float
precision: float
recall: float
f_measure: float
weighted_precision: float
weighted_recall: float
weighted_f_measure: float
weighted_true_positive_rate: float
weighted_false_positive_rate: float
true_positive_rate: float
false_positive_rate: float
f1: Optional[float] = None
accuracy: Optional[float] = None
precision: Optional[float] = None
recall: Optional[float] = None
f_measure: Optional[float] = None
weighted_precision: Optional[float] = None
weighted_recall: Optional[float] = None
weighted_f_measure: Optional[float] = None
weighted_true_positive_rate: Optional[float] = None
weighted_false_positive_rate: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None
true_positive_count: int
false_positive_count: int
true_negative_count: int
Expand Down
5 changes: 3 additions & 2 deletions sdk/radicalbit_platform_sdk/models/dataset_stats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel

Expand All @@ -6,9 +7,9 @@ class DatasetStats(BaseModel):
n_variables: int
n_observations: int
missing_cells: int
missing_cells_perc: float
missing_cells_perc: Optional[float]
duplicate_rows: int
duplicate_rows_perc: float
duplicate_rows_perc: Optional[float]
numeric: int
categorical: int
datetime: int
Expand Down
149 changes: 149 additions & 0 deletions sdk/tests/apis/model_reference_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,152 @@ def test_model_metrics_key_error(self):

with self.assertRaises(ClientError):
model_reference_dataset.model_quality()

@responses.activate
def test_data_quality_ok(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelReferenceDataset(
base_url,
model_id,
ModelType.BINARY,
ReferenceFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality",
"status": 200,
"body": """{
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"dataQuality": {
"nObservations": 200,
"classMetrics": [
{"name": "classA", "count": 100, "percentage": 50.0},
{"name": "classB", "count": 100, "percentage": 50.0}
],
"featureMetrics": [
{
"featureName": "age",
"type": "numerical",
"mean": 29.5,
"std": 5.2,
"min": 18,
"max": 45,
"medianMetrics": {"perc25": 25.0, "median": 29.0, "perc75": 34.0},
"missingValue": {"count": 2, "percentage": 0.02},
"classMedianMetrics": [
{
"name": "classA",
"mean": 30.0,
"medianMetrics": {"perc25": 27.0, "median": 30.0, "perc75": 33.0}
},
{
"name": "classB",
"mean": 29.0,
"medianMetrics": {"perc25": 24.0, "median": 28.0, "perc75": 32.0}
}
],
"histogram": {
"buckets": [40.0, 45.0, 50.0, 55.0, 60.0],
"referenceValues": [50, 150, 200, 150, 50],
"currentValues": [45, 140, 210, 145, 60]
}
},
{
"featureName": "gender",
"type": "categorical",
"distinctValue": 2,
"categoryFrequency": [
{"name": "male", "count": 90, "frequency": 0.45},
{"name": "female", "count": 110, "frequency": 0.55}
],
"missingValue": {"count": 0, "percentage": 0.0}
}
]
}
}""",
}
)

metrics = model_reference_dataset.data_quality()

assert metrics.n_observations == 200
assert len(metrics.class_metrics) == 2
assert metrics.class_metrics[0].name == "classA"
assert metrics.class_metrics[0].count == 100
assert metrics.class_metrics[0].percentage == 50.0
assert len(metrics.feature_metrics) == 2
assert metrics.feature_metrics[0].feature_name == "age"
assert metrics.feature_metrics[0].type == "numerical"
assert metrics.feature_metrics[0].mean == 29.5
assert metrics.feature_metrics[1].feature_name == "gender"
assert metrics.feature_metrics[1].type == "categorical"
assert metrics.feature_metrics[1].distinct_value == 2
assert model_reference_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_data_quality_validation_error(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelReferenceDataset(
base_url,
model_id,
ModelType.BINARY,
ReferenceFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality",
"status": 200,
"body": '{"dataQuality": "wrong"}',
}
)

with self.assertRaises(ClientError):
model_reference_dataset.data_quality()

@responses.activate
def test_data_quality_key_error(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelReferenceDataset(
base_url,
model_id,
ModelType.BINARY,
ReferenceFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality",
"status": 200,
"body": '{"wrong": "json"}',
}
)

with self.assertRaises(ClientError):
model_reference_dataset.data_quality()

0 comments on commit 8479099

Please sign in to comment.