Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk): align reference metrics business models with API #11

Merged
merged 2 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def data_quality(self) -> Optional[DataQuality]:
:return: The `DataQuality` if exists
"""

def __callback(response: requests.Response) -> Optional[DataQuality]:
def __callback(
response: requests.Response,
) -> tuple[JobStatus, Optional[DataQuality]]:
try:
response_json = response.json()
job_status = JobStatus(response_json["jobStatus"])
Expand Down
66 changes: 30 additions & 36 deletions sdk/radicalbit_platform_sdk/models/dataset_data_quality.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,82 @@
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
from typing import List
from typing import List, Optional, Union


class ClassMetrics(BaseModel):
name: str
count: int
percentage: float
percentage: Optional[float] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
model_config = ConfigDict(populate_by_name=True)


class MedianMetrics(BaseModel):
perc_25: float
median: float
perc_75: float
perc_25: Optional[float] = None
median: Optional[float] = None
perc_75: Optional[float] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MissingValue(BaseModel):
count: int
percentage: float
percentage: Optional[float] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True)


class ClassMedianMetrics(BaseModel):
name: str
mean: float
mean: Optional[float] = None
median_metrics: MedianMetrics

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class FeatureMetrics(BaseModel):
feature_name: str
type: str
missing_value: MissingValue

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class Histogram(BaseModel):
buckets: List[float]
reference_values: List[int]
current_values: Optional[List[int]] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class NumericalFeatureMetrics(FeatureMetrics):
type: str = "numerical"
mean: float
std: float
min: float
max: float
mean: Optional[float] = None
std: Optional[float] = None
min: Optional[float] = None
max: Optional[float] = None
median_metrics: MedianMetrics
class_median_metrics: List[ClassMedianMetrics]
histogram: Histogram

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CategoryFrequency(BaseModel):
name: str
count: int
frequency: float
frequency: Optional[float] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True)


class CategoricalFeatureMetrics(FeatureMetrics):
type: str = "categorical"
category_frequency: List[CategoryFrequency]
distinct_value: int

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class DataQuality(BaseModel):
Expand All @@ -91,13 +86,12 @@ class DataQuality(BaseModel):
class BinaryClassificationDataQuality(DataQuality):
n_observations: int
class_metrics: List[ClassMetrics]
feature_metrics: List[FeatureMetrics]
feature_metrics: List[Union[NumericalFeatureMetrics, CategoricalFeatureMetrics]]

model_config = ConfigDict(
arbitrary_types_allowed=True,
populate_by_name=True,
alias_generator=to_camel,
protected_namespaces=(),
)


Expand Down
24 changes: 12 additions & 12 deletions sdk/radicalbit_platform_sdk/models/dataset_model_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ class ModelQuality(BaseModel):


class BinaryClassificationModelQuality(ModelQuality):
f1: float
accuracy: float
precision: float
recall: float
f_measure: float
weighted_precision: float
weighted_recall: float
weighted_f_measure: float
weighted_true_positive_rate: float
weighted_false_positive_rate: float
true_positive_rate: float
false_positive_rate: float
f1: Optional[float] = None
accuracy: Optional[float] = None
precision: Optional[float] = None
recall: Optional[float] = None
f_measure: Optional[float] = None
weighted_precision: Optional[float] = None
weighted_recall: Optional[float] = None
weighted_f_measure: Optional[float] = None
weighted_true_positive_rate: Optional[float] = None
weighted_false_positive_rate: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None
true_positive_count: int
false_positive_count: int
true_negative_count: int
Expand Down
5 changes: 3 additions & 2 deletions sdk/radicalbit_platform_sdk/models/dataset_stats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel

Expand All @@ -6,9 +7,9 @@ class DatasetStats(BaseModel):
n_variables: int
n_observations: int
missing_cells: int
missing_cells_perc: float
missing_cells_perc: Optional[float]
duplicate_rows: int
duplicate_rows_perc: float
duplicate_rows_perc: Optional[float]
numeric: int
categorical: int
datetime: int
Expand Down
149 changes: 149 additions & 0 deletions sdk/tests/apis/model_reference_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,152 @@ def test_model_metrics_key_error(self):

with self.assertRaises(ClientError):
model_reference_dataset.model_quality()

@responses.activate
def test_data_quality_ok(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelReferenceDataset(
base_url,
model_id,
ModelType.BINARY,
ReferenceFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality",
"status": 200,
"body": """{
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"dataQuality": {
"nObservations": 200,
"classMetrics": [
{"name": "classA", "count": 100, "percentage": 50.0},
{"name": "classB", "count": 100, "percentage": 50.0}
],
"featureMetrics": [
{
"featureName": "age",
"type": "numerical",
"mean": 29.5,
"std": 5.2,
"min": 18,
"max": 45,
"medianMetrics": {"perc25": 25.0, "median": 29.0, "perc75": 34.0},
"missingValue": {"count": 2, "percentage": 0.02},
"classMedianMetrics": [
{
"name": "classA",
"mean": 30.0,
"medianMetrics": {"perc25": 27.0, "median": 30.0, "perc75": 33.0}
},
{
"name": "classB",
"mean": 29.0,
"medianMetrics": {"perc25": 24.0, "median": 28.0, "perc75": 32.0}
}
],
"histogram": {
"buckets": [40.0, 45.0, 50.0, 55.0, 60.0],
"referenceValues": [50, 150, 200, 150, 50],
"currentValues": [45, 140, 210, 145, 60]
}
},
{
"featureName": "gender",
"type": "categorical",
"distinctValue": 2,
"categoryFrequency": [
{"name": "male", "count": 90, "frequency": 0.45},
{"name": "female", "count": 110, "frequency": 0.55}
],
"missingValue": {"count": 0, "percentage": 0.0}
}
]
}
}""",
}
)

metrics = model_reference_dataset.data_quality()

assert metrics.n_observations == 200
assert len(metrics.class_metrics) == 2
assert metrics.class_metrics[0].name == "classA"
assert metrics.class_metrics[0].count == 100
assert metrics.class_metrics[0].percentage == 50.0
assert len(metrics.feature_metrics) == 2
assert metrics.feature_metrics[0].feature_name == "age"
assert metrics.feature_metrics[0].type == "numerical"
assert metrics.feature_metrics[0].mean == 29.5
assert metrics.feature_metrics[1].feature_name == "gender"
assert metrics.feature_metrics[1].type == "categorical"
assert metrics.feature_metrics[1].distinct_value == 2
assert model_reference_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_data_quality_validation_error(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelReferenceDataset(
base_url,
model_id,
ModelType.BINARY,
ReferenceFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality",
"status": 200,
"body": '{"dataQuality": "wrong"}',
}
)

with self.assertRaises(ClientError):
model_reference_dataset.data_quality()

@responses.activate
def test_data_quality_key_error(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelReferenceDataset(
base_url,
model_id,
ModelType.BINARY,
ReferenceFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality",
"status": 200,
"body": '{"wrong": "json"}',
}
)

with self.assertRaises(ClientError):
model_reference_dataset.data_quality()