Skip to content

Commit

Permalink
feat: (api) classification model quality refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
dtria91 committed Jul 8, 2024
1 parent db54d1a commit 7bca1d7
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 42 deletions.
121 changes: 86 additions & 35 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,43 @@
from app.models.model_dto import ModelType


class MetricsBase(BaseModel):
f1: Optional[float] = None
accuracy: Optional[float] = None
class Distribution(BaseModel):
timestamp: str
value: Optional[float] = None


class BaseClassificationMetrics(BaseModel):
precision: Optional[float] = None
recall: Optional[float] = None
f_measure: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class GroupedBaseClassificationMetrics(BaseModel):
precision: List[Distribution]
recall: List[Distribution]
f_measure: List[Distribution]
true_positive_rate: List[Distribution]
false_positive_rate: List[Distribution]

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class AdditionalMetrics(BaseModel):
f1: Optional[float] = None
accuracy: Optional[float] = None
weighted_precision: Optional[float] = None
weighted_recall: Optional[float] = None
weighted_f_measure: Optional[float] = None
weighted_true_positive_rate: Optional[float] = None
weighted_false_positive_rate: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None
area_under_roc: Optional[float] = None
area_under_pr: Optional[float] = None

Expand All @@ -30,53 +54,56 @@ class MetricsBase(BaseModel):
)


class BinaryClassificationModelQuality(MetricsBase):
class AdditionalGroupedMetrics(GroupedBaseClassificationMetrics):
f1: List[Distribution]
accuracy: List[Distribution]
weighted_precision: List[Distribution]
weighted_recall: List[Distribution]
weighted_f_measure: List[Distribution]
weighted_true_positive_rate: List[Distribution]
weighted_false_positive_rate: List[Distribution]
area_under_roc: Optional[List[Distribution]] = None
area_under_pr: Optional[List[Distribution]] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class GlobalBinaryMetrics(BaseClassificationMetrics, AdditionalMetrics):
true_positive_count: int
false_positive_count: int
true_negative_count: int
false_negative_count: int


class Distribution(BaseModel):
timestamp: str
value: Optional[float] = None
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GroupedMetricsBase(BaseModel):
f1: Optional[List[Distribution]] = None
accuracy: Optional[List[Distribution]] = None
precision: List[Distribution]
recall: List[Distribution]
f_measure: List[Distribution]
weighted_precision: Optional[List[Distribution]] = None
weighted_recall: Optional[List[Distribution]] = None
weighted_f_measure: Optional[List[Distribution]] = None
weighted_true_positive_rate: Optional[List[Distribution]] = None
weighted_false_positive_rate: Optional[List[Distribution]] = None
true_positive_rate: List[Distribution]
false_positive_rate: List[Distribution]
area_under_roc: Optional[List[Distribution]] = None
area_under_pr: Optional[List[Distribution]] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
class BinaryClassificationModelQuality(GlobalBinaryMetrics):
pass


class CurrentBinaryClassificationModelQuality(BaseModel):
global_metrics: BinaryClassificationModelQuality
grouped_metrics: GroupedMetricsBase
global_metrics: GlobalBinaryMetrics
grouped_metrics: AdditionalGroupedMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class ClassMetrics(BaseModel):
class_name: str
metrics: MetricsBase
grouped_metrics: Optional[GroupedMetricsBase] = None
metrics: BaseClassificationMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class AdditionalClassMetrics(ClassMetrics):
grouped_metrics: GroupedBaseClassificationMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GlobalMetrics(MetricsBase):
class GlobalMulticlassMetrics(AdditionalMetrics):
confusion_matrix: List[List[int]]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
Expand All @@ -85,7 +112,15 @@ class GlobalMetrics(MetricsBase):
class MultiClassificationModelQuality(BaseModel):
classes: List[str]
class_metrics: List[ClassMetrics]
global_metrics: GlobalMetrics
global_metrics: GlobalMulticlassMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CurrentMultiClassificationModelQuality(BaseModel):
classes: List[str]
class_metrics: List[AdditionalClassMetrics]
global_metrics: GlobalMulticlassMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

Expand Down Expand Up @@ -143,6 +178,7 @@ class ModelQualityDTO(BaseModel):
BinaryClassificationModelQuality
| CurrentBinaryClassificationModelQuality
| MultiClassificationModelQuality
| CurrentMultiClassificationModelQuality
| RegressionModelQuality
| CurrentRegressionModelQuality
]
Expand Down Expand Up @@ -187,7 +223,10 @@ def _create_model_quality(
model_quality_data=model_quality_data,
)
if model_type == ModelType.MULTI_CLASS:
return MultiClassificationModelQuality(**model_quality_data)
return ModelQualityDTO._create_multiclass_model_quality(
dataset_type=dataset_type,
model_quality_data=model_quality_data,
)
if model_type == ModelType.REGRESSION:
return ModelQualityDTO._create_regression_model_quality(
dataset_type=dataset_type, model_quality_data=model_quality_data
Expand All @@ -206,12 +245,24 @@ def _create_binary_model_quality(
return CurrentBinaryClassificationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')

@staticmethod
def _create_multiclass_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> MultiClassificationModelQuality | CurrentMultiClassificationModelQuality:
"""Create a multiclass model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return MultiClassificationModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
return CurrentMultiClassificationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')

@staticmethod
def _create_regression_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> RegressionModelQuality | CurrentRegressionModelQuality:
"""Create a binary model quality instance based on dataset type."""
"""Create a regression model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return RegressionModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
Expand Down
2 changes: 1 addition & 1 deletion api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def get_sample_current_dataset(
'mape': 35.19314237273801,
'rmse': 202.23194752188695,
'adj_r2': 0.9116805380966796,
'variance': 0.23
'variance': 0.23,
}

grouped_regression_model_quality_dict = {
Expand Down
12 changes: 6 additions & 6 deletions sdk/radicalbit_platform_sdk/models/dataset_model_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Distribution(BaseModel):
value: Optional[float] = None


class BaseMetrics(BaseModel):
class BaseClassificationMetrics(BaseModel):
precision: Optional[float] = None
recall: Optional[float] = None
f_measure: Optional[float] = None
Expand All @@ -25,7 +25,7 @@ class BaseMetrics(BaseModel):
)


class GroupedBaseMetrics(BaseModel):
class GroupedBaseClassificationMetrics(BaseModel):
precision: List[Distribution]
recall: List[Distribution]
f_measure: List[Distribution]
Expand Down Expand Up @@ -53,7 +53,7 @@ class AdditionalMetrics(BaseModel):
)


class AdditionalGroupedMetrics(GroupedBaseMetrics):
class AdditionalGroupedMetrics(GroupedBaseClassificationMetrics):
f1: List[Distribution]
accuracy: List[Distribution]
weighted_precision: List[Distribution]
Expand All @@ -69,7 +69,7 @@ class AdditionalGroupedMetrics(GroupedBaseMetrics):
)


class GlobalBinaryMetrics(BaseMetrics, AdditionalMetrics):
class GlobalBinaryMetrics(BaseClassificationMetrics, AdditionalMetrics):
true_positive_count: int
false_positive_count: int
true_negative_count: int
Expand All @@ -91,13 +91,13 @@ class CurrentBinaryClassificationModelQuality(ModelQuality):

class ClassMetrics(BaseModel):
class_name: str
metrics: BaseMetrics
metrics: BaseClassificationMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class AdditionalClassMetrics(ClassMetrics):
grouped_metrics: GroupedBaseMetrics
grouped_metrics: GroupedBaseClassificationMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

Expand Down

0 comments on commit 7bca1d7

Please sign in to comment.