Skip to content

Commit

Permalink
Support using sql in unit testing fixtures (#9873)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank authored Apr 17, 2024
1 parent a70024f commit 86b349f
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 17 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240408-094132.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support SQL in unit testing fixtures
time: 2024-04-08T09:41:32.15936-04:00
custom:
Author: gshank
Issue: "9405"
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/unit_test_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class UnitTestConfig(BaseConfig):
class UnitTestFormat(StrEnum):
CSV = "csv"
Dict = "dict"
SQL = "sql"


@dataclass
Expand Down
1 change: 1 addition & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def insensitive_patterns(*patterns: str):
@dataclass
class UnitTestNodeConfig(NodeConfig):
expected_rows: List[Dict[str, Any]] = field(default_factory=list)
expected_sql: Optional[str] = None


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def same_contents(self, other: Optional["UnitTestDefinition"]) -> bool:
@dataclass
class UnitTestFileFixture(BaseNode):
resource_type: Literal[NodeType.Fixture]
rows: Optional[List[Dict[str, Any]]] = None
rows: Optional[Union[List[Dict[str, Any]], str]] = None


# ====================================
Expand Down
7 changes: 6 additions & 1 deletion core/dbt/parser/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,19 @@ def parse_file(self, file_block: FileBlock):
assert isinstance(file_block.file, FixtureSourceFile)
unique_id = self.generate_unique_id(file_block.name)

if file_block.file.path.relative_path.endswith(".sql"):
rows = file_block.file.contents # type: ignore
else: # endswith('.csv')
rows = self.get_rows(file_block.file.contents) # type: ignore

fixture = UnitTestFileFixture(
name=file_block.name,
path=file_block.file.path.relative_path,
original_file_path=file_block.path.original_file_path,
package_name=self.project.project_name,
unique_id=unique_id,
resource_type=NodeType.Fixture,
rows=self.get_rows(file_block.file.contents),
rows=rows,
)
self.manifest.add_fixture(file_block.file, fixture)

Expand Down
6 changes: 3 additions & 3 deletions core/dbt/parser/read_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ def get_source_files(project, paths, extension, parse_file_type, saved_files, ig
if parse_file_type == ParseFileType.Seed:
fb_list.append(load_seed_source_file(fp, project.project_name))
# singular tests live in /tests but only generic tests live
# in /tests/generic so we want to skip those
# in /tests/generic and fixtures in /tests/fixture so we want to skip those
else:
if parse_file_type == ParseFileType.SingularTest:
path = pathlib.Path(fp.relative_path)
if path.parts[0] == "generic":
if path.parts[0] in ["generic", "fixtures"]:
continue
file = load_source_file(fp, parse_file_type, project.project_name, saved_files)
# only append the list if it has contents. added to fix #3568
Expand Down Expand Up @@ -431,7 +431,7 @@ def get_file_types_for_project(project):
},
ParseFileType.Fixture: {
"paths": project.fixture_paths,
"extensions": [".csv"],
"extensions": [".csv", ".sql"],
"parser": "FixtureParser",
},
}
Expand Down
51 changes: 39 additions & 12 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
name = test_case.name
if tested_node.is_versioned:
name = name + f"_v{tested_node.version}"
expected_sql: Optional[str] = None
if test_case.expect.format == UnitTestFormat.SQL:
expected_rows: List[Dict[str, Any]] = []
expected_sql = test_case.expect.rows # type: ignore
else:
assert isinstance(test_case.expect.rows, List)
expected_rows = deepcopy(test_case.expect.rows)

assert isinstance(expected_rows, List)
unit_test_node = UnitTestNode(
name=name,
resource_type=NodeType.Unit,
Expand All @@ -76,8 +85,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
original_file_path=test_case.original_file_path,
unique_id=test_case.unique_id,
config=UnitTestNodeConfig(
materialized="unit",
expected_rows=deepcopy(test_case.expect.rows), # type:ignore
materialized="unit", expected_rows=expected_rows, expected_sql=expected_sql
),
raw_code=tested_node.raw_code,
database=tested_node.database,
Expand Down Expand Up @@ -132,7 +140,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
"schema": original_input_node.schema,
"fqn": original_input_node.fqn,
"checksum": FileHash.empty(),
"raw_code": self._build_fixture_raw_code(given.rows, None),
"raw_code": self._build_fixture_raw_code(given.rows, None, given.format),
"package_name": original_input_node.package_name,
"unique_id": f"model.{original_input_node.package_name}.{input_name}",
"name": input_name,
Expand Down Expand Up @@ -172,12 +180,15 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
# Add unique ids of input_nodes to depends_on
unit_test_node.depends_on.nodes.append(input_node.unique_id)

def _build_fixture_raw_code(self, rows, column_name_to_data_types) -> str:
def _build_fixture_raw_code(self, rows, column_name_to_data_types, fixture_format) -> str:
# We're not currently using column_name_to_data_types, but leaving here for
# possible future use.
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
rows=rows, column_name_to_data_types=column_name_to_data_types
)
if fixture_format == UnitTestFormat.SQL:
return rows
else:
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
rows=rows, column_name_to_data_types=column_name_to_data_types
)

def _get_original_input_node(self, input: str, tested_node: ModelNode, test_case_name: str):
"""
Expand Down Expand Up @@ -352,13 +363,29 @@ def _validate_and_normalize_rows(self, ut_fixture, unit_test_definition, fixture
)

if ut_fixture.fixture:
# find fixture file object and store unit_test_definition unique_id
fixture = self._get_fixture(ut_fixture.fixture, self.project.project_name)
fixture_source_file = self.manifest.files[fixture.file_id]
fixture_source_file.unit_tests.append(unit_test_definition.unique_id)
ut_fixture.rows = fixture.rows
ut_fixture.rows = self.get_fixture_file_rows(
ut_fixture.fixture, self.project.project_name, unit_test_definition.unique_id
)
else:
ut_fixture.rows = self._convert_csv_to_list_of_dicts(ut_fixture.rows)
elif ut_fixture.format == UnitTestFormat.SQL:
if not (isinstance(ut_fixture.rows, str) or isinstance(ut_fixture.fixture, str)):
raise ParsingError(
f"Unit test {unit_test_definition.name} has {fixture_type} rows or fixtures "
f"which do not match format {ut_fixture.format}. Expected string."
)

if ut_fixture.fixture:
ut_fixture.rows = self.get_fixture_file_rows(
ut_fixture.fixture, self.project.project_name, unit_test_definition.unique_id
)

def get_fixture_file_rows(self, fixture_name, project_name, utdef_unique_id):
# find fixture file object and store unit_test_definition unique_id
fixture = self._get_fixture(fixture_name, project_name)
fixture_source_file = self.manifest.files[fixture.file_id]
fixture_source_file.unit_tests.append(utdef_unique_id)
return fixture.rows

def _convert_csv_to_list_of_dicts(self, csv_string: str) -> List[Dict[str, Any]]:
dummy_file = StringIO(csv_string)
Expand Down
Loading

0 comments on commit 86b349f

Please sign in to comment.