Skip to content

Commit

Permalink
feat: add multiclass dto for model quality of reference part (#46)
Browse files Browse the repository at this point in the history
* feat: add multiclass dto for model quality of reference part

* feat: add confusion matrix

* refactor: improve compose

* feat: align sdk to multiclass model quality dto

* fix: ruff check

* fix: remove previous wrong commit

* feat: (sdk) define single class for model quality metrics

* fix: ruff fix check

---------

Co-authored-by: Marco Riva <[email protected]>
  • Loading branch information
dtria91 and rivamarco authored Jun 28, 2024
1 parent 0ba6c0a commit 20e2d6b
Show file tree
Hide file tree
Showing 11 changed files with 446 additions and 112 deletions.
92 changes: 63 additions & 29 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from app.models.model_dto import ModelType


class BinaryClassModelQuality(BaseModel):
class MetricsBase(BaseModel):
f1: Optional[float] = None
accuracy: Optional[float] = None
precision: Optional[float] = None
Expand All @@ -22,10 +22,6 @@ class BinaryClassModelQuality(BaseModel):
weighted_false_positive_rate: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None
true_positive_count: int
false_positive_count: int
true_negative_count: int
false_negative_count: int
area_under_roc: Optional[float] = None
area_under_pr: Optional[float] = None

Expand All @@ -34,12 +30,19 @@ class BinaryClassModelQuality(BaseModel):
)


class BinaryClassificationModelQuality(MetricsBase):
true_positive_count: int
false_positive_count: int
true_negative_count: int
false_negative_count: int


class Distribution(BaseModel):
timestamp: str
value: Optional[float] = None


class GroupedBinaryClassModelQuality(BaseModel):
class GroupedMetricsBase(BaseModel):
f1: List[Distribution]
accuracy: List[Distribution]
precision: List[Distribution]
Expand All @@ -55,21 +58,38 @@ class GroupedBinaryClassModelQuality(BaseModel):
area_under_roc: Optional[List[Distribution]] = None
area_under_pr: Optional[List[Distribution]] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CurrentBinaryClassModelQuality(BaseModel):
global_metrics: BinaryClassModelQuality
grouped_metrics: GroupedBinaryClassModelQuality
class CurrentBinaryClassificationModelQuality(BaseModel):
global_metrics: BinaryClassificationModelQuality
grouped_metrics: GroupedMetricsBase

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class ClassMetrics(BaseModel):
class_name: str
metrics: MetricsBase

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MultiClassModelQuality(BaseModel):
class GlobalMetrics(MetricsBase):
confusion_matrix: List[List[int]]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MultiClassificationModelQuality(BaseModel):
classes: List[str]
class_metrics: List[ClassMetrics]
global_metrics: GlobalMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CurrentMultiClassificationModelQuality(BaseModel):
pass


Expand All @@ -80,15 +100,14 @@ class RegressionModelQuality(BaseModel):
class ModelQualityDTO(BaseModel):
job_status: JobStatus
model_quality: Optional[
BinaryClassModelQuality
| CurrentBinaryClassModelQuality
| MultiClassModelQuality
BinaryClassificationModelQuality
| CurrentBinaryClassificationModelQuality
| MultiClassificationModelQuality
| CurrentMultiClassificationModelQuality
| RegressionModelQuality
]

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

