From 17bed6b28ff50ca423373076f11a3ae65c3d1b20 Mon Sep 17 00:00:00 2001 From: Daniele Tria <36860433+dtria91@users.noreply.github.com> Date: Tue, 2 Jul 2024 16:30:30 +0200 Subject: [PATCH] fix: set optional field of current multiclass model quality (#77) * fix: set optional field for current model quality of multiclass metrics * fix: (sdk) set optional field for current model quality of multiclass metrics --- api/app/models/metrics/model_quality_dto.py | 14 ++++---- .../models/dataset_drift.py | 1 - .../models/dataset_model_quality.py | 14 ++++---- sdk/tests/apis/model_current_dataset_test.py | 32 ++----------------- spark/jobs/utils/current_multiclass.py | 2 -- 5 files changed, 16 insertions(+), 47 deletions(-) diff --git a/api/app/models/metrics/model_quality_dto.py b/api/app/models/metrics/model_quality_dto.py index 37b70aed..6fdcbfba 100644 --- a/api/app/models/metrics/model_quality_dto.py +++ b/api/app/models/metrics/model_quality_dto.py @@ -43,16 +43,16 @@ class Distribution(BaseModel): class GroupedMetricsBase(BaseModel): - f1: List[Distribution] - accuracy: List[Distribution] + f1: Optional[List[Distribution]] = None + accuracy: Optional[List[Distribution]] = None precision: List[Distribution] recall: List[Distribution] f_measure: List[Distribution] - weighted_precision: List[Distribution] - weighted_recall: List[Distribution] - weighted_f_measure: List[Distribution] - weighted_true_positive_rate: List[Distribution] - weighted_false_positive_rate: List[Distribution] + weighted_precision: Optional[List[Distribution]] = None + weighted_recall: Optional[List[Distribution]] = None + weighted_f_measure: Optional[List[Distribution]] = None + weighted_true_positive_rate: Optional[List[Distribution]] = None + weighted_false_positive_rate: Optional[List[Distribution]] = None true_positive_rate: List[Distribution] false_positive_rate: List[Distribution] area_under_roc: Optional[List[Distribution]] = None diff --git a/sdk/radicalbit_platform_sdk/models/dataset_drift.py b/sdk/radicalbit_platform_sdk/models/dataset_drift.py index 532774a7..dd7f8e10 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_drift.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_drift.py @@ -29,4 +29,3 @@ class Drift(BaseModel): feature_metrics: List[FeatureDrift] model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) - diff --git a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py index 5a68f5ed..6bb60480 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py @@ -42,16 +42,16 @@ class Distribution(BaseModel): class GroupedMetricsBase(BaseModel): - f1: List[Distribution] - accuracy: List[Distribution] + f1: Optional[List[Distribution]] = None + accuracy: Optional[List[Distribution]] = None precision: List[Distribution] recall: List[Distribution] f_measure: List[Distribution] - weighted_precision: List[Distribution] - weighted_recall: List[Distribution] - weighted_f_measure: List[Distribution] - weighted_true_positive_rate: List[Distribution] - weighted_false_positive_rate: List[Distribution] + weighted_precision: Optional[List[Distribution]] = None + weighted_recall: Optional[List[Distribution]] = None + weighted_f_measure: Optional[List[Distribution]] = None + weighted_true_positive_rate: Optional[List[Distribution]] = None + weighted_false_positive_rate: Optional[List[Distribution]] = None true_positive_rate: List[Distribution] false_positive_rate: List[Distribution] area_under_roc: Optional[List[Distribution]] = None diff --git a/sdk/tests/apis/model_current_dataset_test.py b/sdk/tests/apis/model_current_dataset_test.py index b9606d4c..6669807c 100644 --- a/sdk/tests/apis/model_current_dataset_test.py +++ b/sdk/tests/apis/model_current_dataset_test.py @@ -702,46 +702,18 @@ def test_multi_class_model_quality_ok(self): "accuracy": {accuracy} }}, "groupedMetrics": {{ - "f1": [ - {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.8}}, - {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.85}} - ], - "accuracy": [ - {{"timestamp": "2024-01-01T00:00:00Z", "value": {accuracy}}}, - {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.9}} - ], "precision": [ {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.86}}, {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.88}} ], "recall": [ - {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.81}}, + {{"timestamp": "2024-01-01T00:00:00Z", "value": {recall}}}, {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}} ], "fMeasure": [ {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.8}}, {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.85}} ], - "weightedPrecision": [ - {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.85}}, - {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.87}} - ], - "weightedRecall": [ - {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.82}}, - {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.84}} - ], - "weightedFMeasure": [ - {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.84}}, - {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.86}} - ], - "weightedTruePositiveRate": [ - {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.88}}, - {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.9}} - ], - "weightedFalsePositiveRate": [ - {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.12}}, - {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.1}} - ], "truePositiveRate": [ {{"timestamp": "2024-01-01T00:00:00Z", "value": 0.81}}, {{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}} @@ -813,7 +785,7 @@ def test_multi_class_model_quality_ok(self): assert metrics.global_metrics.f_measure == f_measure assert metrics.class_metrics[0].class_name == 'classA' assert metrics.class_metrics[0].metrics.accuracy == accuracy - assert metrics.class_metrics[0].grouped_metrics.accuracy[0].value == accuracy + assert metrics.class_metrics[0].grouped_metrics.recall[0].value == recall assert metrics.class_metrics[1].class_name == 'classB' assert metrics.class_metrics[1].metrics.f_measure == f_measure assert metrics.class_metrics[2].class_name == 'classC' diff --git a/spark/jobs/utils/current_multiclass.py b/spark/jobs/utils/current_multiclass.py index 0bddfa39..22738723 100644 --- a/spark/jobs/utils/current_multiclass.py +++ b/spark/jobs/utils/current_multiclass.py @@ -121,8 +121,6 @@ def calculate_multiclass_model_quality_group_by_timestamp(self): ] ) - dataset_with_group.show() - list_of_time_group = ( dataset_with_group.select("time_group") .distinct()