diff --git a/.changes/unreleased/Features-20240408-094132.yaml b/.changes/unreleased/Features-20240408-094132.yaml new file mode 100644 index 00000000000..0b7a251e926 --- /dev/null +++ b/.changes/unreleased/Features-20240408-094132.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support SQL in unit testing fixtures +time: 2024-04-08T09:41:32.15936-04:00 +custom: + Author: gshank + Issue: "9405" diff --git a/core/dbt/artifacts/resources/v1/unit_test_definition.py b/core/dbt/artifacts/resources/v1/unit_test_definition.py index 7ef10f52cdf..fc265fa36b9 100644 --- a/core/dbt/artifacts/resources/v1/unit_test_definition.py +++ b/core/dbt/artifacts/resources/v1/unit_test_definition.py @@ -30,6 +30,7 @@ class UnitTestConfig(BaseConfig): class UnitTestFormat(StrEnum): CSV = "csv" Dict = "dict" + SQL = "sql" @dataclass diff --git a/core/dbt/contracts/graph/model_config.py b/core/dbt/contracts/graph/model_config.py index 18765dc5eaa..b45c313327c 100644 --- a/core/dbt/contracts/graph/model_config.py +++ b/core/dbt/contracts/graph/model_config.py @@ -36,6 +36,7 @@ def insensitive_patterns(*patterns: str): @dataclass class UnitTestNodeConfig(NodeConfig): expected_rows: List[Dict[str, Any]] = field(default_factory=list) + expected_sql: Optional[str] = None @dataclass diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 134c272db23..e1f409ff1de 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -991,7 +991,7 @@ def same_contents(self, other: Optional["UnitTestDefinition"]) -> bool: @dataclass class UnitTestFileFixture(BaseNode): resource_type: Literal[NodeType.Fixture] - rows: Optional[List[Dict[str, Any]]] = None + rows: Optional[Union[List[Dict[str, Any]], str]] = None # ==================================== diff --git a/core/dbt/parser/fixtures.py b/core/dbt/parser/fixtures.py index f12cc6f272a..b3002725674 100644 --- a/core/dbt/parser/fixtures.py +++ b/core/dbt/parser/fixtures.py @@ -26,6 +26,11 @@ def parse_file(self, file_block: FileBlock): assert isinstance(file_block.file, FixtureSourceFile) unique_id = self.generate_unique_id(file_block.name) + if file_block.file.path.relative_path.endswith(".sql"): + rows = file_block.file.contents # type: ignore + else: # endswith('.csv') + rows = self.get_rows(file_block.file.contents) # type: ignore + fixture = UnitTestFileFixture( name=file_block.name, path=file_block.file.path.relative_path, @@ -33,7 +38,7 @@ def parse_file(self, file_block: FileBlock): package_name=self.project.project_name, unique_id=unique_id, resource_type=NodeType.Fixture, - rows=self.get_rows(file_block.file.contents), + rows=rows, ) self.manifest.add_fixture(file_block.file, fixture) diff --git a/core/dbt/parser/read_files.py b/core/dbt/parser/read_files.py index a44bd2fbb22..314a2a0fdd1 100644 --- a/core/dbt/parser/read_files.py +++ b/core/dbt/parser/read_files.py @@ -145,11 +145,11 @@ def get_source_files(project, paths, extension, parse_file_type, saved_files, ig if parse_file_type == ParseFileType.Seed: fb_list.append(load_seed_source_file(fp, project.project_name)) # singular tests live in /tests but only generic tests live - # in /tests/generic so we want to skip those + # in /tests/generic and fixtures in /tests/fixture so we want to skip those else: if parse_file_type == ParseFileType.SingularTest: path = pathlib.Path(fp.relative_path) - if path.parts[0] == "generic": + if path.parts[0] in ["generic", "fixtures"]: continue file = load_source_file(fp, parse_file_type, project.project_name, saved_files) # only append the list if it has contents. added to fix #3568 @@ -431,7 +431,7 @@ def get_file_types_for_project(project): }, ParseFileType.Fixture: { "paths": project.fixture_paths, - "extensions": [".csv"], + "extensions": [".csv", ".sql"], "parser": "FixtureParser", }, } diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index 763efab44aa..0abadca5cf9 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -68,6 +68,15 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): name = test_case.name if tested_node.is_versioned: name = name + f"_v{tested_node.version}" + expected_sql: Optional[str] = None + if test_case.expect.format == UnitTestFormat.SQL: + expected_rows: List[Dict[str, Any]] = [] + expected_sql = test_case.expect.rows # type: ignore + else: + assert isinstance(test_case.expect.rows, List) + expected_rows = deepcopy(test_case.expect.rows) + + assert isinstance(expected_rows, List) unit_test_node = UnitTestNode( name=name, resource_type=NodeType.Unit, @@ -76,8 +85,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): original_file_path=test_case.original_file_path, unique_id=test_case.unique_id, config=UnitTestNodeConfig( - materialized="unit", - expected_rows=deepcopy(test_case.expect.rows), # type:ignore + materialized="unit", expected_rows=expected_rows, expected_sql=expected_sql ), raw_code=tested_node.raw_code, database=tested_node.database, @@ -132,7 +140,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): "schema": original_input_node.schema, "fqn": original_input_node.fqn, "checksum": FileHash.empty(), - "raw_code": self._build_fixture_raw_code(given.rows, None), + "raw_code": self._build_fixture_raw_code(given.rows, None, given.format), "package_name": original_input_node.package_name, "unique_id": f"model.{original_input_node.package_name}.{input_name}", "name": input_name, @@ -172,12 +180,15 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): # Add unique ids of input_nodes to depends_on unit_test_node.depends_on.nodes.append(input_node.unique_id) - def _build_fixture_raw_code(self, rows, column_name_to_data_types) -> str: + def _build_fixture_raw_code(self, rows, column_name_to_data_types, fixture_format) -> str: # We're not currently using column_name_to_data_types, but leaving here for # possible future use. - return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format( - rows=rows, column_name_to_data_types=column_name_to_data_types - ) + if fixture_format == UnitTestFormat.SQL: + return rows + else: + return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format( + rows=rows, column_name_to_data_types=column_name_to_data_types + ) def _get_original_input_node(self, input: str, tested_node: ModelNode, test_case_name: str): """ @@ -352,13 +363,29 @@ def _validate_and_normalize_rows(self, ut_fixture, unit_test_definition, fixture ) if ut_fixture.fixture: - # find fixture file object and store unit_test_definition unique_id - fixture = self._get_fixture(ut_fixture.fixture, self.project.project_name) - fixture_source_file = self.manifest.files[fixture.file_id] - fixture_source_file.unit_tests.append(unit_test_definition.unique_id) - ut_fixture.rows = fixture.rows + ut_fixture.rows = self.get_fixture_file_rows( + ut_fixture.fixture, self.project.project_name, unit_test_definition.unique_id + ) else: ut_fixture.rows = self._convert_csv_to_list_of_dicts(ut_fixture.rows) + elif ut_fixture.format == UnitTestFormat.SQL: + if not (isinstance(ut_fixture.rows, str) or isinstance(ut_fixture.fixture, str)): + raise ParsingError( + f"Unit test {unit_test_definition.name} has {fixture_type} rows or fixtures " + f"which do not match format {ut_fixture.format}. Expected string." + ) + + if ut_fixture.fixture: + ut_fixture.rows = self.get_fixture_file_rows( + ut_fixture.fixture, self.project.project_name, unit_test_definition.unique_id + ) + + def get_fixture_file_rows(self, fixture_name, project_name, utdef_unique_id): + # find fixture file object and store unit_test_definition unique_id + fixture = self._get_fixture(fixture_name, project_name) + fixture_source_file = self.manifest.files[fixture.file_id] + fixture_source_file.unit_tests.append(utdef_unique_id) + return fixture.rows def _convert_csv_to_list_of_dicts(self, csv_string: str) -> List[Dict[str, Any]]: dummy_file = StringIO(csv_string) diff --git a/tests/functional/unit_testing/test_sql_format.py b/tests/functional/unit_testing/test_sql_format.py new file mode 100644 index 00000000000..6b5af93e1ba --- /dev/null +++ b/tests/functional/unit_testing/test_sql_format.py @@ -0,0 +1,245 @@ +import pytest +from dbt.tests.util import run_dbt + +wizards_csv = """id,w_name,email,email_tld,phone,world +1,Albus Dumbledore,a.dumbledore@gmail.com,gmail.com,813-456-9087,1 +2,Gandalf,gandy811@yahoo.com,yahoo.com,551-329-8367,2 +3,Winifred Sanderson,winnie@hocuspocus.com,hocuspocus.com,,6 +4,Marnie Piper,cromwellwitch@gmail.com,gmail.com,,5 +5,Grace Goheen,grace.goheen@dbtlabs.com,dbtlabs.com,,3 +6,Glinda,glinda_good@hotmail.com,hotmail.com,912-458-3289,4 +""" + +top_level_email_domains_csv = """tld +gmail.com +yahoo.com +hocuspocus.com +dbtlabs.com +hotmail.com +""" + +worlds_csv = """id,name +1,The Wizarding World +2,Middle-earth +3,dbt Labs +4,Oz +5,Halloweentown +6,Salem +""" + +stg_wizards_sql = """ +select + id as wizard_id, + w_name as wizard_name, + email, + email_tld as email_top_level_domain, + phone as phone_number, + world as world_id +from {{ ref('wizards') }} +""" + +stg_worlds_sql = """ +select + id as world_id, + name as world_name +from {{ ref('worlds') }} +""" + +dim_wizards_sql = """ +with wizards as ( + + select * from {{ ref('stg_wizards') }} + +), + +worlds as ( + + select * from {{ ref('stg_worlds') }} + +), + +accepted_email_domains as ( + + select * from {{ ref('top_level_email_domains') }} + +), + +check_valid_emails as ( + + select + wizards.wizard_id, + wizards.wizard_name, + wizards.email, + wizards.phone_number, + wizards.world_id, + + coalesce ( + wizards.email ~ '^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}$' + = true + and accepted_email_domains.tld is not null, + false) as is_valid_email_address + + from wizards + left join accepted_email_domains + on wizards.email_top_level_domain = lower(accepted_email_domains.tld) + +) + +select + check_valid_emails.wizard_id, + check_valid_emails.wizard_name, + check_valid_emails.email, + check_valid_emails.is_valid_email_address, + check_valid_emails.phone_number, + worlds.world_name +from check_valid_emails +left join worlds + on check_valid_emails.world_id = worlds.world_id +""" + +orig_schema_yml = """ +unit_tests: + - name: test_valid_email_address + model: dim_wizards + given: + - input: ref('stg_wizards') + rows: + - {email: cool@example.com, email_top_level_domain: example.com} + - {email: cool@unknown.com, email_top_level_domain: unknown.com} + - {email: badgmail.com, email_top_level_domain: gmail.com} + - {email: missingdot@gmailcom, email_top_level_domain: gmail.com} + - input: ref('top_level_email_domains') + rows: + - {tld: example.com} + - {tld: gmail.com} + - input: ref('stg_worlds') + rows: [] + expect: + rows: + - {email: cool@example.com, is_valid_email_address: true} + - {email: cool@unknown.com, is_valid_email_address: false} + - {email: badgmail.com, is_valid_email_address: false} + - {email: missingdot@gmailcom, is_valid_email_address: false} +""" + +schema_yml = """ +unit_tests: + - name: test_valid_email_address + model: dim_wizards + given: + - input: ref('stg_wizards') + format: sql + rows: | + select 1 as wizard_id, 'joe' as wizard_name, 'cool@example.com' as email, 'example.com' as email_top_level_domain, '123' as phone_number, 1 as world_id union all + select 2 as wizard_id, 'don' as wizard_name, 'cool@unknown.com' as email, 'unknown.com' as email_top_level_domain, '456' as phone_number, 2 as world_id union all + select 3 as wizard_id, 'mary' as wizard_name, 'badgmail.com' as email, 'gmail.com' as email_top_level_domain, '789' as phone_number, 3 as world_id union all + select 4 as wizard_id, 'jane' as wizard_name, 'missingdot@gmailcom' as email, 'gmail.com' as email_top_level_domain, '102' as phone_number, 4 as world_id + - input: ref('top_level_email_domains') + format: sql + rows: | + select 'example.com' as tld union all + select 'gmail.com' as tld + - input: ref('stg_worlds') + rows: [] + expect: + format: sql + rows: | + select 1 as wizard_id, 'joe' as wizard_name, 'cool@example.com' as email, true as is_valid_email_address, '123' as phone_number, null as world_name union all + select 2 as wizard_id, 'don' as wizard_name, 'cool@unknown.com' as email, false as is_valid_email_address, '456' as phone_number, null as world_name union all + select 3 as wizard_id, 'mary' as wizard_name, 'badgmail.com' as email, false as is_valid_email_address, '789' as phone_number, null as world_name union all + select 4 as wizard_id, 'jane' as wizard_name, 'missingdot@gmailcom' as email, false as is_valid_email_address, '102' as phone_number, null as world_name +""" + + +class TestSQLFormat: + @pytest.fixture(scope="class") + def seeds(self): + return { + "wizards.csv": wizards_csv, + "top_level_email_domains.csv": top_level_email_domains_csv, + "worlds.csv": worlds_csv, + } + + @pytest.fixture(scope="class") + def models(self): + return { + "stg_wizards.sql": stg_wizards_sql, + "stg_worlds.sql": stg_worlds_sql, + "dim_wizards.sql": dim_wizards_sql, + "schema.yml": schema_yml, + } + + def test_sql_format(self, project): + results = run_dbt(["build"]) + assert len(results) == 7 + + +stg_wizards_fixture_sql = """ + select 1 as wizard_id, 'joe' as wizard_name, 'cool@example.com' as email, 'example.com' as email_top_level_domain, '123' as phone_number, 1 as world_id union all + select 2 as wizard_id, 'don' as wizard_name, 'cool@unknown.com' as email, 'unknown.com' as email_top_level_domain, '456' as phone_number, 2 as world_id union all + select 3 as wizard_id, 'mary' as wizard_name, 'badgmail.com' as email, 'gmail.com' as email_top_level_domain, '789' as phone_number, 3 as world_id union all + select 4 as wizard_id, 'jane' as wizard_name, 'missingdot@gmailcom' as email, 'gmail.com' as email_top_level_domain, '102' as phone_number, 4 as world_id +""" + +top_level_email_domains_fixture_sql = """ + select 'example.com' as tld union all + select 'gmail.com' as tld +""" + +test_valid_email_address_fixture_sql = """ + select 1 as wizard_id, 'joe' as wizard_name, 'cool@example.com' as email, true as is_valid_email_address, '123' as phone_number, null as world_name union all + select 2 as wizard_id, 'don' as wizard_name, 'cool@unknown.com' as email, false as is_valid_email_address, '456' as phone_number, null as world_name union all + select 3 as wizard_id, 'mary' as wizard_name, 'badgmail.com' as email, false as is_valid_email_address, '789' as phone_number, null as world_name union all + select 4 as wizard_id, 'jane' as wizard_name, 'missingdot@gmailcom' as email, false as is_valid_email_address, '102' as phone_number, null as world_name +""" + +fixture_schema_yml = """ +unit_tests: + - name: test_valid_email_address + model: dim_wizards + given: + - input: ref('stg_wizards') + format: sql + fixture: stg_wizards_fixture + - input: ref('top_level_email_domains') + format: sql + fixture: top_level_email_domains_fixture + - input: ref('stg_worlds') + rows: [] + expect: + format: sql + fixture: test_valid_email_address_fixture +""" + + +class TestSQLFormatFixtures: + @pytest.fixture(scope="class") + def tests(self): + return { + "fixtures": { + "test_valid_email_address_fixture.sql": test_valid_email_address_fixture_sql, + "top_level_email_domains_fixture.sql": top_level_email_domains_fixture_sql, + "stg_wizards_fixture.sql": stg_wizards_fixture_sql, + } + } + + @pytest.fixture(scope="class") + def seeds(self): + return { + "wizards.csv": wizards_csv, + "top_level_email_domains.csv": top_level_email_domains_csv, + "worlds.csv": worlds_csv, + } + + @pytest.fixture(scope="class") + def models(self): + return { + "stg_wizards.sql": stg_wizards_sql, + "stg_worlds.sql": stg_worlds_sql, + "dim_wizards.sql": dim_wizards_sql, + "schema.yml": fixture_schema_yml, + } + + def test_sql_format_fixtures(self, project): + results = run_dbt(["build"]) + assert len(results) == 7