Skip to content

Commit

Permalink
feat: refactoring model quality dtos (#92)
Browse files Browse the repository at this point in the history
* fix: edit model quality dto

* feat: (api) classification model quality refactoring
  • Loading branch information
dtria91 authored Jul 9, 2024
1 parent 460f79e commit 85d6a64
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 101 deletions.
121 changes: 86 additions & 35 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,43 @@
from app.models.model_dto import ModelType


class MetricsBase(BaseModel):
f1: Optional[float] = None
accuracy: Optional[float] = None
class Distribution(BaseModel):
timestamp: str
value: Optional[float] = None


class BaseClassificationMetrics(BaseModel):
precision: Optional[float] = None
recall: Optional[float] = None
f_measure: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class GroupedBaseClassificationMetrics(BaseModel):
precision: List[Distribution]
recall: List[Distribution]
f_measure: List[Distribution]
true_positive_rate: List[Distribution]
false_positive_rate: List[Distribution]

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class AdditionalMetrics(BaseModel):
f1: Optional[float] = None
accuracy: Optional[float] = None
weighted_precision: Optional[float] = None
weighted_recall: Optional[float] = None
weighted_f_measure: Optional[float] = None
weighted_true_positive_rate: Optional[float] = None
weighted_false_positive_rate: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None
area_under_roc: Optional[float] = None
area_under_pr: Optional[float] = None

Expand All @@ -30,53 +54,56 @@ class MetricsBase(BaseModel):
)


class BinaryClassificationModelQuality(MetricsBase):
class AdditionalGroupedMetrics(GroupedBaseClassificationMetrics):
f1: List[Distribution]
accuracy: List[Distribution]
weighted_precision: List[Distribution]
weighted_recall: List[Distribution]
weighted_f_measure: List[Distribution]
weighted_true_positive_rate: List[Distribution]
weighted_false_positive_rate: List[Distribution]
area_under_roc: Optional[List[Distribution]] = None
area_under_pr: Optional[List[Distribution]] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class GlobalBinaryMetrics(BaseClassificationMetrics, AdditionalMetrics):
true_positive_count: int
false_positive_count: int
true_negative_count: int
false_negative_count: int


class Distribution(BaseModel):
timestamp: str
value: Optional[float] = None
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GroupedMetricsBase(BaseModel):
f1: Optional[List[Distribution]] = None
accuracy: Optional[List[Distribution]] = None
precision: List[Distribution]
recall: List[Distribution]
f_measure: List[Distribution]
weighted_precision: Optional[List[Distribution]] = None
weighted_recall: Optional[List[Distribution]] = None
weighted_f_measure: Optional[List[Distribution]] = None
weighted_true_positive_rate: Optional[List[Distribution]] = None
weighted_false_positive_rate: Optional[List[Distribution]] = None
true_positive_rate: List[Distribution]
false_positive_rate: List[Distribution]
area_under_roc: Optional[List[Distribution]] = None
area_under_pr: Optional[List[Distribution]] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
class BinaryClassificationModelQuality(GlobalBinaryMetrics):
pass


class CurrentBinaryClassificationModelQuality(BaseModel):
global_metrics: BinaryClassificationModelQuality
grouped_metrics: GroupedMetricsBase
global_metrics: GlobalBinaryMetrics
grouped_metrics: AdditionalGroupedMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class ClassMetrics(BaseModel):
class_name: str
metrics: MetricsBase
grouped_metrics: Optional[GroupedMetricsBase] = None
metrics: BaseClassificationMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class AdditionalClassMetrics(ClassMetrics):
grouped_metrics: GroupedBaseClassificationMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GlobalMetrics(MetricsBase):
class GlobalMulticlassMetrics(AdditionalMetrics):
confusion_matrix: List[List[int]]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
Expand All @@ -85,7 +112,15 @@ class GlobalMetrics(MetricsBase):
class MultiClassificationModelQuality(BaseModel):
classes: List[str]
class_metrics: List[ClassMetrics]
global_metrics: GlobalMetrics
global_metrics: GlobalMulticlassMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CurrentMultiClassificationModelQuality(BaseModel):
classes: List[str]
class_metrics: List[AdditionalClassMetrics]
global_metrics: GlobalMulticlassMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

