Skip to content

Commit

Permalink
fix: edit model quality dto
Browse files Browse the repository at this point in the history
  • Loading branch information
dtria91 committed Jul 8, 2024
1 parent c1dbcb2 commit c29749e
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 65 deletions.
4 changes: 2 additions & 2 deletions sdk/radicalbit_platform_sdk/apis/model_current_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
ClassificationDataQuality,
CurrentBinaryClassificationModelQuality,
CurrentFileUpload,
CurrentMultiClassificationModelQuality,
CurrentRegressionModelQuality,
DataQuality,
DatasetStats,
Drift,
JobStatus,
ModelQuality,
ModelType,
MultiClassificationModelQuality,
RegressionDataQuality,
)

Expand Down Expand Up @@ -241,7 +241,7 @@ def __callback(
case ModelType.MULTI_CLASS:
return (
job_status,
MultiClassificationModelQuality.model_validate(
CurrentMultiClassificationModelQuality.model_validate(
response_json['modelQuality']
),
)
Expand Down
2 changes: 2 additions & 0 deletions sdk/radicalbit_platform_sdk/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .dataset_model_quality import (
BinaryClassificationModelQuality,
CurrentBinaryClassificationModelQuality,
CurrentMultiClassificationModelQuality,
CurrentRegressionModelQuality,
ModelQuality,
MultiClassificationModelQuality,
Expand Down Expand Up @@ -52,6 +53,7 @@
'ModelQuality',
'BinaryClassificationModelQuality',
'CurrentBinaryClassificationModelQuality',
'CurrentMultiClassificationModelQuality',
'MultiClassificationModelQuality',
'RegressionModelQuality',
'CurrentRegressionModelQuality',
Expand Down
101 changes: 68 additions & 33 deletions sdk/radicalbit_platform_sdk/models/dataset_model_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,43 @@ class ModelQuality(BaseModel):
pass


class MetricsBase(BaseModel):
f1: Optional[float] = None
accuracy: Optional[float] = None
class Distribution(BaseModel):
timestamp: str
value: Optional[float] = None


class BaseMetrics(BaseModel):
precision: Optional[float] = None
recall: Optional[float] = None
f_measure: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class GroupedBaseMetrics(BaseModel):
precision: List[Distribution]
recall: List[Distribution]
f_measure: List[Distribution]
true_positive_rate: List[Distribution]
false_positive_rate: List[Distribution]

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class AdditionalMetrics(BaseModel):
f1: Optional[float] = None
accuracy: Optional[float] = None
weighted_precision: Optional[float] = None
weighted_recall: Optional[float] = None
weighted_f_measure: Optional[float] = None
weighted_true_positive_rate: Optional[float] = None
weighted_false_positive_rate: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None
area_under_roc: Optional[float] = None
area_under_pr: Optional[float] = None

Expand All @@ -29,53 +53,56 @@ class MetricsBase(BaseModel):
)


class BinaryClassificationModelQuality(ModelQuality, MetricsBase):
class AdditionalGroupedMetrics(GroupedBaseMetrics):
f1: List[Distribution]
accuracy: List[Distribution]
weighted_precision: List[Distribution]
weighted_recall: List[Distribution]
weighted_f_measure: List[Distribution]
weighted_true_positive_rate: List[Distribution]
weighted_false_positive_rate: List[Distribution]
area_under_roc: Optional[List[Distribution]] = None
area_under_pr: Optional[List[Distribution]] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class GlobalBinaryMetrics(BaseMetrics, AdditionalMetrics):
true_positive_count: int
false_positive_count: int
true_negative_count: int
false_negative_count: int


class Distribution(BaseModel):
timestamp: str
value: Optional[float] = None
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GroupedMetricsBase(BaseModel):
f1: Optional[List[Distribution]] = None
accuracy: Optional[List[Distribution]] = None
precision: List[Distribution]
recall: List[Distribution]
f_measure: List[Distribution]
weighted_precision: Optional[List[Distribution]] = None
weighted_recall: Optional[List[Distribution]] = None
weighted_f_measure: Optional[List[Distribution]] = None
weighted_true_positive_rate: Optional[List[Distribution]] = None
weighted_false_positive_rate: Optional[List[Distribution]] = None
true_positive_rate: List[Distribution]
false_positive_rate: List[Distribution]
area_under_roc: Optional[List[Distribution]] = None
area_under_pr: Optional[List[Distribution]] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
class BinaryClassificationModelQuality(ModelQuality, GlobalBinaryMetrics):
pass


class CurrentBinaryClassificationModelQuality(ModelQuality):
global_metrics: BinaryClassificationModelQuality
grouped_metrics: GroupedMetricsBase
global_metrics: GlobalBinaryMetrics
grouped_metrics: AdditionalGroupedMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class ClassMetrics(BaseModel):
class_name: str
metrics: MetricsBase
grouped_metrics: Optional[GroupedMetricsBase] = None
metrics: BaseMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GlobalMetrics(MetricsBase):
class AdditionalClassMetrics(ClassMetrics):
grouped_metrics: GroupedBaseMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GlobalMulticlassMetrics(AdditionalMetrics):
confusion_matrix: List[List[int]]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
Expand All @@ -84,7 +111,15 @@ class GlobalMetrics(MetricsBase):
class MultiClassificationModelQuality(ModelQuality):
classes: List[str]
class_metrics: List[ClassMetrics]
global_metrics: GlobalMetrics
global_metrics: GlobalMulticlassMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CurrentMultiClassificationModelQuality(ModelQuality):
classes: List[str]
class_metrics: List[AdditionalClassMetrics]
global_metrics: GlobalMulticlassMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

Expand Down
93 changes: 73 additions & 20 deletions sdk/tests/apis/model_current_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
ClassificationDataQuality,
CurrentBinaryClassificationModelQuality,
CurrentFileUpload,
CurrentMultiClassificationModelQuality,
CurrentRegressionModelQuality,
Drift,
DriftAlgorithm,
JobStatus,
ModelType,
MultiClassificationModelQuality,
RegressionDataQuality,
)

Expand Down Expand Up @@ -746,7 +746,6 @@ def test_multi_class_model_quality_ok(self):
weighted_f_measure = 2.45
true_positive_rate = 4.12
false_positive_rate = 5.89
precision = 2.33
weighted_recall = 4.22
f_measure = 9.33
confusion_matrix = [
Expand Down Expand Up @@ -781,7 +780,7 @@ def test_multi_class_model_quality_ok(self):
{{
"className": "classA",
"metrics": {{
"accuracy": {accuracy}
"recall": {recall}
}},
"groupedMetrics": {{
"precision": [
Expand Down Expand Up @@ -810,35 +809,96 @@ def test_multi_class_model_quality_ok(self):
"className": "classB",
"metrics": {{
"fMeasure": {f_measure}
}}
}},
"groupedMetrics": {{
"precision": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.86}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.88}}
],
"recall": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {recall}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}}
],
"fMeasure": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.8}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.85}}
],
"truePositiveRate": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.81}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}}
],
"falsePositiveRate": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.14}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}}
]
}}
}},
{{
"className": "classC",
"metrics": {{
"recall": {recall}
}}
}},
"groupedMetrics": {{
"precision": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.86}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.88}}
],
"recall": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {recall}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}}
],
"fMeasure": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.8}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.85}}
],
"truePositiveRate": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.81}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}}
],
"falsePositiveRate": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.14}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}}
]
}}
}},
{{
"className": "classD",
"metrics": {{
"truePositiveRate": {true_positive_rate},
"falsePositiveRate": {false_positive_rate}
}}
}},
"groupedMetrics": {{
"precision": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.86}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.88}}
],
"recall": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {recall}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}}
],
"fMeasure": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.8}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.85}}
],
"truePositiveRate": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.81}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}}
],
"falsePositiveRate": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": 0.14}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}}
]
}}
}}
],
"globalMetrics": {{
"f1": {f1},
"accuracy": {accuracy},
"precision": {precision},
"recall": {recall},
"fMeasure": {f_measure},
"weightedPrecision": {weighted_precision},
"weightedRecall": {weighted_recall},
"weightedFMeasure": {weighted_f_measure},
"weightedTruePositiveRate": {weighted_true_positive_rate},
"weightedFalsePositiveRate": {weighted_false_positive_rate},
"truePositiveRate": {true_positive_rate},
"falsePositiveRate": {false_positive_rate},
"confusionMatrix": {confusion_matrix}
}}
}}
Expand All @@ -847,9 +907,8 @@ def test_multi_class_model_quality_ok(self):

