Skip to content

Commit

Permalink
Refactor macro contexts.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Jan 21, 2024
1 parent 069043b commit 47cfe30
Showing 1 changed file with 39 additions and 3 deletions.
42 changes: 39 additions & 3 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import tempfile
from ast import literal_eval
from collections import ChainMap
from contextlib import contextmanager
from itertools import chain, islice
from typing import List, Union, Set, Optional, Dict, Any, Iterator, Type, Callable
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type
from typing_extensions import Protocol

import jinja2
Expand Down Expand Up @@ -99,6 +100,41 @@ def _compile(self, source, filename):
return super()._compile(source, filename) # type: ignore


class MacroFuzzTemplate(jinja2.nativetypes.NativeTemplate):
environment_class = MacroFuzzEnvironment

def new_context(
self,
vars: Optional[Dict[str, Any]] = None,
shared: bool = False,
locals: Optional[Mapping[str, Any]] = None,
) -> jinja2.runtime.Context:
# This custom override makes the assumption that the locals and shared
# parameters are not used, so enforce that.
if shared or locals:
raise Exception("The MacroFuzzTemplate.new_context() override cannot use the shared or locals parameters.")

parent = ChainMap(vars, self.globals) if self.globals else vars

return self.environment.context_class(self.environment, parent, self.name, self.blocks)

def render(self, *args: Any, **kwargs: Any) -> Any:
if kwargs or len(args) != 1:
raise Exception("The MacroFuzzTemplate.render() override requires exactly one argument.")

ctx = self.new_context(args[0])

try:
return self.environment_class.concat( # type: ignore
self.root_render_func(ctx) # type: ignore
)
except Exception:
return self.environment.handle_exception()


MacroFuzzEnvironment.template_class = MacroFuzzTemplate


class NativeSandboxEnvironment(MacroFuzzEnvironment):
code_generator_class = jinja2.nativetypes.NativeCodeGenerator

Expand Down Expand Up @@ -171,7 +207,7 @@ def render(self, *args, **kwargs):
with :func:`ast.literal_eval`, the parsed value is returned.
Otherwise, the string is returned.
"""
vars = dict(*args, **kwargs)
vars = args[0]

try:
return quoted_native_concat(self.root_render_func(self.new_context(vars)))
Expand Down Expand Up @@ -226,7 +262,7 @@ def get_macro(self):
# make_module is in jinja2.environment. It returns a TemplateModule
module = template.make_module(vars=self.context, shared=False)
macro = module.__dict__[get_dbt_macro_name(name)]
module.__dict__.update(self.context)

return macro

@contextmanager
Expand Down

0 comments on commit 47cfe30

Please sign in to comment.