diff --git a/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff.py b/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff.py index aac1903d46..b0eb623b66 100644 --- a/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff.py +++ b/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff.py @@ -39,6 +39,28 @@ """) +def get_affected_tables(): + """ + This method gets all the tables that are affected by this cleaning rule which are all the CDM tables + except for the person table. The birth date field in the person table will be cleaned in another + cleaning rule where all participants under the age of 18 will be dropped. Ignoring this table will + optimize this cleaning rule's runtime. + + :return: list of affected tables + """ + tables = [] + for table in AOU_REQUIRED: + + # skips the person table + if table == 'person': + continue + + # appends all CDM tables except for the person table + else: + tables.append(table) + return tables + + class EhrSubmissionDataCutoff(BaseCleaningRule): """ All rows of data in the RDR ETL with dates after the cutoff date should be sandboxed and dropped @@ -74,28 +96,8 @@ def __init__(self, affected_datasets=[cdr_consts.UNIONED], project_id=project_id, dataset_id=dataset_id, - sandbox_dataset_id=sandbox_dataset_id) - - def get_affected_tables(self): - """ - This method gets all the tables that are affected by this cleaning rule which are all the CDM tables - except for the person table. The birth date field in the person table will be cleaned in another - cleaning rule where all participants under the age of 18 will be dropped. Ignoring this table will - optimize this cleaning rule's runtime. - - :return: list of affected tables - """ - tables = [] - for table in AOU_REQUIRED: - - # skips the person table - if table == 'person': - continue - - # appends all CDM tables except for the person table - else: - tables.append(table) - return tables + sandbox_dataset_id=sandbox_dataset_id, + affected_tables=get_affected_tables()) def get_query_specs(self, *args, **keyword_args): """ @@ -109,7 +111,7 @@ def get_query_specs(self, *args, **keyword_args): queries_list = [] sandbox_queries_list = [] - for table in self.get_affected_tables(): + for table in self.affected_tables: # gets all fields from the affected table fields = fields_for(table) diff --git a/tests/integration_tests/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff_test.py b/tests/integration_tests/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff_test.py index 8604dcc439..965ea5b251 100644 --- a/tests/integration_tests/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff_test.py +++ b/tests/integration_tests/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff_test.py @@ -43,6 +43,14 @@ def setUpClass(cls): cutoff_date = "2020-05-01" cls.kwargs.update({'cutoff_date': cutoff_date}) + # mocks the return value of get_affected_tables as we only want to loop through the + # visit_occurrence not all of the CDM tables + get_affected_tables_patch = patch( + 'cdr_cleaner.cleaning_rules.ehr_submission_data_cutoff.get_affected_tables' + ) + cls.mock_get_affected_tables = get_affected_tables_patch.start() + cls.mock_get_affected_tables.return_value = [common.VISIT_OCCURRENCE] + cls.rule_instance = EhrSubmissionDataCutoff(project_id, dataset_id, sandbox_id) @@ -75,16 +83,11 @@ def setUp(self): super().setUp() - @patch.object(EhrSubmissionDataCutoff, 'get_affected_tables') - def test_ehr_submission_data_cutoff(self, mock_get_affected_tables): + def test_ehr_submission_data_cutoff(self): """ Validates pre conditions, tests execution, and post conditions based on the load statements and the tables_and_counts variable. """ - # mocks the return value of get_affected_tables as we only want to loop through the - # visit_occurrence not all of the CDM tables - mock_get_affected_tables.return_value = [common.VISIT_OCCURRENCE] - queries = [] visit_occurrence_tmpl = self.jinja_env.from_string(""" INSERT INTO `{{fq_dataset_name}}.{{cdm_table}}` @@ -133,3 +136,8 @@ def test_ehr_submission_data_cutoff(self, mock_get_affected_tables): }] self.default_test(table_and_counts) + + @classmethod + def tearDownClass(cls): + cls.mock_get_affected_tables.stop() + super().tearDownClass() diff --git a/tests/unit_tests/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff_test.py b/tests/unit_tests/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff_test.py index 293aeacb80..56eee06dd2 100644 --- a/tests/unit_tests/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff_test.py +++ b/tests/unit_tests/data_steward/cdr_cleaner/cleaning_rules/ehr_submission_data_cutoff_test.py @@ -45,6 +45,13 @@ def setUp(self): ] self.cutoff_date = str(datetime.now().date()) + get_affected_tables_patch = patch( + 'cdr_cleaner.cleaning_rules.ehr_submission_data_cutoff.get_affected_tables' + ) + mock_get_affected_tables = get_affected_tables_patch.start() + mock_get_affected_tables.return_value = [common.VISIT_OCCURRENCE] + self.addCleanup(mock_get_affected_tables.stop) + self.rule_instance = data_cutoff.EhrSubmissionDataCutoff( self.project_id, self.dataset_id, self.sandbox_id) @@ -58,11 +65,7 @@ def test_setup_rule(self): # No errors are raised, nothing will happen - @patch.object(data_cutoff.EhrSubmissionDataCutoff, 'get_affected_tables') - def test_get_query_specs(self, mock_get_affected_tables): - # Pre conditions - mock_get_affected_tables.return_value = [common.VISIT_OCCURRENCE] - + def test_get_query_specs(self): table = common.VISIT_OCCURRENCE sandbox_query = {