Skip to content

Commit

Permalink
feat: add current model quality regression (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
dtria91 authored Jul 8, 2024
1 parent 6a63d26 commit 65f4f65
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 24 deletions.
61 changes: 53 additions & 8 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,61 @@ class MultiClassificationModelQuality(BaseModel):
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class RegressionModelQuality(BaseModel):
class RegressionMetricsBase(BaseModel):
r2: Optional[float] = None
mae: Optional[float] = None
mse: Optional[float] = None
var: Optional[float] = None
variance: Optional[float] = None
mape: Optional[float] = None
rmse: Optional[float] = None
adj_r2: Optional[float] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class BaseRegressionMetrics(BaseModel):
r2: Optional[float] = None
mae: Optional[float] = None
mse: Optional[float] = None
variance: Optional[float] = None
mape: Optional[float] = None
rmse: Optional[float] = None
adj_r2: Optional[float] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GroupedBaseRegressionMetrics(BaseModel):
r2: List[Distribution]
mae: List[Distribution]
mse: List[Distribution]
variance: List[Distribution]
mape: List[Distribution]
rmse: List[Distribution]
adj_r2: List[Distribution]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class RegressionModelQuality(BaseRegressionMetrics):
pass


class CurrentRegressionModelQuality(BaseModel):
global_metrics: BaseRegressionMetrics
grouped_metrics: GroupedBaseRegressionMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class ModelQualityDTO(BaseModel):
job_status: JobStatus
model_quality: Optional[
BinaryClassificationModelQuality
| CurrentBinaryClassificationModelQuality
| MultiClassificationModelQuality
| RegressionModelQuality
| CurrentRegressionModelQuality
]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
Expand Down Expand Up @@ -143,11 +179,6 @@ def _create_model_quality(
model_type: ModelType,
dataset_type: DatasetType,
model_quality_data: Dict,
) -> (
BinaryClassificationModelQuality
| CurrentBinaryClassificationModelQuality
| MultiClassificationModelQuality
| RegressionModelQuality
):
"""Create a specific model quality instance based on model type and dataset type."""
if model_type == ModelType.BINARY:
Expand All @@ -158,7 +189,9 @@ def _create_model_quality(
if model_type == ModelType.MULTI_CLASS:
return MultiClassificationModelQuality(**model_quality_data)
if model_type == ModelType.REGRESSION:
return RegressionModelQuality(**model_quality_data)
return ModelQualityDTO._create_regression_model_quality(
dataset_type=dataset_type, model_quality_data=model_quality_data
)
raise MetricsInternalError(f'Invalid model type {model_type}')

@staticmethod
Expand All @@ -172,3 +205,15 @@ def _create_binary_model_quality(
if dataset_type == DatasetType.CURRENT:
return CurrentBinaryClassificationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')

@staticmethod
def _create_regression_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> RegressionModelQuality | CurrentRegressionModelQuality:
"""Create a binary model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return RegressionModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
return CurrentRegressionModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')
37 changes: 37 additions & 0 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,43 @@ def get_sample_current_dataset(
'mape': 35.19314237273801,
'rmse': 202.23194752188695,
'adj_r2': 0.9116805380966796,
'variance': 0.23
}

grouped_regression_model_quality_dict = {
'r2': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.8},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.85},
],
'mae': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.88},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.9},
],
'mse': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.86},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.88},
],
'mape': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.81},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.83},
],
'rmse': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.8},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.85},
],
'adj_r2': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.85},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.87},
],
'variance': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.82},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.84},
],
}

current_regression_model_quality_dict = {
'global_metrics': regression_model_quality_dict,
'grouped_metrics': grouped_regression_model_quality_dict,
}

regression_data_quality_dict = {
Expand Down
8 changes: 3 additions & 5 deletions api/tests/routes/upload_dataset_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_upload_current(self):
path='test',
date=str(datetime.datetime.now(tz=datetime.UTC)),
status=JobStatus.IMPORTING,
correlation_id_column=None
correlation_id_column=None,
)
self.file_service.upload_current_file = MagicMock(
return_value=upload_file_result
Expand All @@ -103,11 +103,9 @@ def test_bind_current(self):
path='test',
date=str(datetime.datetime.now(tz=datetime.UTC)),
status=JobStatus.IMPORTING,
correlation_id_column=None
)
self.file_service.bind_current_file = MagicMock(
return_value=upload_file_result
correlation_id_column=None,
)
self.file_service.bind_current_file = MagicMock(return_value=upload_file_result)
res = self.client.post(
f'{self.prefix}/{model_uuid}/current/bind',
json=jsonable_encoder(file_ref),
Expand Down
3 changes: 2 additions & 1 deletion api/tests/services/file_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ def test_upload_current_file_ok(self):
self.spark_k8s_client.submit_app = MagicMock()

result = self.files_service.upload_current_file(
model.uuid, file,
model.uuid,
file,
)

self.model_svc.get_model_by_uuid.assert_called_once()
Expand Down
31 changes: 31 additions & 0 deletions api/tests/services/metrics_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,37 @@ def test_get_current_multiclass_model_quality_by_model_by_uuid(self):
model_quality_data=current_metrics.model_quality,
)

def test_get_current_regression_model_quality_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
current_dataset = db_mock.get_sample_current_dataset(status=status.value)
current_metrics = db_mock.get_sample_current_metrics(
model_quality=db_mock.current_regression_model_quality_dict
)
model = db_mock.get_sample_model(model_type=ModelType.REGRESSION)
self.model_service.get_model_by_uuid = MagicMock(return_value=model)
self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock(
return_value=current_dataset
)
self.current_metrics_dao.get_current_metrics_by_model_uuid = MagicMock(
return_value=current_metrics
)
res = self.metrics_service.get_current_model_quality_by_model_by_uuid(
model_uuid, current_dataset.uuid
)
self.current_dataset_dao.get_current_dataset_by_model_uuid.assert_called_once_with(
model_uuid, current_dataset.uuid
)
self.current_metrics_dao.get_current_metrics_by_model_uuid.assert_called_once_with(
model_uuid, current_dataset.uuid
)

assert res == ModelQualityDTO.from_dict(
dataset_type=DatasetType.CURRENT,
model_type=model.model_type,
job_status=current_dataset.status,
model_quality_data=current_metrics.model_quality,
)


model_uuid = db_mock.MODEL_UUID
current_uuid = db_mock.CURRENT_UUID
25 changes: 22 additions & 3 deletions sdk/radicalbit_platform_sdk/models/dataset_model_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,36 @@ class MultiClassificationModelQuality(ModelQuality):
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class RegressionModelQuality(ModelQuality):
class BaseRegressionMetrics(BaseModel):
r2: Optional[float] = None
mae: Optional[float] = None
mse: Optional[float] = None
var: Optional[float] = None
variance: Optional[float] = None
mape: Optional[float] = None
rmse: Optional[float] = None
adj_r2: Optional[float] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CurrentRegressionModelQuality(ModelQuality):
class GroupedBaseRegressionMetrics(BaseModel):
r2: List[Distribution]
mae: List[Distribution]
mse: List[Distribution]
variance: List[Distribution]
mape: List[Distribution]
rmse: List[Distribution]
adj_r2: List[Distribution]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class RegressionModelQuality(ModelQuality, BaseRegressionMetrics):
pass


class CurrentRegressionModelQuality(ModelQuality):
global_metrics: BaseRegressionMetrics
grouped_metrics: GroupedBaseRegressionMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
68 changes: 64 additions & 4 deletions sdk/tests/apis/model_current_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,13 @@ def test_regression_model_quality_ok(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
r2 = 0.91
mae = 125.01
mse = 408.76
variance = 393.31
mape = 35.19
rmse = 202.23
adj_r2 = 0.91
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
Expand All @@ -915,17 +922,70 @@ def test_regression_model_quality_ok(self):
method=responses.GET,
url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/model-quality',
status=200,
body="""{
body=f"""{{
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"modelQuality": {}
}""",
"modelQuality": {{
"global_metrics": {{
"r2": {r2},
"mae": {mae},
"mse": {mse},
"variance": {variance},
"mape": {mape},
"rmse": {rmse},
"adjR2": {adj_r2}
}},
"grouped_metrics": {{
"r2": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {r2}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.88}}
],
"mae": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {mae}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}}
],
"mse": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {mse}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.85}}
],
"variance": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {variance}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}}
],
"mape": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {mape}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}}
],
"rmse": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {rmse}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}}
],
"adjR2": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {adj_r2}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}}
]
}}
}}
}}""",
)

metrics = model_current_dataset.model_quality()

assert isinstance(metrics, CurrentRegressionModelQuality)
# TODO: add asserts to properties
assert metrics.global_metrics.r2 == r2
assert metrics.global_metrics.mae == mae
assert metrics.global_metrics.mse == mse
assert metrics.global_metrics.variance == variance
assert metrics.global_metrics.mape == mape
assert metrics.global_metrics.rmse == rmse
assert metrics.global_metrics.adj_r2 == adj_r2
assert metrics.grouped_metrics.r2[0].value == r2
assert metrics.grouped_metrics.mae[0].value == mae
assert metrics.grouped_metrics.mse[0].value == mse
assert metrics.grouped_metrics.variance[0].value == variance
assert metrics.grouped_metrics.mape[0].value == mape
assert metrics.grouped_metrics.rmse[0].value == rmse
assert metrics.grouped_metrics.adj_r2[0].value == adj_r2
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
Expand Down
6 changes: 3 additions & 3 deletions sdk/tests/apis/model_reference_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def test_regression_model_metrics_ok(self):
r2 = 0.91
mae = 125.01
mse = 408.76
var = 393.31
variance = 393.31
mape = 35.19
rmse = 202.23
adj_r2 = 0.91
Expand All @@ -382,7 +382,7 @@ def test_regression_model_metrics_ok(self):
"r2": {r2},
"mae": {mae},
"mse": {mse},
"var": {var},
"variance": {variance},
"mape": {mape},
"rmse": {rmse},
"adjR2": {adj_r2}
Expand All @@ -396,7 +396,7 @@ def test_regression_model_metrics_ok(self):
assert metrics.r2 == r2
assert metrics.mae == mae
assert metrics.mse == mse
assert metrics.var == var
assert metrics.variance == variance
assert metrics.mape == mape
assert metrics.rmse == rmse
assert metrics.adj_r2 == adj_r2
Expand Down

0 comments on commit 65f4f65

Please sign in to comment.