Skip to content

Commit

Permalink
Added equals macro that handles null value comparison (#394)
Browse files Browse the repository at this point in the history
Co-authored-by: Mila Page <[email protected]>
  • Loading branch information
adrianburusdbt and VersusFacit authored Dec 23, 2024
1 parent 5741dc2 commit a035cd9
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 33 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20241217-110536.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Added new equals macro that handles null value checks in sql
time: 2024-12-17T11:05:36.363421+02:00
custom:
Author: adrianburusdbt
Issue: "159"
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@
select 'NY','New York','Manhattan','2021-04-01'
union all
select 'PA','Philadelphia','Philadelphia','2021-05-21'
union all
select 'CO','Denver',null,'2021-06-18'
"""

Expand All @@ -265,6 +267,8 @@
select 'NY','New York','Manhattan','2021-04-01'
union all
select 'PA','Philadelphia','Philadelphia','2021-05-21'
union all
select 'CO','Denver',null,'2021-06-18'
"""

Expand All @@ -288,6 +292,7 @@
NY,Kings,Brooklyn,2021-04-02
NY,New York,Manhattan,2021-04-01
PA,Philadelphia,Philadelphia,2021-05-21
CO,Denver,,2021-06-18
"""

seeds__add_new_rows_sql = """
Expand Down Expand Up @@ -439,7 +444,7 @@ def fail_to_build_inc_missing_unique_key_column(self, incremental_model_name):
def test__no_unique_keys(self, project):
"""with no unique keys, seed and model should match"""

expected_fields = self.get_expected_fields(relation="seed", seed_rows=8)
expected_fields = self.get_expected_fields(relation="seed", seed_rows=9)
test_case_fields = self.get_test_fields(
project, seed="seed", incremental_model="no_unique_key", update_sql_file="add_new_rows"
)
Expand All @@ -449,7 +454,7 @@ def test__no_unique_keys(self, project):
def test__empty_str_unique_key(self, project):
"""with empty string for unique key, seed and model should match"""

expected_fields = self.get_expected_fields(relation="seed", seed_rows=8)
expected_fields = self.get_expected_fields(relation="seed", seed_rows=9)
test_case_fields = self.get_test_fields(
project,
seed="seed",
Expand All @@ -462,7 +467,7 @@ def test__one_unique_key(self, project):
"""with one unique key, model will overwrite existing row"""

expected_fields = self.get_expected_fields(
relation="one_str__overwrite", seed_rows=7, opt_model_count=1
relation="one_str__overwrite", seed_rows=8, opt_model_count=1
)
test_case_fields = self.get_test_fields(
project,
Expand All @@ -487,7 +492,7 @@ def test__bad_unique_key(self, project):
def test__empty_unique_key_list(self, project):
"""with no unique keys, seed and model should match"""

expected_fields = self.get_expected_fields(relation="seed", seed_rows=8)
expected_fields = self.get_expected_fields(relation="seed", seed_rows=9)
test_case_fields = self.get_test_fields(
project,
seed="seed",
Expand All @@ -500,7 +505,7 @@ def test__unary_unique_key_list(self, project):
"""with one unique key, model will overwrite existing row"""

expected_fields = self.get_expected_fields(
relation="unique_key_list__inplace_overwrite", seed_rows=7, opt_model_count=1
relation="unique_key_list__inplace_overwrite", seed_rows=8, opt_model_count=1
)
test_case_fields = self.get_test_fields(
project,
Expand All @@ -515,7 +520,7 @@ def test__duplicated_unary_unique_key_list(self, project):
"""with two of the same unique key, model will overwrite existing row"""

expected_fields = self.get_expected_fields(
relation="unique_key_list__inplace_overwrite", seed_rows=7, opt_model_count=1
relation="unique_key_list__inplace_overwrite", seed_rows=8, opt_model_count=1
)
test_case_fields = self.get_test_fields(
project,
Expand All @@ -530,7 +535,7 @@ def test__trinary_unique_key_list(self, project):
"""with three unique keys, model will overwrite existing row"""

expected_fields = self.get_expected_fields(
relation="unique_key_list__inplace_overwrite", seed_rows=7, opt_model_count=1
relation="unique_key_list__inplace_overwrite", seed_rows=8, opt_model_count=1
)
test_case_fields = self.get_test_fields(
project,
Expand All @@ -545,7 +550,7 @@ def test__trinary_unique_key_list_no_update(self, project):
"""even with three unique keys, adding distinct rows to seed does not
cause seed and model to diverge"""

expected_fields = self.get_expected_fields(relation="seed", seed_rows=8)
expected_fields = self.get_expected_fields(relation="seed", seed_rows=9)
test_case_fields = self.get_test_fields(
project,
seed="seed",
Expand Down
19 changes: 9 additions & 10 deletions dbt-tests-adapter/dbt/tests/adapter/utils/base_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import pytest
from dbt.tests.util import run_dbt


macros__equals_sql = """
{% macro equals(expr1, expr2) -%}
case when (({{ expr1 }} = {{ expr2 }}) or ({{ expr1 }} is null and {{ expr2 }} is null))
then 0
else 1
end = 0
{% endmacro %}
"""

macros__test_assert_equal_sql = """
{% test assert_equal(model, actual, expected) %}
select * from {{ model }}
Expand All @@ -27,6 +17,15 @@
{% endmacro %}
"""

macros__equals_sql = """
{% macro equals(expr1, expr2) -%}
case when (({{ expr1 }} = {{ expr2 }}) or ({{ expr1 }} is null and {{ expr2 }} is null))
then 0
else 1
end = 0
{% endmacro %}
"""


class BaseUtils:
# setup
Expand Down
8 changes: 1 addition & 7 deletions dbt-tests-adapter/dbt/tests/adapter/utils/test_equals.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import pytest

from dbt.tests.adapter.utils import base_utils, fixture_equals
from dbt.tests.adapter.utils import fixture_equals
from dbt.tests.util import relation_from_name, run_dbt


class BaseEquals:
@pytest.fixture(scope="class")
def macros(self):
return {
"equals.sql": base_utils.macros__equals_sql,
}

@pytest.fixture(scope="class")
def seeds(self):
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@
{% do predicates.append(this_key_match) %}
{% endfor %}
{% else %}
{% set source_unique_key %}
DBT_INTERNAL_SOURCE.{{ unique_key }}
{% endset %}
{% set target_unique_key %}
DBT_INTERNAL_DEST.{{ unique_key }}
{% endset %}
{% set unique_key_match %}
DBT_INTERNAL_SOURCE.{{ unique_key }} = DBT_INTERNAL_DEST.{{ unique_key }}
{{ equals(source_unique_key, target_unique_key) }}
{% endset %}
{% do predicates.append(unique_key_match) %}
{% endif %}
Expand Down Expand Up @@ -62,11 +68,18 @@

{% if unique_key %}
{% if unique_key is sequence and unique_key is not string %}
delete from {{target }}
delete from {{ target }}
using {{ source }}
where (
{% for key in unique_key %}
{{ source }}.{{ key }} = {{ target }}.{{ key }}
{% set source_unique_key %}
{{ source }}.{{ key }}
{% endset %}
{% set target_unique_key %}
{{ target }}.{{ key }}
{% endset %}

{{ equals(source_unique_key, target_unique_key) }}
{{ "and " if not loop.last}}
{% endfor %}
{% if incremental_predicates %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,14 @@
from {{ target_relation }}
where
{% if config.get('dbt_valid_to_current') %}
{# Check for either dbt_valid_to_current OR null, in order to correctly update records with nulls #}
( {{ columns.dbt_valid_to }} = {{ config.get('dbt_valid_to_current') }} or {{ columns.dbt_valid_to }} is null)
{% set source_unique_key %}
columns.dbt_valid_to
{% endset %}
{% set target_unique_key %}
config.get('dbt_valid_to_current')
{% endset %}

{{ equals(source_unique_key, target_unique_key) }}
{% else %}
{{ columns.dbt_valid_to }} is null
{% endif %}
Expand Down Expand Up @@ -276,7 +282,14 @@
{% macro unique_key_join_on(unique_key, identifier, from_identifier) %}
{% if unique_key | is_list %}
{% for key in unique_key %}
{{ identifier }}.dbt_unique_key_{{ loop.index }} = {{ from_identifier }}.dbt_unique_key_{{ loop.index }}
{% set source_unique_key %}
{{ identifier }}.dbt_unique_key_{{ loop.index }}
{% endset %}
{% set target_unique_key %}
{{ from_identifier }}.dbt_unique_key_{{ loop.index }}
{% endset %}

{{ equals(source_unique_key, target_unique_key) }}
{%- if not loop.last %} and {%- endif %}
{% endfor %}
{% else %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@

when matched
{% if config.get("dbt_valid_to_current") %}
and (DBT_INTERNAL_DEST.{{ columns.dbt_valid_to }} = {{ config.get('dbt_valid_to_current') }} or
DBT_INTERNAL_DEST.{{ columns.dbt_valid_to }} is null)
{% set source_unique_key %}
DBT_INTERNAL_DEST.{{ columns.dbt_valid_to }}
{% endset %}
{% set target_unique_key %}
{{ config.get('dbt_valid_to_current') }}
{% endset %}
and {{ equals(source_unique_key, target_unique_key) }}

{% else %}
and DBT_INTERNAL_DEST.{{ columns.dbt_valid_to }} is null
{% endif %}
Expand Down
12 changes: 12 additions & 0 deletions dbt/include/global_project/macros/utils/equals.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{% macro equals(expr1, expr2) %}
{{ return(adapter.dispatch('equals', 'dbt') (expr1, expr2)) }}
{%- endmacro %}

{% macro default__equals(expr1, expr2) -%}

case when (({{ expr1 }} = {{ expr2 }}) or ({{ expr1 }} is null and {{ expr2 }} is null))
then 0
else 1
end = 0

{% endmacro %}

0 comments on commit a035cd9

Please sign in to comment.