Skip to content

Commit

Permalink
feat: added field type in drift return, removes threshold of 15 for n…
Browse files Browse the repository at this point in the history
…umerical vatriables (#141)
  • Loading branch information
SteZamboni authored Jul 26, 2024
1 parent dd6e85b commit b08e97d
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 170 deletions.
93 changes: 20 additions & 73 deletions spark/jobs/metrics/drift_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from metrics.psi import PSI
from models.current_dataset import CurrentDataset
from models.reference_dataset import ReferenceDataset
from utils.models import FieldTypes


class DriftCalculator:
Expand All @@ -30,6 +31,7 @@ def calculate_drift(
for column in categorical_features:
feature_dict_to_append = {
"feature_name": column,
"field_type": FieldTypes.categorical.value,
"drift_calc": {
"type": "CHI2",
},
Expand All @@ -55,45 +57,18 @@ def calculate_drift(
for column in float_features:
feature_dict_to_append = {
"feature_name": column,
"drift_calc": {},
"field_type": FieldTypes.numerical.value,
"drift_calc": {
"type": "KS",
},
}
unique_values_ref = (
reference_dataset.reference.select(column)
.distinct()
.rdd.flatMap(lambda x: x)
.collect()
result_tmp = ks.test(column, column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["ks_statistic"]
)
unique_values_cur = (
current_dataset.current.select(column)
.distinct()
.rdd.flatMap(lambda x: x)
.collect()
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["ks_statistic"] > result_tmp["critical_value"]
)
unique_values_refcur = unique_values_ref + unique_values_cur
lookup = set()
unique_values_tot = [
x
for x in unique_values_refcur
if x is not None and x not in lookup and lookup.add(x) is None
]
if len(unique_values_tot) < 15:
feature_dict_to_append["drift_calc"]["type"] = "CHI2"
result_tmp = chi2.test_goodness_fit(column, column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["pValue"]
)
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["pValue"] <= 0.05
)
else:
feature_dict_to_append["drift_calc"]["type"] = "KS"
result_tmp = ks.test(column, column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["ks_statistic"]
)
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["ks_statistic"] > result_tmp["critical_value"]
)
drift_result["feature_metrics"].append(feature_dict_to_append)

int_features = [
Expand All @@ -107,46 +82,18 @@ def calculate_drift(
for column in int_features:
feature_dict_to_append = {
"feature_name": column,
"drift_calc": {},
"field_type": FieldTypes.numerical.value,
"drift_calc": {
"type": "PSI",
},
}
unique_values_ref = (
reference_dataset.reference.select(column)
.distinct()
.rdd.flatMap(lambda x: x)
.collect()
result_tmp = psi_obj.calculate_psi(column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["psi_value"]
)
unique_values_cur = (
current_dataset.current.select(column)
.distinct()
.rdd.flatMap(lambda x: x)
.collect()
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["psi_value"] >= 0.1
)
unique_values_refcur = unique_values_ref + unique_values_cur
lookup = set()
unique_values_tot = [
x
for x in unique_values_refcur
if x is not None and x not in lookup and lookup.add(x) is None
]
if len(unique_values_tot) < 15:
feature_dict_to_append["drift_calc"]["type"] = "CHI2"
feature_dict_to_append["drift_calc"]["type"] = "CHI2"
result_tmp = chi2.test_goodness_fit(column, column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["pValue"]
)
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["pValue"] <= 0.05
)
else:
feature_dict_to_append["drift_calc"]["type"] = "PSI"
result_tmp = psi_obj.calculate_psi(column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["psi_value"]
)
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["psi_value"] >= 0.1
)
drift_result["feature_metrics"].append(feature_dict_to_append)

return drift_result
20 changes: 10 additions & 10 deletions spark/tests/drift_calculator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def test_drift_phone(spark_fixture, drift_dataset_phone):
field_type=FieldTypes.categorical,
),
ColumnDefinition(
name="num_cores", type=SupportedTypes.int, field_type=FieldTypes.numerical
name="num_cores", type=SupportedTypes.int, field_type=FieldTypes.categorical
),
ColumnDefinition(
name="processor_speed",
Expand All @@ -592,7 +592,7 @@ def test_drift_phone(spark_fixture, drift_dataset_phone):
ColumnDefinition(
name="fast_charging_available",
type=SupportedTypes.int,
field_type=FieldTypes.numerical,
field_type=FieldTypes.categorical,
),
ColumnDefinition(
name="fast_charging",
Expand All @@ -602,12 +602,12 @@ def test_drift_phone(spark_fixture, drift_dataset_phone):
ColumnDefinition(
name="ram_capacity",
type=SupportedTypes.int,
field_type=FieldTypes.numerical,
field_type=FieldTypes.categorical,
),
ColumnDefinition(
name="internal_memory",
type=SupportedTypes.int,
field_type=FieldTypes.numerical,
field_type=FieldTypes.categorical,
),
ColumnDefinition(
name="screen_size",
Expand All @@ -622,12 +622,12 @@ def test_drift_phone(spark_fixture, drift_dataset_phone):
ColumnDefinition(
name="num_rear_cameras",
type=SupportedTypes.int,
field_type=FieldTypes.numerical,
field_type=FieldTypes.categorical,
),
ColumnDefinition(
name="num_front_cameras",
type=SupportedTypes.int,
field_type=FieldTypes.numerical,
field_type=FieldTypes.categorical,
),
ColumnDefinition(
name="os", type=SupportedTypes.string, field_type=FieldTypes.categorical
Expand All @@ -645,22 +645,22 @@ def test_drift_phone(spark_fixture, drift_dataset_phone):
ColumnDefinition(
name="extended_memory_available",
type=SupportedTypes.int,
field_type=FieldTypes.numerical,
field_type=FieldTypes.categorical,
),
ColumnDefinition(
name="extended_upto",
type=SupportedTypes.float,
field_type=FieldTypes.numerical,
field_type=FieldTypes.categorical,
),
ColumnDefinition(
name="resolution_width",
type=SupportedTypes.int,
field_type=FieldTypes.numerical,
field_type=FieldTypes.categorical,
),
ColumnDefinition(
name="resolution_height",
type=SupportedTypes.int,
field_type=FieldTypes.numerical,
field_type=FieldTypes.categorical,
),
]
model = ModelOut(
Expand Down
Loading

0 comments on commit b08e97d

Please sign in to comment.