metrics = model_current_dataset.model_quality()

assert isinstance(metrics, MultiClassificationModelQuality)
assert isinstance(metrics, CurrentMultiClassificationModelQuality)
assert metrics.classes == ['classA', 'classB', 'classC', 'classD']
assert metrics.global_metrics.accuracy == accuracy
assert metrics.global_metrics.weighted_precision == weighted_precision
assert metrics.global_metrics.weighted_recall == weighted_recall
assert (
Expand All @@ -860,14 +919,8 @@ def test_multi_class_model_quality_ok(self):
metrics.global_metrics.weighted_false_positive_rate
== weighted_false_positive_rate
)
assert metrics.global_metrics.weighted_f_measure == weighted_f_measure
assert metrics.global_metrics.true_positive_rate == true_positive_rate
assert metrics.global_metrics.false_positive_rate == false_positive_rate
assert metrics.global_metrics.precision == precision
assert metrics.global_metrics.f_measure == f_measure
assert metrics.class_metrics[0].class_name == 'classA'
assert metrics.class_metrics[0].metrics.accuracy == accuracy
assert metrics.class_metrics[0].grouped_metrics.recall[0].value == recall
assert metrics.class_metrics[0].metrics.recall == recall
assert metrics.class_metrics[1].class_name == 'classB'
assert metrics.class_metrics[1].metrics.f_measure == f_measure
assert metrics.class_metrics[2].class_name == 'classC'
Expand Down
10 changes: 0 additions & 10 deletions sdk/tests/apis/model_reference_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def test_multi_class_model_metrics_ok(self):
weighted_f_measure = 2.45
true_positive_rate = 4.12
false_positive_rate = 5.89
precision = 2.33
weighted_recall = 4.22
f_measure = 9.33
confusion_matrix = [
Expand Down Expand Up @@ -297,16 +296,12 @@ def test_multi_class_model_metrics_ok(self):
"globalMetrics": {{
"f1": {f1},
"accuracy": {accuracy},
"precision": {precision},
"recall": {recall},
"fMeasure": {f_measure},
"weightedPrecision": {weighted_precision},
"weightedRecall": {weighted_recall},
"weightedFMeasure": {weighted_f_measure},
"weightedTruePositiveRate": {weighted_true_positive_rate},
"weightedFalsePositiveRate": {weighted_false_positive_rate},
"truePositiveRate": {true_positive_rate},
"falsePositiveRate": {false_positive_rate},
"confusionMatrix": {confusion_matrix}
}}
}}
Expand All @@ -329,12 +324,7 @@ def test_multi_class_model_metrics_ok(self):
== weighted_false_positive_rate
)
assert metrics.global_metrics.weighted_f_measure == weighted_f_measure
assert metrics.global_metrics.true_positive_rate == true_positive_rate
assert metrics.global_metrics.false_positive_rate == false_positive_rate
assert metrics.global_metrics.precision == precision
assert metrics.global_metrics.f_measure == f_measure
assert metrics.class_metrics[0].class_name == 'classA'
assert metrics.class_metrics[0].metrics.accuracy == accuracy
assert metrics.class_metrics[1].class_name == 'classB'
assert metrics.class_metrics[1].metrics.f_measure == f_measure
assert metrics.class_metrics[2].class_name == 'classC'
Expand Down

0 comments on commit c29749e

Please sign in to comment.