From c29749e16feb1e1e2c819375e8da76319ef774fd Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Mon, 8 Jul 2024 10:04:06 +0200 Subject: [PATCH 1/2] fix: edit model quality dto --- .../apis/model_current_dataset.py | 4 +- .../models/__init__.py | 2 + .../models/dataset_model_quality.py | 101 ++++++++++++------ sdk/tests/apis/model_current_dataset_test.py | 93 ++++++++++++---- .../apis/model_reference_dataset_test.py | 10 -- 5 files changed, 145 insertions(+), 65 deletions(-) diff --git a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py index 06430527..6c1acd1f 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py @@ -10,6 +10,7 @@ ClassificationDataQuality, CurrentBinaryClassificationModelQuality, CurrentFileUpload, + CurrentMultiClassificationModelQuality, CurrentRegressionModelQuality, DataQuality, DatasetStats, @@ -17,7 +18,6 @@ JobStatus, ModelQuality, ModelType, - MultiClassificationModelQuality, RegressionDataQuality, ) @@ -241,7 +241,7 @@ def __callback( case ModelType.MULTI_CLASS: return ( job_status, - MultiClassificationModelQuality.model_validate( + CurrentMultiClassificationModelQuality.model_validate( response_json['modelQuality'] ), ) diff --git a/sdk/radicalbit_platform_sdk/models/__init__.py b/sdk/radicalbit_platform_sdk/models/__init__.py index fc7b63c7..9f9a3964 100644 --- a/sdk/radicalbit_platform_sdk/models/__init__.py +++ b/sdk/radicalbit_platform_sdk/models/__init__.py @@ -23,6 +23,7 @@ from .dataset_model_quality import ( BinaryClassificationModelQuality, CurrentBinaryClassificationModelQuality, + CurrentMultiClassificationModelQuality, CurrentRegressionModelQuality, ModelQuality, MultiClassificationModelQuality, @@ -52,6 +53,7 @@ 'ModelQuality', 'BinaryClassificationModelQuality', 'CurrentBinaryClassificationModelQuality', + 'CurrentMultiClassificationModelQuality', 'MultiClassificationModelQuality', 'RegressionModelQuality', 'CurrentRegressionModelQuality', diff --git a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py index f69f87f2..2fa63a60 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py @@ -8,19 +8,43 @@ class ModelQuality(BaseModel): pass -class MetricsBase(BaseModel): - f1: Optional[float] = None - accuracy: Optional[float] = None +class Distribution(BaseModel): + timestamp: str + value: Optional[float] = None + + +class BaseMetrics(BaseModel): precision: Optional[float] = None recall: Optional[float] = None f_measure: Optional[float] = None + true_positive_rate: Optional[float] = None + false_positive_rate: Optional[float] = None + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + +class GroupedBaseMetrics(BaseModel): + precision: List[Distribution] + recall: List[Distribution] + f_measure: List[Distribution] + true_positive_rate: List[Distribution] + false_positive_rate: List[Distribution] + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + +class AdditionalMetrics(BaseModel): + f1: Optional[float] = None + accuracy: Optional[float] = None weighted_precision: Optional[float] = None weighted_recall: Optional[float] = None weighted_f_measure: Optional[float] = None weighted_true_positive_rate: Optional[float] = None weighted_false_positive_rate: Optional[float] = None - true_positive_rate: Optional[float] = None - false_positive_rate: Optional[float] = None area_under_roc: Optional[float] = None area_under_pr: Optional[float] = None @@ -29,53 +53,56 @@ class MetricsBase(BaseModel): ) -class BinaryClassificationModelQuality(ModelQuality, MetricsBase): +class AdditionalGroupedMetrics(GroupedBaseMetrics): + f1: List[Distribution] + accuracy: List[Distribution] + weighted_precision: List[Distribution] + weighted_recall: List[Distribution] + weighted_f_measure: List[Distribution] + weighted_true_positive_rate: List[Distribution] + weighted_false_positive_rate: List[Distribution] + area_under_roc: Optional[List[Distribution]] = None + area_under_pr: Optional[List[Distribution]] = None + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + +class GlobalBinaryMetrics(BaseMetrics, AdditionalMetrics): true_positive_count: int false_positive_count: int true_negative_count: int false_negative_count: int - -class Distribution(BaseModel): - timestamp: str - value: Optional[float] = None + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class GroupedMetricsBase(BaseModel): - f1: Optional[List[Distribution]] = None - accuracy: Optional[List[Distribution]] = None - precision: List[Distribution] - recall: List[Distribution] - f_measure: List[Distribution] - weighted_precision: Optional[List[Distribution]] = None - weighted_recall: Optional[List[Distribution]] = None - weighted_f_measure: Optional[List[Distribution]] = None - weighted_true_positive_rate: Optional[List[Distribution]] = None - weighted_false_positive_rate: Optional[List[Distribution]] = None - true_positive_rate: List[Distribution] - false_positive_rate: List[Distribution] - area_under_roc: Optional[List[Distribution]] = None - area_under_pr: Optional[List[Distribution]] = None - - model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) +class BinaryClassificationModelQuality(ModelQuality, GlobalBinaryMetrics): + pass class CurrentBinaryClassificationModelQuality(ModelQuality): - global_metrics: BinaryClassificationModelQuality - grouped_metrics: GroupedMetricsBase + global_metrics: GlobalBinaryMetrics + grouped_metrics: AdditionalGroupedMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) class ClassMetrics(BaseModel): class_name: str - metrics: MetricsBase - grouped_metrics: Optional[GroupedMetricsBase] = None + metrics: BaseMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class GlobalMetrics(MetricsBase): +class AdditionalClassMetrics(ClassMetrics): + grouped_metrics: GroupedBaseMetrics + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class GlobalMulticlassMetrics(AdditionalMetrics): confusion_matrix: List[List[int]] model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) @@ -84,7 +111,15 @@ class GlobalMetrics(MetricsBase): class MultiClassificationModelQuality(ModelQuality): classes: List[str] class_metrics: List[ClassMetrics] - global_metrics: GlobalMetrics + global_metrics: GlobalMulticlassMetrics + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class CurrentMultiClassificationModelQuality(ModelQuality): + classes: List[str] + class_metrics: List[AdditionalClassMetrics] + global_metrics: GlobalMulticlassMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) diff --git a/sdk/tests/apis/model_current_dataset_test.py b/sdk/tests/apis/model_current_dataset_test.py index 039fadad..5df194b0 100644 --- a/sdk/tests/apis/model_current_dataset_test.py +++ b/sdk/tests/apis/model_current_dataset_test.py @@ -10,12 +10,12 @@ ClassificationDataQuality, CurrentBinaryClassificationModelQuality, CurrentFileUpload, + CurrentMultiClassificationModelQuality, CurrentRegressionModelQuality, Drift, DriftAlgorithm, JobStatus, ModelType, - MultiClassificationModelQuality, RegressionDataQuality, ) @@ -746,7 +746,6 @@ def test_multi_class_model_quality_ok(self): weighted_f_measure = 2.45 true_positive_rate = 4.12 false_positive_rate = 5.89 - precision = 2.33 weighted_recall = 4.22 f_measure = 9.33 confusion_matrix = [ @@ -781,7 +780,7 @@ def test_multi_class_model_quality_ok(self): {{ "className": "classA", "metrics": {{ - "accuracy": {accuracy} + "recall": {recall} }}, "groupedMetrics": {{ "precision": [ @@ -810,35 +809,96 @@ def test_multi_class_model_quality_ok(self): "className": "classB", "metrics": {{ "fMeasure": {f_measure} - }} + }}, + "groupedMetrics": {{ + "precision": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.86}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.88}} + ], + "recall": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": {recall}}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}} + ], + "fMeasure": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.8}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.85}} + ], + "truePositiveRate": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.81}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}} + ], + "falsePositiveRate": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.14}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}} + ] + }} }}, {{ "className": "classC", "metrics": {{ "recall": {recall} - }} + }}, + "groupedMetrics": {{ + "precision": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.86}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.88}} + ], + "recall": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": {recall}}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}} + ], + "fMeasure": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.8}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.85}} + ], + "truePositiveRate": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.81}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}} + ], + "falsePositiveRate": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.14}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}} + ] + }} }}, {{ "className": "classD", "metrics": {{ "truePositiveRate": {true_positive_rate}, "falsePositiveRate": {false_positive_rate} - }} + }}, + "groupedMetrics": {{ + "precision": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.86}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.88}} + ], + "recall": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": {recall}}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}} + ], + "fMeasure": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.8}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.85}} + ], + "truePositiveRate": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.81}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}} + ], + "falsePositiveRate": [ + {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.14}}, + {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}} + ] + }} }} ], "globalMetrics": {{ "f1": {f1}, "accuracy": {accuracy}, - "precision": {precision}, - "recall": {recall}, - "fMeasure": {f_measure}, "weightedPrecision": {weighted_precision}, "weightedRecall": {weighted_recall}, "weightedFMeasure": {weighted_f_measure}, "weightedTruePositiveRate": {weighted_true_positive_rate}, "weightedFalsePositiveRate": {weighted_false_positive_rate}, - "truePositiveRate": {true_positive_rate}, - "falsePositiveRate": {false_positive_rate}, "confusionMatrix": {confusion_matrix} }} }} @@ -847,9 +907,8 @@ def test_multi_class_model_quality_ok(self): metrics = model_current_dataset.model_quality() - assert isinstance(metrics, MultiClassificationModelQuality) + assert isinstance(metrics, CurrentMultiClassificationModelQuality) assert metrics.classes == ['classA', 'classB', 'classC', 'classD'] - assert metrics.global_metrics.accuracy == accuracy assert metrics.global_metrics.weighted_precision == weighted_precision assert metrics.global_metrics.weighted_recall == weighted_recall assert ( @@ -860,14 +919,8 @@ def test_multi_class_model_quality_ok(self): metrics.global_metrics.weighted_false_positive_rate == weighted_false_positive_rate ) - assert metrics.global_metrics.weighted_f_measure == weighted_f_measure - assert metrics.global_metrics.true_positive_rate == true_positive_rate - assert metrics.global_metrics.false_positive_rate == false_positive_rate - assert metrics.global_metrics.precision == precision - assert metrics.global_metrics.f_measure == f_measure assert metrics.class_metrics[0].class_name == 'classA' - assert metrics.class_metrics[0].metrics.accuracy == accuracy - assert metrics.class_metrics[0].grouped_metrics.recall[0].value == recall + assert metrics.class_metrics[0].metrics.recall == recall assert metrics.class_metrics[1].class_name == 'classB' assert metrics.class_metrics[1].metrics.f_measure == f_measure assert metrics.class_metrics[2].class_name == 'classC' diff --git a/sdk/tests/apis/model_reference_dataset_test.py b/sdk/tests/apis/model_reference_dataset_test.py index 38c8b43c..d51c5397 100644 --- a/sdk/tests/apis/model_reference_dataset_test.py +++ b/sdk/tests/apis/model_reference_dataset_test.py @@ -236,7 +236,6 @@ def test_multi_class_model_metrics_ok(self): weighted_f_measure = 2.45 true_positive_rate = 4.12 false_positive_rate = 5.89 - precision = 2.33 weighted_recall = 4.22 f_measure = 9.33 confusion_matrix = [ @@ -297,16 +296,12 @@ def test_multi_class_model_metrics_ok(self): "globalMetrics": {{ "f1": {f1}, "accuracy": {accuracy}, - "precision": {precision}, "recall": {recall}, - "fMeasure": {f_measure}, "weightedPrecision": {weighted_precision}, "weightedRecall": {weighted_recall}, "weightedFMeasure": {weighted_f_measure}, "weightedTruePositiveRate": {weighted_true_positive_rate}, "weightedFalsePositiveRate": {weighted_false_positive_rate}, - "truePositiveRate": {true_positive_rate}, - "falsePositiveRate": {false_positive_rate}, "confusionMatrix": {confusion_matrix} }} }} @@ -329,12 +324,7 @@ def test_multi_class_model_metrics_ok(self): == weighted_false_positive_rate ) assert metrics.global_metrics.weighted_f_measure == weighted_f_measure - assert metrics.global_metrics.true_positive_rate == true_positive_rate - assert metrics.global_metrics.false_positive_rate == false_positive_rate - assert metrics.global_metrics.precision == precision - assert metrics.global_metrics.f_measure == f_measure assert metrics.class_metrics[0].class_name == 'classA' - assert metrics.class_metrics[0].metrics.accuracy == accuracy assert metrics.class_metrics[1].class_name == 'classB' assert metrics.class_metrics[1].metrics.f_measure == f_measure assert metrics.class_metrics[2].class_name == 'classC' From 7bca1d7a4d758f7196e6d990a9342270cb44d47d Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Mon, 8 Jul 2024 16:27:51 +0200 Subject: [PATCH 2/2] feat: (api) classification model quality refactoring --- api/app/models/metrics/model_quality_dto.py | 121 +++++++++++++----- api/tests/commons/db_mock.py | 2 +- .../models/dataset_model_quality.py | 12 +- 3 files changed, 93 insertions(+), 42 deletions(-) diff --git a/api/app/models/metrics/model_quality_dto.py b/api/app/models/metrics/model_quality_dto.py index 5998c752..965c107a 100644 --- a/api/app/models/metrics/model_quality_dto.py +++ b/api/app/models/metrics/model_quality_dto.py @@ -9,19 +9,43 @@ from app.models.model_dto import ModelType -class MetricsBase(BaseModel): - f1: Optional[float] = None - accuracy: Optional[float] = None +class Distribution(BaseModel): + timestamp: str + value: Optional[float] = None + + +class BaseClassificationMetrics(BaseModel): precision: Optional[float] = None recall: Optional[float] = None f_measure: Optional[float] = None + true_positive_rate: Optional[float] = None + false_positive_rate: Optional[float] = None + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + +class GroupedBaseClassificationMetrics(BaseModel): + precision: List[Distribution] + recall: List[Distribution] + f_measure: List[Distribution] + true_positive_rate: List[Distribution] + false_positive_rate: List[Distribution] + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + +class AdditionalMetrics(BaseModel): + f1: Optional[float] = None + accuracy: Optional[float] = None weighted_precision: Optional[float] = None weighted_recall: Optional[float] = None weighted_f_measure: Optional[float] = None weighted_true_positive_rate: Optional[float] = None weighted_false_positive_rate: Optional[float] = None - true_positive_rate: Optional[float] = None - false_positive_rate: Optional[float] = None area_under_roc: Optional[float] = None area_under_pr: Optional[float] = None @@ -30,53 +54,56 @@ class MetricsBase(BaseModel): ) -class BinaryClassificationModelQuality(MetricsBase): +class AdditionalGroupedMetrics(GroupedBaseClassificationMetrics): + f1: List[Distribution] + accuracy: List[Distribution] + weighted_precision: List[Distribution] + weighted_recall: List[Distribution] + weighted_f_measure: List[Distribution] + weighted_true_positive_rate: List[Distribution] + weighted_false_positive_rate: List[Distribution] + area_under_roc: Optional[List[Distribution]] = None + area_under_pr: Optional[List[Distribution]] = None + + model_config = ConfigDict( + populate_by_name=True, alias_generator=to_camel, protected_namespaces=() + ) + + +class GlobalBinaryMetrics(BaseClassificationMetrics, AdditionalMetrics): true_positive_count: int false_positive_count: int true_negative_count: int false_negative_count: int - -class Distribution(BaseModel): - timestamp: str - value: Optional[float] = None + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class GroupedMetricsBase(BaseModel): - f1: Optional[List[Distribution]] = None - accuracy: Optional[List[Distribution]] = None - precision: List[Distribution] - recall: List[Distribution] - f_measure: List[Distribution] - weighted_precision: Optional[List[Distribution]] = None - weighted_recall: Optional[List[Distribution]] = None - weighted_f_measure: Optional[List[Distribution]] = None - weighted_true_positive_rate: Optional[List[Distribution]] = None - weighted_false_positive_rate: Optional[List[Distribution]] = None - true_positive_rate: List[Distribution] - false_positive_rate: List[Distribution] - area_under_roc: Optional[List[Distribution]] = None - area_under_pr: Optional[List[Distribution]] = None - - model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) +class BinaryClassificationModelQuality(GlobalBinaryMetrics): + pass class CurrentBinaryClassificationModelQuality(BaseModel): - global_metrics: BinaryClassificationModelQuality - grouped_metrics: GroupedMetricsBase + global_metrics: GlobalBinaryMetrics + grouped_metrics: AdditionalGroupedMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) class ClassMetrics(BaseModel): class_name: str - metrics: MetricsBase - grouped_metrics: Optional[GroupedMetricsBase] = None + metrics: BaseClassificationMetrics + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class AdditionalClassMetrics(ClassMetrics): + grouped_metrics: GroupedBaseClassificationMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) -class GlobalMetrics(MetricsBase): +class GlobalMulticlassMetrics(AdditionalMetrics): confusion_matrix: List[List[int]] model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) @@ -85,7 +112,15 @@ class GlobalMetrics(MetricsBase): class MultiClassificationModelQuality(BaseModel): classes: List[str] class_metrics: List[ClassMetrics] - global_metrics: GlobalMetrics + global_metrics: GlobalMulticlassMetrics + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class CurrentMultiClassificationModelQuality(BaseModel): + classes: List[str] + class_metrics: List[AdditionalClassMetrics] + global_metrics: GlobalMulticlassMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) @@ -143,6 +178,7 @@ class ModelQualityDTO(BaseModel): BinaryClassificationModelQuality | CurrentBinaryClassificationModelQuality | MultiClassificationModelQuality + | CurrentMultiClassificationModelQuality | RegressionModelQuality | CurrentRegressionModelQuality ] @@ -187,7 +223,10 @@ def _create_model_quality( model_quality_data=model_quality_data, ) if model_type == ModelType.MULTI_CLASS: - return MultiClassificationModelQuality(**model_quality_data) + return ModelQualityDTO._create_multiclass_model_quality( + dataset_type=dataset_type, + model_quality_data=model_quality_data, + ) if model_type == ModelType.REGRESSION: return ModelQualityDTO._create_regression_model_quality( dataset_type=dataset_type, model_quality_data=model_quality_data @@ -206,12 +245,24 @@ def _create_binary_model_quality( return CurrentBinaryClassificationModelQuality(**model_quality_data) raise MetricsInternalError(f'Invalid dataset type {dataset_type}') + @staticmethod + def _create_multiclass_model_quality( + dataset_type: DatasetType, + model_quality_data: Dict, + ) -> MultiClassificationModelQuality | CurrentMultiClassificationModelQuality: + """Create a multiclass model quality instance based on dataset type.""" + if dataset_type == DatasetType.REFERENCE: + return MultiClassificationModelQuality(**model_quality_data) + if dataset_type == DatasetType.CURRENT: + return CurrentMultiClassificationModelQuality(**model_quality_data) + raise MetricsInternalError(f'Invalid dataset type {dataset_type}') + @staticmethod def _create_regression_model_quality( dataset_type: DatasetType, model_quality_data: Dict, ) -> RegressionModelQuality | CurrentRegressionModelQuality: - """Create a binary model quality instance based on dataset type.""" + """Create a regression model quality instance based on dataset type.""" if dataset_type == DatasetType.REFERENCE: return RegressionModelQuality(**model_quality_data) if dataset_type == DatasetType.CURRENT: diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index 040fe88a..717b9612 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -323,7 +323,7 @@ def get_sample_current_dataset( 'mape': 35.19314237273801, 'rmse': 202.23194752188695, 'adj_r2': 0.9116805380966796, - 'variance': 0.23 + 'variance': 0.23, } grouped_regression_model_quality_dict = { diff --git a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py index 1f67bc0b..258c10e2 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py @@ -13,7 +13,7 @@ class Distribution(BaseModel): value: Optional[float] = None -class BaseMetrics(BaseModel): +class BaseClassificationMetrics(BaseModel): precision: Optional[float] = None recall: Optional[float] = None f_measure: Optional[float] = None @@ -25,7 +25,7 @@ class BaseMetrics(BaseModel): ) -class GroupedBaseMetrics(BaseModel): +class GroupedBaseClassificationMetrics(BaseModel): precision: List[Distribution] recall: List[Distribution] f_measure: List[Distribution] @@ -53,7 +53,7 @@ class AdditionalMetrics(BaseModel): ) -class AdditionalGroupedMetrics(GroupedBaseMetrics): +class AdditionalGroupedMetrics(GroupedBaseClassificationMetrics): f1: List[Distribution] accuracy: List[Distribution] weighted_precision: List[Distribution] @@ -69,7 +69,7 @@ class AdditionalGroupedMetrics(GroupedBaseMetrics): ) -class GlobalBinaryMetrics(BaseMetrics, AdditionalMetrics): +class GlobalBinaryMetrics(BaseClassificationMetrics, AdditionalMetrics): true_positive_count: int false_positive_count: int true_negative_count: int @@ -91,13 +91,13 @@ class CurrentBinaryClassificationModelQuality(ModelQuality): class ClassMetrics(BaseModel): class_name: str - metrics: BaseMetrics + metrics: BaseClassificationMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) class AdditionalClassMetrics(ClassMetrics): - grouped_metrics: GroupedBaseMetrics + grouped_metrics: GroupedBaseClassificationMetrics model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)