Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add multiclass dto for model quality of reference part #46

Merged
92 changes: 63 additions & 29 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from app.models.model_dto import ModelType


class BinaryClassModelQuality(BaseModel):
class MetricsBase(BaseModel):
f1: Optional[float] = None
accuracy: Optional[float] = None
precision: Optional[float] = None
Expand All @@ -22,10 +22,6 @@ class BinaryClassModelQuality(BaseModel):
weighted_false_positive_rate: Optional[float] = None
true_positive_rate: Optional[float] = None
false_positive_rate: Optional[float] = None
true_positive_count: int
false_positive_count: int
true_negative_count: int
false_negative_count: int
area_under_roc: Optional[float] = None
area_under_pr: Optional[float] = None

Expand All @@ -34,12 +30,19 @@ class BinaryClassModelQuality(BaseModel):
)


class BinaryClassificationModelQuality(MetricsBase):
true_positive_count: int
false_positive_count: int
true_negative_count: int
false_negative_count: int


class Distribution(BaseModel):
timestamp: str
value: Optional[float] = None


class GroupedBinaryClassModelQuality(BaseModel):
class GroupedMetricsBase(BaseModel):
f1: List[Distribution]
accuracy: List[Distribution]
precision: List[Distribution]
Expand All @@ -55,21 +58,38 @@ class GroupedBinaryClassModelQuality(BaseModel):
area_under_roc: Optional[List[Distribution]] = None
area_under_pr: Optional[List[Distribution]] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CurrentBinaryClassModelQuality(BaseModel):
global_metrics: BinaryClassModelQuality
grouped_metrics: GroupedBinaryClassModelQuality
class CurrentBinaryClassificationModelQuality(BaseModel):
global_metrics: BinaryClassificationModelQuality
grouped_metrics: GroupedMetricsBase

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class ClassMetrics(BaseModel):
class_name: str
metrics: MetricsBase

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MultiClassModelQuality(BaseModel):
class GlobalMetrics(MetricsBase):
confusion_matrix: List[List[int]]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MultiClassificationModelQuality(BaseModel):
classes: List[str]
class_metrics: List[ClassMetrics]
global_metrics: GlobalMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CurrentMultiClassificationModelQuality(BaseModel):
pass


Expand All @@ -80,15 +100,14 @@ class RegressionModelQuality(BaseModel):
class ModelQualityDTO(BaseModel):
job_status: JobStatus
model_quality: Optional[
BinaryClassModelQuality
| CurrentBinaryClassModelQuality
| MultiClassModelQuality
BinaryClassificationModelQuality
| CurrentBinaryClassificationModelQuality
| MultiClassificationModelQuality
| CurrentMultiClassificationModelQuality
| RegressionModelQuality
]

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

