Skip to content

Commit

Permalink
postmodeling report subset fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Alice Lai committed Jan 10, 2024
1 parent 35955dd commit 80a0431
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
17 changes: 13 additions & 4 deletions src/triage/component/postmodeling/model_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def train_end_time(self):
def train_label_timespan(self):
return self.metadata['training_label_timespan']

def get_predictions(self, matrix_uuid=None, fetch_null_labels=True, predictions_table='test_results.predictions'):
def get_predictions(self, matrix_uuid=None, fetch_null_labels=True, predictions_table='test_results.predictions', subset_hash=None):
"""Fetch the predictions from the DB for a given matrix
args:
Expand All @@ -114,6 +114,13 @@ def get_predictions(self, matrix_uuid=None, fetch_null_labels=True, predictions_
if not fetch_null_labels:
where_clause += f" AND label_value IS NOT NULL"

if subset_hash:
# get subset table name
q = f"select config from triage_metadata.subsets where subset_hash='{subset_hash}'"
config_df = pd.read_sql(q, self.engine)
table_name = f"subset_{config_df.iloc[0]['config']['name']}_{subset_hash}"
predictions_table += f" preds join {table_name} subset on preds.entity_id = subset.entity_id and preds.as_of_date = subset.as_of_date"

query = f"""
SELECT model_id,
entity_id,
Expand Down Expand Up @@ -195,13 +202,15 @@ def get_aequitas(self, parameter=None, attribute_name=None, subset_hash=None):
where_clause = f'WHERE model_id={self.model_id}'

if subset_hash is not None:
where_clause += f" AND subset_hash={subset_hash}"
where_clause += f" AND 'subset_hash='{subset_hash}'"
else:
where_clause += f" AND subset_hash=''"

if parameter is not None:
where_clause += f" AND parameter={parameter}"
where_clause += f" AND parameter='{parameter}'"

if attribute_name:
where_clause += f" AND attribute_name={attribute_name}"
where_clause += f" AND attribute_name='{attribute_name}'"

# TODO don't return all columns ?
q = f"""
Expand Down
23 changes: 21 additions & 2 deletions src/triage/component/postmodeling/report_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,28 @@ def cohort_summary(self):
evaluation_start_time as train_end_time,
num_labeled_examples as cohort_size,
num_positive_labels,
num_positive_labels::float/num_labeled_examples as label_base_rate
case when num_labeled_examples > 0 then num_positive_labels::float/num_labeled_examples else 0 end as label_base_rate
from triage_metadata.experiment_matrices join test_results.evaluations using(matrix_uuid)
where experiment_hash in ('{"','".join(self.experiment_hashes)}')
where experiment_hash in ('{"','".join(self.experiment_hashes)}') and subset_hash = ''
order by 1
"""

matrices = pd.read_sql(q, self.engine)

print(matrices)


def subset_summary(self, subset_hash):
q = f"""
select distinct on(train_end_time)
-- matrix_uuid,
evaluation_start_time as train_end_time,
num_labeled_examples as cohort_size,
num_positive_labels,
case when num_labeled_examples > 0 then num_positive_labels::float/num_labeled_examples else 0 end as label_base_rate
from triage_metadata.experiment_matrices join test_results.evaluations using(matrix_uuid)
where experiment_hash in ('{"','".join(self.experiment_hashes)}')
and subset_hash = '{subset_hash}'
order by 1
"""

Expand Down

0 comments on commit 80a0431

Please sign in to comment.