Skip to content

Commit

Permalink
Cache static objects between tests to accelerate integration runs.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Jan 31, 2024
1 parent 9e081a5 commit ce595e2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
41 changes: 38 additions & 3 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import jinja2.parser # type: ignore
import jinja2.sandbox # type: ignore

from dbt_common.tests import test_caching_enabled
from dbt_common.utils.jinja import (
get_dbt_macro_name,
get_docs_macro_name,
Expand Down Expand Up @@ -504,9 +505,19 @@ def catch_jinja(node=None) -> Iterator[None]:
raise


_TESTING_PARSE_CACHE: Dict[str, jinja2.Template] = {}


def parse(string):
str_string = str(string)
if test_caching_enabled() and str_string in _TESTING_PARSE_CACHE:
return _TESTING_PARSE_CACHE[str_string]

with catch_jinja():
return get_environment().parse(str(string))
parsed = get_environment().parse(str(string))
if test_caching_enabled():
_TESTING_PARSE_CACHE[str_string] = parsed
return parsed


def get_template(
Expand All @@ -528,6 +539,18 @@ def render_template(template, ctx: Dict[str, Any], node=None) -> str:
return template.render(ctx)


_TESTING_BLOCKS_CACHE: Dict[int, List[Union[BlockData, BlockTag]]] = {}


def _get_blocks_hash(text: str,
allowed_blocks: Optional[Set[str]],
collect_raw_data: bool
) -> int:
"""Provides a hash function over the arguments to extract_toplevel_blocks, in order to support caching."""
allowed_tuple = tuple(sorted(allowed_blocks) or [])
return text.__hash__() + allowed_tuple.__hash__() + collect_raw_data.__hash__()


def extract_toplevel_blocks(
text: str,
allowed_blocks: Optional[Set[str]] = None,
Expand All @@ -537,7 +560,7 @@ def extract_toplevel_blocks(
Includes some special handling for block nesting.
:param data: The data to extract blocks from.
:param text: The data to extract blocks from.
:param allowed_blocks: The names of the blocks to extract from the file.
They may not be nested within if/for blocks. If None, use the default
values.
Expand All @@ -548,7 +571,19 @@ def extract_toplevel_blocks(
:return: A list of `BlockTag`s matching the allowed block types and (if
`collect_raw_data` is `True`) `BlockData` objects.
"""

if test_caching_enabled():
hash = _get_blocks_hash(text, allowed_blocks, collect_raw_data)
if hash in _TESTING_BLOCKS_CACHE:
return _TESTING_BLOCKS_CACHE[hash]

tag_iterator = TagIterator(text)
return BlockIterator(tag_iterator).lex_for_blocks(
blocks = BlockIterator(tag_iterator).lex_for_blocks(
allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data
)

if test_caching_enabled():
hash = _get_blocks_hash(text, allowed_blocks, collect_raw_data)
_TESTING_BLOCKS_CACHE[hash] = blocks

return blocks
12 changes: 12 additions & 0 deletions dbt_common/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_TEST_CACHING_ENABLED: bool = False


def test_caching_enabled() -> bool:
return _TEST_CACHING_ENABLED


def enable_test_caching() -> None:
global _TEST_CACHING_ENABLED
if _TEST_CACHING_ENABLED is False:
print("ENABLING CACHES")
_TEST_CACHING_ENABLED = True

0 comments on commit ce595e2

Please sign in to comment.