Skip to content

Commit

Permalink
fix: add subset for in get_errors calls for process_split and process…
Browse files Browse the repository at this point in the history
…_imputer
  • Loading branch information
hlbotterman committed Oct 23, 2024
1 parent 1233456 commit 40e900a
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions qolmat/benchmark/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def process_split(
df_with_holes = df_origin.copy()
df_with_holes[df_mask] = np.nan

subset = self.generator_holes.subset
if subset is None:
raise ValueError(
"HoleGenerator `subset` should be overwritten in split "
"but it is none!"
)

split_results = {}
for imputer_name, imputer in self.dict_imputers.items():
dict_config_opti_imputer = self.dict_config_opti.get(
Expand All @@ -131,7 +138,9 @@ def process_split(
)

df_imputed = imputer_opti.fit_transform(df_with_holes)
errors = self.get_errors(df_origin, df_imputed, df_mask)
errors = self.get_errors(
df_origin[subset], df_imputed[subset], df_mask[subset]
)
split_results[imputer_name] = errors

return pd.concat(split_results, axis=1)
Expand All @@ -154,6 +163,13 @@ def process_imputer(
"""
imputer_name, imputer, all_masks, df_origin = imputer_data

subset = self.generator_holes.subset
if subset is None:
raise ValueError(
"HoleGenerator `subset` should be overwritten in split "
"but it is none!"
)

dict_config_opti_imputer = self.dict_config_opti.get(imputer_name, {})
imputer_opti = hyperparameters.optimize(
imputer,
Expand All @@ -170,7 +186,9 @@ def process_imputer(
df_with_holes = df_origin.copy()
df_with_holes[df_mask] = np.nan
df_imputed = imputer_opti.fit_transform(df_with_holes)
errors = self.get_errors(df_origin, df_imputed, df_mask)
errors = self.get_errors(
df_origin[subset], df_imputed[subset], df_mask[subset]
)
imputer_results.append(errors)

return imputer_name, pd.concat(imputer_results).groupby(
Expand Down

0 comments on commit 40e900a

Please sign in to comment.