Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add current model quality regression #90

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,61 @@ class MultiClassificationModelQuality(BaseModel):
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class RegressionModelQuality(BaseModel):
class RegressionMetricsBase(BaseModel):
r2: Optional[float] = None
mae: Optional[float] = None
mse: Optional[float] = None
var: Optional[float] = None
variance: Optional[float] = None
mape: Optional[float] = None
rmse: Optional[float] = None
adj_r2: Optional[float] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class BaseRegressionMetrics(BaseModel):
r2: Optional[float] = None
mae: Optional[float] = None
mse: Optional[float] = None
variance: Optional[float] = None
mape: Optional[float] = None
rmse: Optional[float] = None
adj_r2: Optional[float] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class GroupedBaseRegressionMetrics(BaseModel):
r2: List[Distribution]
mae: List[Distribution]
mse: List[Distribution]
variance: List[Distribution]
mape: List[Distribution]
rmse: List[Distribution]
adj_r2: List[Distribution]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class RegressionModelQuality(BaseRegressionMetrics):
pass


class CurrentRegressionModelQuality(BaseModel):
global_metrics: BaseRegressionMetrics
grouped_metrics: GroupedBaseRegressionMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class ModelQualityDTO(BaseModel):
job_status: JobStatus
model_quality: Optional[
BinaryClassificationModelQuality
| CurrentBinaryClassificationModelQuality
| MultiClassificationModelQuality
| RegressionModelQuality
| CurrentRegressionModelQuality
]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
Expand Down Expand Up @@ -143,11 +179,6 @@ def _create_model_quality(
model_type: ModelType,
dataset_type: DatasetType,
model_quality_data: Dict,
) -> (
BinaryClassificationModelQuality
| CurrentBinaryClassificationModelQuality
| MultiClassificationModelQuality
| RegressionModelQuality
):
"""Create a specific model quality instance based on model type and dataset type."""
if model_type == ModelType.BINARY:
Expand All @@ -158,7 +189,9 @@ def _create_model_quality(
if model_type == ModelType.MULTI_CLASS:
return MultiClassificationModelQuality(**model_quality_data)
if model_type == ModelType.REGRESSION:
return RegressionModelQuality(**model_quality_data)
return ModelQualityDTO._create_regression_model_quality(
dataset_type=dataset_type, model_quality_data=model_quality_data
)
raise MetricsInternalError(f'Invalid model type {model_type}')

@staticmethod
Expand All @@ -172,3 +205,15 @@ def _create_binary_model_quality(
if dataset_type == DatasetType.CURRENT:
return CurrentBinaryClassificationModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')

@staticmethod
def _create_regression_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> RegressionModelQuality | CurrentRegressionModelQuality:
"""Create a binary model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return RegressionModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
return CurrentRegressionModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')
37 changes: 37 additions & 0 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,43 @@ def get_sample_current_dataset(
'mape': 35.19314237273801,
'rmse': 202.23194752188695,
'adj_r2': 0.9116805380966796,
'variance': 0.23
}

grouped_regression_model_quality_dict = {
'r2': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.8},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.85},
],
'mae': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.88},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.9},
],
'mse': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.86},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.88},
],
'mape': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.81},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.83},
],
'rmse': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.8},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.85},
],
'adj_r2': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.85},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.87},
],
'variance': [
{'timestamp': '2024-01-01T00:00:00Z', 'value': 0.82},
{'timestamp': '2024-02-01T00:00:00Z', 'value': 0.84},
],
}

current_regression_model_quality_dict = {
'global_metrics': regression_model_quality_dict,
'grouped_metrics': grouped_regression_model_quality_dict,
}

regression_data_quality_dict = {
Expand Down
8 changes: 3 additions & 5 deletions api/tests/routes/upload_dataset_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_upload_current(self):
path='test',
date=str(datetime.datetime.now(tz=datetime.UTC)),
status=JobStatus.IMPORTING,
correlation_id_column=None
correlation_id_column=None,
)
self.file_service.upload_current_file = MagicMock(
return_value=upload_file_result
Expand All @@ -103,11 +103,9 @@ def test_bind_current(self):
path='test',
date=str(datetime.datetime.now(tz=datetime.UTC)),
status=JobStatus.IMPORTING,
correlation_id_column=None
)
self.file_service.bind_current_file = MagicMock(
return_value=upload_file_result
correlation_id_column=None,
)
self.file_service.bind_current_file = MagicMock(return_value=upload_file_result)
res = self.client.post(
f'{self.prefix}/{model_uuid}/current/bind',
json=jsonable_encoder(file_ref),
Expand Down
3 changes: 2 additions & 1 deletion api/tests/services/file_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ def test_upload_current_file_ok(self):
self.spark_k8s_client.submit_app = MagicMock()

result = self.files_service.upload_current_file(
model.uuid, file,
model.uuid,
file,
)

self.model_svc.get_model_by_uuid.assert_called_once()
Expand Down
31 changes: 31 additions & 0 deletions api/tests/services/metrics_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,37 @@ def test_get_current_multiclass_model_quality_by_model_by_uuid(self):
model_quality_data=current_metrics.model_quality,
)

def test_get_current_regression_model_quality_by_model_by_uuid(self):
status = JobStatus.SUCCEEDED
current_dataset = db_mock.get_sample_current_dataset(status=status.value)
current_metrics = db_mock.get_sample_current_metrics(
model_quality=db_mock.current_regression_model_quality_dict
)
model = db_mock.get_sample_model(model_type=ModelType.REGRESSION)
self.model_service.get_model_by_uuid = MagicMock(return_value=model)
self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock(
return_value=current_dataset
)
self.current_metrics_dao.get_current_metrics_by_model_uuid = MagicMock(
return_value=current_metrics
)
res = self.metrics_service.get_current_model_quality_by_model_by_uuid(
model_uuid, current_dataset.uuid
)
self.current_dataset_dao.get_current_dataset_by_model_uuid.assert_called_once_with(
model_uuid, current_dataset.uuid
)
self.current_metrics_dao.get_current_metrics_by_model_uuid.assert_called_once_with(
model_uuid, current_dataset.uuid
)

assert res == ModelQualityDTO.from_dict(
dataset_type=DatasetType.CURRENT,
model_type=model.model_type,
job_status=current_dataset.status,
model_quality_data=current_metrics.model_quality,
)


model_uuid = db_mock.MODEL_UUID
current_uuid = db_mock.CURRENT_UUID
25 changes: 22 additions & 3 deletions sdk/radicalbit_platform_sdk/models/dataset_model_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,36 @@ class MultiClassificationModelQuality(ModelQuality):
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class RegressionModelQuality(ModelQuality):
class BaseRegressionMetrics(BaseModel):
r2: Optional[float] = None
mae: Optional[float] = None
mse: Optional[float] = None
var: Optional[float] = None
variance: Optional[float] = None
mape: Optional[float] = None
rmse: Optional[float] = None
adj_r2: Optional[float] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CurrentRegressionModelQuality(ModelQuality):
class GroupedBaseRegressionMetrics(BaseModel):
r2: List[Distribution]
mae: List[Distribution]
mse: List[Distribution]
variance: List[Distribution]
mape: List[Distribution]
rmse: List[Distribution]
adj_r2: List[Distribution]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class RegressionModelQuality(ModelQuality, BaseRegressionMetrics):
pass


class CurrentRegressionModelQuality(ModelQuality):
global_metrics: BaseRegressionMetrics
grouped_metrics: GroupedBaseRegressionMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
68 changes: 64 additions & 4 deletions sdk/tests/apis/model_current_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,13 @@ def test_regression_model_quality_ok(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
r2 = 0.91
mae = 125.01
mse = 408.76
variance = 393.31
mape = 35.19
rmse = 202.23
adj_r2 = 0.91
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
Expand All @@ -915,17 +922,70 @@ def test_regression_model_quality_ok(self):
method=responses.GET,
url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/model-quality',
status=200,
body="""{
body=f"""{{
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"modelQuality": {}
}""",
"modelQuality": {{
"global_metrics": {{
"r2": {r2},
"mae": {mae},
"mse": {mse},
"variance": {variance},
"mape": {mape},
"rmse": {rmse},
"adjR2": {adj_r2}
}},
"grouped_metrics": {{
"r2": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {r2}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.88}}
],
"mae": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {mae}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}}
],
"mse": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {mse}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.85}}
],
"variance": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {variance}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.83}}
],
"mape": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {mape}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}}
],
"rmse": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {rmse}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}}
],
"adjR2": [
{{"timestamp": "2024-01-01T00:00:00Z", "value": {adj_r2}}},
{{"timestamp": "2024-02-01T00:00:00Z", "value": 0.12}}
]
}}
}}
}}""",
)

metrics = model_current_dataset.model_quality()

assert isinstance(metrics, CurrentRegressionModelQuality)
# TODO: add asserts to properties
assert metrics.global_metrics.r2 == r2
assert metrics.global_metrics.mae == mae
assert metrics.global_metrics.mse == mse
assert metrics.global_metrics.variance == variance
assert metrics.global_metrics.mape == mape
assert metrics.global_metrics.rmse == rmse
assert metrics.global_metrics.adj_r2 == adj_r2
assert metrics.grouped_metrics.r2[0].value == r2
assert metrics.grouped_metrics.mae[0].value == mae
assert metrics.grouped_metrics.mse[0].value == mse
assert metrics.grouped_metrics.variance[0].value == variance
assert metrics.grouped_metrics.mape[0].value == mape
assert metrics.grouped_metrics.rmse[0].value == rmse
assert metrics.grouped_metrics.adj_r2[0].value == adj_r2
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
Expand Down
6 changes: 3 additions & 3 deletions sdk/tests/apis/model_reference_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def test_regression_model_metrics_ok(self):
r2 = 0.91
mae = 125.01
mse = 408.76
var = 393.31
variance = 393.31
mape = 35.19
rmse = 202.23
adj_r2 = 0.91
Expand All @@ -382,7 +382,7 @@ def test_regression_model_metrics_ok(self):
"r2": {r2},
"mae": {mae},
"mse": {mse},
"var": {var},
"variance": {variance},
"mape": {mape},
"rmse": {rmse},
"adjR2": {adj_r2}
Expand All @@ -396,7 +396,7 @@ def test_regression_model_metrics_ok(self):
assert metrics.r2 == r2
assert metrics.mae == mae
assert metrics.mse == mse
assert metrics.var == var
assert metrics.variance == variance
assert metrics.mape == mape
assert metrics.rmse == rmse
assert metrics.adj_r2 == adj_r2
Expand Down