Skip to content

Commit

Permalink
feat(sdk): implement get current dataset drift (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
maocorte authored Jun 21, 2024
1 parent b86c5ec commit 7513b39
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 4 deletions.
58 changes: 55 additions & 3 deletions sdk/radicalbit_platform_sdk/apis/model_current_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
CurrentFileUpload,
JobStatus,
DatasetStats,
Drift,
BinaryClassDrift,
)
from radicalbit_platform_sdk.errors import ClientError
from pydantic import ValidationError
Expand Down Expand Up @@ -96,9 +98,59 @@ def __callback(

return self.__statistics

def drift(self):
# TODO: implement get drift
pass
def drift(self) -> Optional[Drift]:
"""
Get drift about the current dataset
:return: The `Drift` if exists
"""

def __callback(
response: requests.Response,
) -> tuple[JobStatus, Optional[Drift]]:
try:
response_json = response.json()
job_status = JobStatus(response_json["jobStatus"])
if "drift" in response_json:
if self.__model_type is ModelType.BINARY:
return (
job_status,
BinaryClassDrift.model_validate(response_json["drift"]),
)
else:
raise ClientError(
"Unable to parse get metrics for not binary models"
)
else:
return job_status, None
except KeyError as _:
raise ClientError(f"Unable to parse response: {response.text}")
except ValidationError as _:
raise ClientError(f"Unable to parse response: {response.text}")

match self.__status:
case JobStatus.ERROR:
self.__drift = None
case JobStatus.SUCCEEDED:
if self.__drift is None:
_, drift = invoke(
method="GET",
url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/drift",
valid_response_code=200,
func=__callback,
)
self.__drift = drift
case JobStatus.IMPORTING:
status, drift = invoke(
method="GET",
url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/drift",
valid_response_code=200,
func=__callback,
)
self.__status = status
self.__drift = drift

return self.__drift

def data_quality(self):
# TODO: implement get data quality
Expand Down
16 changes: 16 additions & 0 deletions sdk/radicalbit_platform_sdk/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
CategoryFrequency,
CategoricalFeatureMetrics,
)
from .dataset_drift import (
DriftAlgorithm,
FeatureDriftCalculation,
FeatureDrift,
Drift,
BinaryClassDrift,
MultiClassDrift,
RegressionDrift,
)
from .column_definition import ColumnDefinition
from .aws_credentials import AwsCredentials

Expand Down Expand Up @@ -59,6 +68,13 @@
"NumericalFeatureMetrics",
"CategoryFrequency",
"CategoricalFeatureMetrics",
"DriftAlgorithm",
"FeatureDriftCalculation",
"FeatureDrift",
"Drift",
"BinaryClassDrift",
"MultiClassDrift",
"RegressionDrift",
"PaginatedModelDefinitions",
"ReferenceFileUpload",
"CurrentFileUpload",
Expand Down
43 changes: 43 additions & 0 deletions sdk/radicalbit_platform_sdk/models/dataset_drift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel


class DriftAlgorithm(str, Enum):
KS = "KS"
CHI2 = "CHI2"


class FeatureDriftCalculation(BaseModel):
type: DriftAlgorithm
value: Optional[float] = None
has_drift: bool

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class FeatureDrift(BaseModel):
feature_name: str
drift_calc: FeatureDriftCalculation

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class Drift(BaseModel):
pass


class BinaryClassDrift(Drift):
feature_metrics: List[FeatureDrift]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MultiClassDrift(Drift):
pass


class RegressionDrift(BaseModel):
pass
120 changes: 119 additions & 1 deletion sdk/tests/apis/model_current_dataset_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from radicalbit_platform_sdk.apis import ModelCurrentDataset
from radicalbit_platform_sdk.models import CurrentFileUpload, ModelType, JobStatus
from radicalbit_platform_sdk.models import CurrentFileUpload, ModelType, JobStatus, DriftAlgorithm
from radicalbit_platform_sdk.errors import ClientError
import responses
import unittest
Expand Down Expand Up @@ -129,3 +129,121 @@ def test_statistics_key_error(self):

with self.assertRaises(ClientError):
model_reference_dataset.statistics()

@responses.activate
def test_drift_ok(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
correlation_id_column="column",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift",
"status": 200,
"body": """{
"jobStatus": "SUCCEEDED",
"drift": {
"featureMetrics": [
{
"featureName": "gender",
"driftCalc": {"type": "CHI2", "value": 0.87, "hasDrift": true}
},
{
"featureName": "city",
"driftCalc": {"type": "CHI2", "value": 0.12, "hasDrift": false}
},
{
"featureName": "age",
"driftCalc": {"type": "KS", "value": 0.92, "hasDrift": true}
}
]
}
}""",
}
)

drift = model_reference_dataset.drift()

assert len(drift.feature_metrics) == 3
assert drift.feature_metrics[1].feature_name == "city"
assert drift.feature_metrics[1].drift_calc.type == DriftAlgorithm.CHI2
assert drift.feature_metrics[1].drift_calc.value == 0.12
assert drift.feature_metrics[1].drift_calc.has_drift is False
assert drift.feature_metrics[2].feature_name == "age"
assert drift.feature_metrics[2].drift_calc.type == DriftAlgorithm.KS
assert drift.feature_metrics[2].drift_calc.value == 0.92
assert drift.feature_metrics[2].drift_calc.has_drift is True
assert model_reference_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_drift_validation_error(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
correlation_id_column="column",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift",
"status": 200,
"body": '{"statistics": "wrong"}',
}
)

with self.assertRaises(ClientError):
model_reference_dataset.drift()

@responses.activate
def test_drift_key_error(self):
base_url = "http://api:9000"
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_reference_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.BINARY,
CurrentFileUpload(
uuid=import_uuid,
path="s3://bucket/file.csv",
date="2014",
correlation_id_column="column",
status=JobStatus.IMPORTING,
),
)

responses.add(
**{
"method": responses.GET,
"url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift",
"status": 200,
"body": '{"wrong": "json"}',
}
)

with self.assertRaises(ClientError):
model_reference_dataset.drift()

0 comments on commit 7513b39

Please sign in to comment.