@staticmethod
def from_dict(
Expand Down Expand Up @@ -121,9 +140,9 @@ def _create_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> (
BinaryClassModelQuality
| CurrentBinaryClassModelQuality
| MultiClassModelQuality
BinaryClassificationModelQuality
| CurrentBinaryClassificationModelQuality
| MultiClassificationModelQuality
| RegressionModelQuality
):
"""Create a specific model quality instance based on model type and dataset type."""
Expand All @@ -133,7 +152,10 @@ def _create_model_quality(
model_quality_data=model_quality_data,
)
if model_type == ModelType.MULTI_CLASS:
return MultiClassModelQuality(**model_quality_data)
return ModelQualityDTO._create_multiclass_model_quality(
dataset_type=dataset_type,
model_quality_data=model_quality_data,
)
if model_type == ModelType.REGRESSION:
return RegressionModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid model type {model_type}')
Expand All @@ -142,10 +164,22 @@ def _create_model_quality(
def _create_binary_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> BinaryClassModelQuality | CurrentBinaryClassModelQuality:
) -> BinaryClassificationModelQuality | CurrentBinaryClassificationModelQuality:
"""Create a binary model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return BinaryClassModelQuality(**model_quality_data)
return BinaryClassificationModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
return CurrentBinaryClassificationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')

@staticmethod
def _create_multiclass_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> MultiClassificationModelQuality | CurrentMultiClassificationModelQuality:
"""Create a multiclass model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return MultiClassificationModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
return CurrentBinaryClassModelQuality(**model_quality_data)
return CurrentMultiClassificationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')
49 changes: 42 additions & 7 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_sample_current_dataset(
'datetime': 1,
}

model_quality_dict = {
model_quality_base_dict = {
'f1': None,
'accuracy': 0.90,
'precision': 0.88,
Expand All @@ -141,16 +141,20 @@ def get_sample_current_dataset(
'weightedFalsePositiveRate': 0.10,
'truePositiveRate': 0.87,
'falsePositiveRate': 0.13,
'areaUnderRoc': 0.92,
'areaUnderPr': 0.91,
}

binary_model_quality_dict = {
'truePositiveCount': 870,
'falsePositiveCount': 130,
'trueNegativeCount': 820,
'falseNegativeCount': 180,
'areaUnderRoc': 0.92,
'areaUnderPr': 0.91,
**model_quality_base_dict,
}

current_model_quality_dict = {
'globalMetrics': model_quality_dict,
binary_current_model_quality_dict = {
'globalMetrics': binary_model_quality_dict,
'groupedMetrics': {
'f1': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.8},
Expand Down Expand Up @@ -211,6 +215,37 @@ def get_sample_current_dataset(
},
}

multiclass_model_quality_dict = {
'classes': [
'classA',
'classB',
'classC',
],
'class_metrics': [
{
'class_name': 'classA',
'metrics': model_quality_base_dict,
},
{
'class_name': 'classB',
'metrics': model_quality_base_dict,
},
{
'class_name': 'classC',
'metrics': model_quality_base_dict,
},
],
'global_metrics': {
'confusion_matrix': [
[3.0, 0.0, 0.0, 0.0],
[0.0, 2.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 2.0],
[1.0, 0.0, 0.0, 0.0],
],
**model_quality_base_dict,
},
}

data_quality_dict = {
'nObservations': 200,
'classMetrics': [
Expand Down Expand Up @@ -278,7 +313,7 @@ def get_sample_current_dataset(

def get_sample_reference_metrics(
reference_uuid: uuid.UUID = REFERENCE_UUID,
model_quality: Dict = model_quality_dict,
model_quality: Dict = binary_model_quality_dict,
data_quality: Dict = data_quality_dict,
statistics: Dict = statistics_dict,
) -> ReferenceDatasetMetrics:
Expand All @@ -292,7 +327,7 @@ def get_sample_reference_metrics(

def get_sample_current_metrics(
current_uuid: uuid.UUID = CURRENT_UUID,
model_quality: Dict = model_quality_dict,
model_quality: Dict = binary_current_model_quality_dict,
data_quality: Dict = data_quality_dict,
statistics: Dict = statistics_dict,
drift: Dict = drift_dict,
Expand Down
2 changes: 1 addition & 1 deletion api/tests/routes/metrics_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_get_current_model_quality_by_model_by_uuid(self):
model_uuid = uuid.uuid4()
current_uuid = uuid.uuid4()
current_metrics = db_mock.get_sample_current_metrics(
model_quality=db_mock.current_model_quality_dict
model_quality=db_mock.binary_current_model_quality_dict
)
model_quality = ModelQualityDTO.from_dict(
dataset_type=DatasetType.CURRENT,
Expand Down
64 changes: 61 additions & 3 deletions api/tests/services/metrics_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from app.models.metrics.drift_dto import DriftDTO
from app.models.metrics.model_quality_dto import ModelQualityDTO
from app.models.metrics.statistics_dto import StatisticsDTO
from app.models.model_dto import ModelType
from app.services.metrics_service import MetricsService
from app.services.model_service import ModelService
from tests.commons import db_mock
Expand Down Expand Up @@ -176,6 +177,37 @@ def test_get_empty_reference_model_quality_by_model_by_uuid(self):
model_quality_data=None,
)

def test_get_reference_multiclass_model_quality_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
reference_dataset = db_mock.get_sample_reference_dataset(status=status.value)
reference_metrics = db_mock.get_sample_reference_metrics(
model_quality=db_mock.multiclass_model_quality_dict
)
model = db_mock.get_sample_model(model_type=ModelType.MULTI_CLASS)
self.model_service.get_model_by_uuid = MagicMock(return_value=model)
self.reference_dataset_dao.get_reference_dataset_by_model_uuid = MagicMock(
return_value=reference_dataset
)
self.reference_metrics_dao.get_reference_metrics_by_model_uuid = MagicMock(
return_value=reference_metrics
)
res = self.metrics_service.get_reference_model_quality_by_model_by_uuid(
model_uuid
)
self.reference_dataset_dao.get_reference_dataset_by_model_uuid.assert_called_once_with(
model_uuid
)
self.reference_metrics_dao.get_reference_metrics_by_model_uuid.assert_called_once_with(
model_uuid
)

assert res == ModelQualityDTO.from_dict(
dataset_type=DatasetType.REFERENCE,
model_type=model.model_type,
job_status=reference_dataset.status,
model_quality_data=reference_metrics.model_quality,
)

def test_get_reference_binary_class_data_quality_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
reference_dataset = db_mock.get_sample_reference_dataset(status=status.value)
Expand Down Expand Up @@ -225,6 +257,34 @@ def test_get_empty_reference_data_quality_by_model_by_uuid(self):
data_quality_data=None,
)

def test_get_reference_multiclass_data_quality_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
reference_dataset = db_mock.get_sample_reference_dataset(status=status.value)
reference_metrics = db_mock.get_sample_reference_metrics()
model = db_mock.get_sample_model(model_type=ModelType.MULTI_CLASS)
self.model_service.get_model_by_uuid = MagicMock(return_value=model)
self.reference_dataset_dao.get_reference_dataset_by_model_uuid = MagicMock(
return_value=reference_dataset
)
self.reference_metrics_dao.get_reference_metrics_by_model_uuid = MagicMock(
return_value=reference_metrics
)
res = self.metrics_service.get_reference_data_quality_by_model_by_uuid(
model_uuid
)
self.reference_dataset_dao.get_reference_dataset_by_model_uuid.assert_called_once_with(
model_uuid
)
self.reference_metrics_dao.get_reference_metrics_by_model_uuid.assert_called_once_with(
model_uuid
)

assert res == DataQualityDTO.from_dict(
model_type=model.model_type,
job_status=reference_dataset.status,
data_quality_data=reference_metrics.data_quality,
)

def test_get_current_statistics_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
current_dataset = db_mock.get_sample_current_dataset(status=status.value)
Expand Down Expand Up @@ -444,9 +504,7 @@ def test_get_empty_current_data_quality_by_model_by_uuid(self):
def test_get_current_binary_class_model_quality_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
current_dataset = db_mock.get_sample_current_dataset(status=status.value)
current_metrics = db_mock.get_sample_current_metrics(
model_quality=db_mock.current_model_quality_dict
)
current_metrics = db_mock.get_sample_current_metrics()
model = db_mock.get_sample_model()
self.model_service.get_model_by_uuid = MagicMock(return_value=model)
self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock(
Expand Down
4 changes: 2 additions & 2 deletions sdk/radicalbit_platform_sdk/apis/model_current_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BinaryClassificationDataQuality,
CurrentBinaryClassificationModelQuality,
CurrentFileUpload,
CurrentMultiClassificationModelQuality,
DataQuality,
DatasetStats,
Drift,
Expand All @@ -19,7 +20,6 @@
ModelType,
MultiClassDataQuality,
MultiClassDrift,
MultiClassModelQuality,
RegressionDataQuality,
RegressionDrift,
RegressionModelQuality,
Expand Down Expand Up @@ -268,7 +268,7 @@ def __callback(
case ModelType.MULTI_CLASS:
return (
job_status,
MultiClassModelQuality.model_validate(
CurrentMultiClassificationModelQuality.model_validate(
response_json['modelQuality']
),
)
Expand Down
4 changes: 2 additions & 2 deletions sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ModelQuality,
ModelType,
MultiClassDataQuality,
MultiClassModelQuality,
MultiClassificationModelQuality,
ReferenceFileUpload,
RegressionDataQuality,
RegressionModelQuality,
Expand Down Expand Up @@ -194,7 +194,7 @@ def __callback(
case ModelType.MULTI_CLASS:
return (
job_status,
MultiClassModelQuality.model_validate(
MultiClassificationModelQuality.model_validate(
response_json['modelQuality']
),
)
Expand Down
Loading