Expand Down Expand Up @@ -143,6 +178,7 @@ class ModelQualityDTO(BaseModel):
BinaryClassificationModelQuality
| CurrentBinaryClassificationModelQuality
| MultiClassificationModelQuality
| CurrentMultiClassificationModelQuality
| RegressionModelQuality
| CurrentRegressionModelQuality
]
Expand Down Expand Up @@ -187,7 +223,10 @@ def _create_model_quality(
model_quality_data=model_quality_data,
)
if model_type == ModelType.MULTI_CLASS:
return MultiClassificationModelQuality(**model_quality_data)
return ModelQualityDTO._create_multiclass_model_quality(
dataset_type=dataset_type,
model_quality_data=model_quality_data,
)
if model_type == ModelType.REGRESSION:
return ModelQualityDTO._create_regression_model_quality(
dataset_type=dataset_type, model_quality_data=model_quality_data
Expand All @@ -206,12 +245,24 @@ def _create_binary_model_quality(
return CurrentBinaryClassificationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')

@staticmethod
def _create_multiclass_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> MultiClassificationModelQuality | CurrentMultiClassificationModelQuality:
"""Create a multiclass model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return MultiClassificationModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
return CurrentMultiClassificationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')

@staticmethod
def _create_regression_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> RegressionModelQuality | CurrentRegressionModelQuality:
"""Create a binary model quality instance based on dataset type."""
"""Create a regression model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return RegressionModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
Expand Down
2 changes: 1 addition & 1 deletion api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def get_sample_current_dataset(
'mape': 35.19314237273801,
'rmse': 202.23194752188695,
'adj_r2': 0.9116805380966796,
'variance': 0.23
'variance': 0.23,
}

grouped_regression_model_quality_dict = {
Expand Down
4 changes: 2 additions & 2 deletions sdk/radicalbit_platform_sdk/apis/model_current_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
ClassificationDataQuality,
CurrentBinaryClassificationModelQuality,
CurrentFileUpload,
CurrentMultiClassificationModelQuality,
CurrentRegressionModelQuality,
DataQuality,
DatasetStats,
Drift,
JobStatus,
ModelQuality,
ModelType,
MultiClassificationModelQuality,
RegressionDataQuality,
)

Expand Down Expand Up @@ -241,7 +241,7 @@ def __callback(
case ModelType.MULTI_CLASS:
return (
job_status,
MultiClassificationModelQuality.model_validate(
CurrentMultiClassificationModelQuality.model_validate(
response_json['modelQuality']
),
)
Expand Down
2 changes: 2 additions & 0 deletions sdk/radicalbit_platform_sdk/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .dataset_model_quality import (
BinaryClassificationModelQuality,
CurrentBinaryClassificationModelQuality,
CurrentMultiClassificationModelQuality,
CurrentRegressionModelQuality,
ModelQuality,
MultiClassificationModelQuality,
Expand Down Expand Up @@ -52,6 +53,7 @@
'ModelQuality',
'BinaryClassificationModelQuality',
'CurrentBinaryClassificationModelQuality',
'CurrentMultiClassificationModelQuality',
'MultiClassificationModelQuality',
'RegressionModelQuality',
'CurrentRegressionModelQuality',
Expand Down
101 changes: 68 additions & 33 deletions sdk/radicalbit_platform_sdk/models/dataset_model_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,43 @@ class ModelQuality(BaseModel):
pass


class MetricsBase(BaseModel):
f1: Optional[float] = None
accuracy: Optional[float] = None
class Distribution(BaseModel):
timestamp: str
value: Optional[float] = None


class BaseClassificationMetrics(BaseModel):
precision: Optional[float] = None
recall: Optional[float] = None
f_measure: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class GroupedBaseClassificationMetrics(BaseModel):
precision: List[Distribution]
recall: List[Distribution]
f_measure: List[Distribution]
true_positive_rate: List[Distribution]
false_positive_rate: List[Distribution]

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class AdditionalMetrics(BaseModel):
f1: Optional[float] = None
accuracy: Optional[float] = None
weighted_precision: Optional[float] = None
weighted_recall: Optional[float] = None
weighted_f_measure: Optional[float] = None
weighted_true_positive_rate: Optional[float] = None
weighted_false_positive_rate: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None
area_under_roc: Optional[float] = None
area_under_pr: Optional[float] = None

Expand All @@ -29,53 +53,56 @@ class MetricsBase(BaseModel):
)


class BinaryClassificationModelQuality(ModelQuality, MetricsBase):
class AdditionalGroupedMetrics(GroupedBaseClassificationMetrics):
f1: List[Distribution]
accuracy: List[Distribution]
weighted_precision: List[Distribution]
weighted_recall: List[Distribution]
weighted_f_measure: List[Distribution]
weighted_true_positive_rate: List[Distribution]
weighted_false_positive_rate: List[Distribution]
area_under_roc: Optional[List[Distribution]] = None
area_under_pr: Optional[List[Distribution]] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class GlobalBinaryMetrics(BaseClassificationMetrics, AdditionalMetrics):
true_positive_count: int
false_positive_count: int
true_negative_count: int
false_negative_count: int


class Distribution(BaseModel):
timestamp: str
value: Optional[float] = None
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GroupedMetricsBase(BaseModel):
f1: Optional[List[Distribution]] = None
accuracy: Optional[List[Distribution]] = None
precision: List[Distribution]
recall: List[Distribution]
f_measure: List[Distribution]
weighted_precision: Optional[List[Distribution]] = None
weighted_recall: Optional[List[Distribution]] = None
weighted_f_measure: Optional[List[Distribution]] = None
weighted_true_positive_rate: Optional[List[Distribution]] = None
weighted_false_positive_rate: Optional[List[Distribution]] = None
true_positive_rate: List[Distribution]
false_positive_rate: List[Distribution]
area_under_roc: Optional[List[Distribution]] = None
area_under_pr: Optional[List[Distribution]] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
class BinaryClassificationModelQuality(ModelQuality, GlobalBinaryMetrics):
pass


class CurrentBinaryClassificationModelQuality(ModelQuality):
global_metrics: BinaryClassificationModelQuality
grouped_metrics: GroupedMetricsBase
global_metrics: GlobalBinaryMetrics
grouped_metrics: AdditionalGroupedMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class ClassMetrics(BaseModel):
class_name: str
metrics: MetricsBase
grouped_metrics: Optional[GroupedMetricsBase] = None
metrics: BaseClassificationMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GlobalMetrics(MetricsBase):
class AdditionalClassMetrics(ClassMetrics):
grouped_metrics: GroupedBaseClassificationMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GlobalMulticlassMetrics(AdditionalMetrics):
confusion_matrix: List[List[int]]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
Expand All @@ -84,7 +111,15 @@ class GlobalMetrics(MetricsBase):
class MultiClassificationModelQuality(ModelQuality):
classes: List[str]
class_metrics: List[ClassMetrics]
global_metrics: GlobalMetrics
global_metrics: GlobalMulticlassMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CurrentMultiClassificationModelQuality(ModelQuality):
classes: List[str]
class_metrics: List[AdditionalClassMetrics]
global_metrics: GlobalMulticlassMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

Expand Down
Loading

0 comments on commit 85d6a64

Please sign in to comment.