@staticmethod
def from_dict(
Expand Down Expand Up @@ -121,9 +140,9 @@ def _create_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> (
BinaryClassModelQuality
| CurrentBinaryClassModelQuality
| MultiClassModelQuality
BinaryClassificationModelQuality
| CurrentBinaryClassificationModelQuality
| MultiClassificationModelQuality
| RegressionModelQuality
):
"""Create a specific model quality instance based on model type and dataset type."""
Expand All @@ -133,7 +152,10 @@ def _create_model_quality(
model_quality_data=model_quality_data,
)
if model_type == ModelType.MULTI_CLASS:
return MultiClassModelQuality(**model_quality_data)
return ModelQualityDTO._create_multiclass_model_quality(
dataset_type=dataset_type,
model_quality_data=model_quality_data,
)
if model_type == ModelType.REGRESSION:
return RegressionModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid model type {model_type}')
Expand All @@ -142,10 +164,22 @@ def _create_model_quality(
def _create_binary_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> BinaryClassModelQuality | CurrentBinaryClassModelQuality:
) -> BinaryClassificationModelQuality | CurrentBinaryClassificationModelQuality:
"""Create a binary model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return BinaryClassModelQuality(**model_quality_data)
return BinaryClassificationModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
return CurrentBinaryClassificationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')

@staticmethod
def _create_multiclass_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> MultiClassificationModelQuality | CurrentMultiClassificationModelQuality:
"""Create a multiclass model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return MultiClassificationModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
return CurrentBinaryClassModelQuality(**model_quality_data)
return CurrentMultiClassificationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')
49 changes: 42 additions & 7 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_sample_current_dataset(
'datetime': 1,
}

model_quality_dict = {
model_quality_base_dict = {
'f1': None,
'accuracy': 0.90,
'precision': 0.88,
Expand All @@ -141,16 +141,20 @@ def get_sample_current_dataset(
'weightedFalsePositiveRate': 0.10,
'truePositiveRate': 0.87,
'falsePositiveRate': 0.13,
'areaUnderRoc': 0.92,
'areaUnderPr': 0.91,
}

binary_model_quality_dict = {
'truePositiveCount': 870,
'falsePositiveCount': 130,
'trueNegativeCount': 820,
'falseNegativeCount': 180,
'areaUnderRoc': 0.92,
'areaUnderPr': 0.91,
**model_quality_base_dict,
}

current_model_quality_dict = {
'globalMetrics': model_quality_dict,
binary_current_model_quality_dict = {
'globalMetrics': binary_model_quality_dict,
'groupedMetrics': {
'f1': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.8},
Expand Down Expand Up @@ -211,6 +215,37 @@ def get_sample_current_dataset(
},
}

multiclass_model_quality_dict = {
'classes': [
'classA',
'classB',
'classC',
],
'class_metrics': [
{
'class_name': 'classA',
'metrics': model_quality_base_dict,
},
{
'class_name': 'classB',
'metrics': model_quality_base_dict,
},
{
'class_name': 'classC',
'metrics': model_quality_base_dict,
},
],
'global_metrics': {
'confusion_matrix': [
[3.0, 0.0, 0.0, 0.0],
[0.0, 2.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 2.0],
[1.0, 0.0, 0.0, 0.0],
],
**model_quality_base_dict,
},
}

data_quality_dict = {
'nObservations': 200,
'classMetrics': [
Expand Down Expand Up @@ -278,7 +313,7 @@ def get_sample_current_dataset(

def get_sample_reference_metrics(
reference_uuid: uuid.UUID = REFERENCE_UUID,
model_quality: Dict = model_quality_dict,
model_quality: Dict = binary_model_quality_dict,
data_quality: Dict = data_quality_dict,
statistics: Dict = statistics_dict,
) -> ReferenceDatasetMetrics:
Expand All @@ -292,7 +327,7 @@ def get_sample_reference_metrics(

def get_sample_current_metrics(
current_uuid: uuid.UUID = CURRENT_UUID,
model_quality: Dict = model_quality_dict,
model_quality: Dict = binary_current_model_quality_dict,
data_quality: Dict = data_quality_dict,
statistics: Dict = statistics_dict,
drift: Dict = drift_dict,
Expand Down
2 changes: 1 addition & 1 deletion api/tests/routes/metrics_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_get_current_model_quality_by_model_by_uuid(self):
model_uuid = uuid.uuid4()
current_uuid = uuid.uuid4()
current_metrics = db_mock.get_sample_current_metrics(
model_quality=db_mock.current_model_quality_dict
model_quality=db_mock.binary_current_model_quality_dict
)
model_quality = ModelQualityDTO.from_dict(
dataset_type=DatasetType.CURRENT,
Expand Down
64 changes: 61 additions & 3 deletions api/tests/services/metrics_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from app.models.metrics.drift_dto import DriftDTO
from app.models.metrics.model_quality_dto import ModelQualityDTO
from app.models.metrics.statistics_dto import StatisticsDTO
from app.models.model_dto import ModelType
from app.services.metrics_service import MetricsService
from app.services.model_service import ModelService
from tests.commons import db_mock
Expand Down Expand Up @@ -176,6 +177,37 @@ def test_get_empty_reference_model_quality_by_model_by_uuid(self):
model_quality_data=None,
)

def test_get_reference_multiclass_model_quality_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
reference_dataset = db_mock.get_sample_reference_dataset(status=status.value)
reference_metrics = db_mock.get_sample_reference_metrics(
model_quality=db_mock.multiclass_model_quality_dict
)
model = db_mock.get_sample_model(model_type=ModelType.MULTI_CLASS)
self.model_service.get_model_by_uuid = MagicMock(return_value=model)
self.reference_dataset_dao.get_reference_dataset_by_model_uuid = MagicMock(
return_value=reference_dataset
)
self.reference_metrics_dao.get_reference_metrics_by_model_uuid = MagicMock(
return_value=reference_metrics
)
res = self.metrics_service.get_reference_model_quality_by_model_by_uuid(
model_uuid
)
self.reference_dataset_dao.get_reference_dataset_by_model_uuid.assert_called_once_with(
model_uuid
)
self.reference_metrics_dao.get_reference_metrics_by_model_uuid.assert_called_once_with(
model_uuid
)

assert res == ModelQualityDTO.from_dict(
dataset_type=DatasetType.REFERENCE,
model_type=model.model_type,
job_status=reference_dataset.status,
model_quality_data=reference_metrics.model_quality,
)

def test_get_reference_binary_class_data_quality_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
reference_dataset = db_mock.get_sample_reference_dataset(status=status.value)
Expand Down Expand Up @@ -225,6 +257,34 @@ def test_get_empty_reference_data_quality_by_model_by_uuid(self):
data_quality_data=None,
)

def test_get_reference_multiclass_data_quality_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
reference_dataset = db_mock.get_sample_reference_dataset(status=status.value)
reference_metrics = db_mock.get_sample_reference_metrics()
model = db_mock.get_sample_model(model_type=ModelType.MULTI_CLASS)
self.model_service.get_model_by_uuid = MagicMock(return_value=model)
self.reference_dataset_dao.get_reference_dataset_by_model_uuid = MagicMock(
return_value=reference_dataset
)
self.reference_metrics_dao.get_reference_metrics_by_model_uuid = MagicMock(
return_value=reference_metrics
)
res = self.metrics_service.get_reference_data_quality_by_model_by_uuid(
model_uuid
)
self.reference_dataset_dao.get_reference_dataset_by_model_uuid.assert_called_once_with(
model_uuid
)
self.reference_metrics_dao.get_reference_metrics_by_model_uuid.assert_called_once_with(
model_uuid
)

assert res == DataQualityDTO.from_dict(
model_type=model.model_type,
job_status=reference_dataset.status,
data_quality_data=reference_metrics.data_quality,
)

def test_get_current_statistics_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
current_dataset = db_mock.get_sample_current_dataset(status=status.value)
Expand Down Expand Up @@ -444,9 +504,7 @@ def test_get_empty_current_data_quality_by_model_by_uuid(self):
def test_get_current_binary_class_model_quality_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
current_dataset = db_mock.get_sample_current_dataset(status=status.value)
current_metrics = db_mock.get_sample_current_metrics(
model_quality=db_mock.current_model_quality_dict
)
current_metrics = db_mock.get_sample_current_metrics()
model = db_mock.get_sample_model()
self.model_service.get_model_by_uuid = MagicMock(return_value=model)
self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock(
Expand Down
Loading

0 comments on commit 20e2d6b

Please sign in to comment.