diff --git a/src/triage/component/postmodeling/model_analyzer.py b/src/triage/component/postmodeling/model_analyzer.py index 1fbb2bd3b..ed782469f 100644 --- a/src/triage/component/postmodeling/model_analyzer.py +++ b/src/triage/component/postmodeling/model_analyzer.py @@ -99,7 +99,7 @@ def train_end_time(self): def train_label_timespan(self): return self.metadata['training_label_timespan'] - def get_predictions(self, matrix_uuid=None, fetch_null_labels=True, predictions_table='test_results.predictions'): + def get_predictions(self, matrix_uuid=None, fetch_null_labels=True, predictions_table='test_results.predictions', subset_hash=None): """Fetch the predictions from the DB for a given matrix args: @@ -114,6 +114,13 @@ def get_predictions(self, matrix_uuid=None, fetch_null_labels=True, predictions_ if not fetch_null_labels: where_clause += f" AND label_value IS NOT NULL" + if subset_hash: + # get subset table name + q = f"select config from triage_metadata.subsets where subset_hash='{subset_hash}'" + config_df = pd.read_sql(q, self.engine) + table_name = f"subset_{config_df.iloc[0]['config']['name']}_{subset_hash}" + predictions_table += f" preds join {table_name} subset on preds.entity_id = subset.entity_id and preds.as_of_date = subset.as_of_date" + query = f""" SELECT model_id, entity_id, @@ -195,13 +202,15 @@ def get_aequitas(self, parameter=None, attribute_name=None, subset_hash=None): where_clause = f'WHERE model_id={self.model_id}' if subset_hash is not None: - where_clause += f" AND subset_hash={subset_hash}" + where_clause += f" AND 'subset_hash='{subset_hash}'" + else: + where_clause += f" AND subset_hash=''" if parameter is not None: - where_clause += f" AND parameter={parameter}" + where_clause += f" AND parameter='{parameter}'" if attribute_name: - where_clause += f" AND attribute_name={attribute_name}" + where_clause += f" AND attribute_name='{attribute_name}'" # TODO don't return all columns ? q = f""" diff --git a/src/triage/component/postmodeling/report_generator.py b/src/triage/component/postmodeling/report_generator.py index cf4805109..bb9bfc0e7 100644 --- a/src/triage/component/postmodeling/report_generator.py +++ b/src/triage/component/postmodeling/report_generator.py @@ -88,9 +88,28 @@ def cohort_summary(self): evaluation_start_time as train_end_time, num_labeled_examples as cohort_size, num_positive_labels, - num_positive_labels::float/num_labeled_examples as label_base_rate + case when num_labeled_examples > 0 then num_positive_labels::float/num_labeled_examples else 0 end as label_base_rate from triage_metadata.experiment_matrices join test_results.evaluations using(matrix_uuid) - where experiment_hash in ('{"','".join(self.experiment_hashes)}') + where experiment_hash in ('{"','".join(self.experiment_hashes)}') and subset_hash = '' + order by 1 + """ + + matrices = pd.read_sql(q, self.engine) + + print(matrices) + + + def subset_summary(self, subset_hash): + q = f""" + select distinct on(train_end_time) + -- matrix_uuid, + evaluation_start_time as train_end_time, + num_labeled_examples as cohort_size, + num_positive_labels, + case when num_labeled_examples > 0 then num_positive_labels::float/num_labeled_examples else 0 end as label_base_rate + from triage_metadata.experiment_matrices join test_results.evaluations using(matrix_uuid) + where experiment_hash in ('{"','".join(self.experiment_hashes)}') + and subset_hash = '{subset_hash}' order by 1 """