From 7bca1d7a4d758f7196e6d990a9342270cb44d47d Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Mon, 8 Jul 2024 16:27:51 +0200 Subject: [PATCH] feat: (api) classification model quality refactoring --- api/app/models/metrics/model_quality_dto.py | 121 +++++++++++++----- api/tests/commons/db_mock.py | 2 +- .../models/dataset_model_quality.py | 12 +- 3 files changed, 93 insertions(+), 42 deletions(-) diff --git a/api/app/models/metrics/model_quality_dto.py b/api/app/models/metrics/model_quality_dto.py index 5998c752..965c107a 100644 --- a/api/app/models/metrics/model_quality_dto.py +++ b/api/app/models/metrics/model_quality_dto.py @@ -9,19 +9,43 @@ from app.models.model_dto import ModelType -class MetricsBase(BaseModel): - f1: Optional[float] = None - accuracy: Optional[float] = None +class Distribution(BaseModel): + timestamp: str + value: Optional[float] = None + + +class BaseClassificationMetrics(BaseModel): precision: Optional[float] = None recall: Optional[float] = None f_measure: Optional[float] = None + true_positive_rate: Optional[float] = None + false_positive_rate: Optional[float] = None + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + +class GroupedBaseClassificationMetrics(BaseModel): + precision: List[Distribution] + recall: List[Distribution] + f_measure: List[Distribution] + true_positive_rate: List[Distribution] + false_positive_rate: List[Distribution] + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + +class AdditionalMetrics(BaseModel): + f1: Optional[float] = None + accuracy: Optional[float] = None weighted_precision: Optional[float] = None weighted_recall: Optional[float] = None weighted_f_measure: Optional[float] = None weighted_true_positive_rate: Optional[float] = None weighted_false_positive_rate: Optional[float] = None - true_positive_rate: Optional[float] = None - false_positive_rate: Optional[float] = None area_under_roc: Optional[float] = None area_under_pr: Optional[float] = None @@ -30,53 +54,56 @@ class MetricsBase(BaseModel): ) -class BinaryClassificationModelQuality(MetricsBase): +class AdditionalGroupedMetrics(GroupedBaseClassificationMetrics): + f1: List[Distribution] + accuracy: List[Distribution] + weighted_precision: List[Distribution] + weighted_recall: List[Distribution] + weighted_f_measure: List[Distribution] + weighted_true_positive_rate: List[Distribution] + weighted_false_positive_rate: List[Distribution] + area_under_roc: Optional[List[Distribution]] = None + area_under_pr: Optional[List[Distribution]] = None + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + +class GlobalBinaryMetrics(BaseClassificationMetrics, AdditionalMetrics): true_positive_count: int false_positive_count: int true_negative_count: int false_negative_count: int - -class Distribution(BaseModel): - timestamp: str - value: Optional[float] = None + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class GroupedMetricsBase(BaseModel): - f1: Optional[List[Distribution]] = None - accuracy: Optional[List[Distribution]] = None - precision: List[Distribution] - recall: List[Distribution] - f_measure: List[Distribution] - weighted_precision: Optional[List[Distribution]] = None - weighted_recall: Optional[List[Distribution]] = None - weighted_f_measure: Optional[List[Distribution]] = None - weighted_true_positive_rate: Optional[List[Distribution]] = None - weighted_false_positive_rate: Optional[List[Distribution]] = None - true_positive_rate: List[Distribution] - false_positive_rate: List[Distribution] - area_under_roc: Optional[List[Distribution]] = None - area_under_pr: Optional[List[Distribution]] = None - - model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) +class BinaryClassificationModelQuality(GlobalBinaryMetrics): + pass class CurrentBinaryClassificationModelQuality(BaseModel): - global_metrics: BinaryClassificationModelQuality - grouped_metrics: GroupedMetricsBase + global_metrics: GlobalBinaryMetrics + grouped_metrics: AdditionalGroupedMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) class ClassMetrics(BaseModel): class_name: str - metrics: MetricsBase - grouped_metrics: Optional[GroupedMetricsBase] = None + metrics: BaseClassificationMetrics + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class AdditionalClassMetrics(ClassMetrics): + grouped_metrics: GroupedBaseClassificationMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class GlobalMetrics(MetricsBase): +class GlobalMulticlassMetrics(AdditionalMetrics): confusion_matrix: List[List[int]] model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) @@ -85,7 +112,15 @@ class GlobalMetrics(MetricsBase): class MultiClassificationModelQuality(BaseModel): classes: List[str] class_metrics: List[ClassMetrics] - global_metrics: GlobalMetrics + global_metrics: GlobalMulticlassMetrics + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class CurrentMultiClassificationModelQuality(BaseModel): + classes: List[str] + class_metrics: List[AdditionalClassMetrics] + global_metrics: GlobalMulticlassMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) @@ -143,6 +178,7 @@ class ModelQualityDTO(BaseModel): BinaryClassificationModelQuality | CurrentBinaryClassificationModelQuality | MultiClassificationModelQuality + | CurrentMultiClassificationModelQuality | RegressionModelQuality | CurrentRegressionModelQuality ] @@ -187,7 +223,10 @@ def _create_model_quality( model_quality_data=model_quality_data, ) if model_type == ModelType.MULTI_CLASS: - return MultiClassificationModelQuality(**model_quality_data) + return ModelQualityDTO._create_multiclass_model_quality( + dataset_type=dataset_type, + model_quality_data=model_quality_data, + ) if model_type == ModelType.REGRESSION: return ModelQualityDTO._create_regression_model_quality( dataset_type=dataset_type, model_quality_data=model_quality_data @@ -206,12 +245,24 @@ def _create_binary_model_quality( return CurrentBinaryClassificationModelQuality(**model_quality_data) raise MetricsInternalError(f'Invalid dataset type {dataset_type}') + @staticmethod + def _create_multiclass_model_quality( + dataset_type: DatasetType, + model_quality_data: Dict, + ) -> MultiClassificationModelQuality | CurrentMultiClassificationModelQuality: + """Create a multiclass model quality instance based on dataset type.""" + if dataset_type == DatasetType.REFERENCE: + return MultiClassificationModelQuality(**model_quality_data) + if dataset_type == DatasetType.CURRENT: + return CurrentMultiClassificationModelQuality(**model_quality_data) + raise MetricsInternalError(f'Invalid dataset type {dataset_type}') + @staticmethod def _create_regression_model_quality( dataset_type: DatasetType, model_quality_data: Dict, ) -> RegressionModelQuality | CurrentRegressionModelQuality: - """Create a binary model quality instance based on dataset type.""" + """Create a regression model quality instance based on dataset type.""" if dataset_type == DatasetType.REFERENCE: return RegressionModelQuality(**model_quality_data) if dataset_type == DatasetType.CURRENT: diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index 040fe88a..717b9612 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -323,7 +323,7 @@ def get_sample_current_dataset( 'mape': 35.19314237273801, 'rmse': 202.23194752188695, 'adj_r2': 0.9116805380966796, - 'variance': 0.23 + 'variance': 0.23, } grouped_regression_model_quality_dict = { diff --git a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py index 1f67bc0b..258c10e2 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py @@ -13,7 +13,7 @@ class Distribution(BaseModel): value: Optional[float] = None -class BaseMetrics(BaseModel): +class BaseClassificationMetrics(BaseModel): precision: Optional[float] = None recall: Optional[float] = None f_measure: Optional[float] = None @@ -25,7 +25,7 @@ class BaseMetrics(BaseModel): ) -class GroupedBaseMetrics(BaseModel): +class GroupedBaseClassificationMetrics(BaseModel): precision: List[Distribution] recall: List[Distribution] f_measure: List[Distribution] @@ -53,7 +53,7 @@ class AdditionalMetrics(BaseModel): ) -class AdditionalGroupedMetrics(GroupedBaseMetrics): +class AdditionalGroupedMetrics(GroupedBaseClassificationMetrics): f1: List[Distribution] accuracy: List[Distribution] weighted_precision: List[Distribution] @@ -69,7 +69,7 @@ class AdditionalGroupedMetrics(GroupedBaseMetrics): ) -class GlobalBinaryMetrics(BaseMetrics, AdditionalMetrics): +class GlobalBinaryMetrics(BaseClassificationMetrics, AdditionalMetrics): true_positive_count: int false_positive_count: int true_negative_count: int @@ -91,13 +91,13 @@ class CurrentBinaryClassificationModelQuality(ModelQuality): class ClassMetrics(BaseModel): class_name: str - metrics: BaseMetrics + metrics: BaseClassificationMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) class AdditionalClassMetrics(ClassMetrics): - grouped_metrics: GroupedBaseMetrics + grouped_metrics: GroupedBaseClassificationMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)