From e1f3d45312a0cf8b0b0e4efec9b6c693a5e5a9ad Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Tue, 23 Jan 2024 16:07:31 -0500 Subject: [PATCH 1/8] Inject tag iterator for greater flexibility. Fix some names. --- dbt_common/clients/_jinja_blocks.py | 4 ++-- dbt_common/clients/jinja.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dbt_common/clients/_jinja_blocks.py b/dbt_common/clients/_jinja_blocks.py index 9a830570..6b9a922f 100644 --- a/dbt_common/clients/_jinja_blocks.py +++ b/dbt_common/clients/_jinja_blocks.py @@ -272,8 +272,8 @@ def __iter__(self): class BlockIterator: - def __init__(self, data): - self.tag_parser = TagIterator(data) + def __init__(self, tag_iterator): + self.tag_parser = tag_iterator self.current = None self.stack = [] self.last_position = 0 diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index ca9a4b55..db297e6d 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -534,4 +534,5 @@ def extract_toplevel_blocks( :return: A list of `BlockTag`s matching the allowed block types and (if `collect_raw_data` is `True`) `BlockData` objects. """ - return BlockIterator(data).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data) + tag_iterator = TagIterator(data) + return BlockIterator(tag_iterator).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data) From 547867b043725d4cd10fba4419447802221aa023 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Tue, 23 Jan 2024 16:11:29 -0500 Subject: [PATCH 2/8] Inject TagIterator into BlockIterator for greater flexibility. --- .changes/unreleased/Under the Hood-20240123-161107.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Under the Hood-20240123-161107.yaml diff --git a/.changes/unreleased/Under the Hood-20240123-161107.yaml b/.changes/unreleased/Under the Hood-20240123-161107.yaml new file mode 100644 index 00000000..68a83ab4 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240123-161107.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Inject TagIterator into BlockIterator for greater flexibility. +time: 2024-01-23T16:11:07.24321-05:00 +custom: + Author: peterallenwebb + Issue: "38" From 23d0077d15ac7fe491908a167be976df1c1a0d25 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Tue, 23 Jan 2024 16:30:02 -0500 Subject: [PATCH 3/8] Refine names. --- dbt_common/clients/_jinja_blocks.py | 18 +++++++++--------- dbt_common/clients/jinja.py | 6 +++--- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/dbt_common/clients/_jinja_blocks.py b/dbt_common/clients/_jinja_blocks.py index 6b9a922f..3a09bbbf 100644 --- a/dbt_common/clients/_jinja_blocks.py +++ b/dbt_common/clients/_jinja_blocks.py @@ -98,22 +98,22 @@ def end_pat(self): class TagIterator: - def __init__(self, data): - self.data = data + def __init__(self, text): + self.text = text self.blocks = [] self._parenthesis_stack = [] self.pos = 0 def linepos(self, end=None) -> str: - """Given an absolute position in the input data, return a pair of + """Given an absolute position in the input text, return a pair of line number + relative position to the start of the line. """ end_val: int = self.pos if end is None else end - data = self.data[:end_val] + text = self.text[:end_val] # if not found, rfind returns -1, and -1+1=0, which is perfect! - last_line_start = data.rfind("\n") + 1 + last_line_start = text.rfind("\n") + 1 # it's easy to forget this, but line numbers are 1-indexed - line_number = data.count("\n") + 1 + line_number = text.count("\n") + 1 return f"{line_number}:{end_val - last_line_start}" def advance(self, new_position): @@ -123,10 +123,10 @@ def rewind(self, amount=1): self.pos -= amount def _search(self, pattern): - return pattern.search(self.data, self.pos) + return pattern.search(self.text, self.pos) def _match(self, pattern): - return pattern.match(self.data, self.pos) + return pattern.match(self.text, self.pos) def _first_match(self, *patterns, **kwargs): matches = [] @@ -147,7 +147,7 @@ def _first_match(self, *patterns, **kwargs): def _expect_match(self, expected_name, *patterns, **kwargs): match = self._first_match(*patterns, **kwargs) if match is None: - raise UnexpectedMacroEOFError(expected_name, self.data[self.pos :]) + raise UnexpectedMacroEOFError(expected_name, self.text[self.pos :]) return match def handle_expr(self, match): diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index db297e6d..b8c0d03a 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -22,7 +22,7 @@ get_materialization_macro_name, get_test_macro_name, ) -from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag +from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag, TagIterator from dbt_common.exceptions import ( CompilationError, @@ -516,7 +516,7 @@ def render_template(template, ctx: Dict[str, Any], node=None) -> str: def extract_toplevel_blocks( - data: str, + text: str, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True, ) -> List[Union[BlockData, BlockTag]]: @@ -534,5 +534,5 @@ def extract_toplevel_blocks( :return: A list of `BlockTag`s matching the allowed block types and (if `collect_raw_data` is `True`) `BlockData` objects. """ - tag_iterator = TagIterator(data) + tag_iterator = TagIterator(text) return BlockIterator(tag_iterator).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data) From f6040c651f3e9109f2319eb4c86094e64c2ffa61 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Tue, 23 Jan 2024 16:32:57 -0500 Subject: [PATCH 4/8] Fix name mismatch. --- dbt_common/clients/_jinja_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt_common/clients/_jinja_blocks.py b/dbt_common/clients/_jinja_blocks.py index 3a09bbbf..cdc069da 100644 --- a/dbt_common/clients/_jinja_blocks.py +++ b/dbt_common/clients/_jinja_blocks.py @@ -287,7 +287,7 @@ def current_end(self): @property def data(self): - return self.tag_parser.data + return self.tag_parser.text def is_current_end(self, tag): return ( From 49f15a38934202b124e8d775eaed60610a48c6fa Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Tue, 23 Jan 2024 18:31:32 -0500 Subject: [PATCH 5/8] Add type annotations --- dbt_common/clients/_jinja_blocks.py | 83 ++++++++++++++--------------- 1 file changed, 39 insertions(+), 44 deletions(-) diff --git a/dbt_common/clients/_jinja_blocks.py b/dbt_common/clients/_jinja_blocks.py index cdc069da..67daa3ff 100644 --- a/dbt_common/clients/_jinja_blocks.py +++ b/dbt_common/clients/_jinja_blocks.py @@ -1,5 +1,6 @@ import re from collections import namedtuple +from typing import Iterator, Optional, List from dbt_common.exceptions import ( BlockDefinitionNotAtTopError, @@ -12,40 +13,40 @@ ) -def regex(pat): +def regex(pat: str) -> re.Pattern: return re.compile(pat, re.DOTALL | re.MULTILINE) class BlockData: """raw plaintext data from the top level of the file.""" - def __init__(self, contents): + def __init__(self, contents: str) -> None: self.block_type_name = "__dbt__data" - self.contents = contents + self.contents: str = contents self.full_block = contents class BlockTag: - def __init__(self, block_type_name, block_name, contents=None, full_block=None, **kw): + def __init__(self, block_type_name: str, block_name: str, contents: Optional[str] = None, full_block: Optional[str] = None) -> None: self.block_type_name = block_type_name self.block_name = block_name self.contents = contents self.full_block = full_block - def __str__(self): + def __str__(self) -> str: return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name) - def __repr__(self): + def __repr__(self) -> str: return str(self) @property - def end_block_type_name(self): + def end_block_type_name(self) -> str: return "end{}".format(self.block_type_name) - def end_pat(self): + def end_pat(self) -> re.Pattern: # we don't want to use string formatting here because jinja uses most # of the string formatting operators in its syntax... - pattern = "".join( + pattern: str = "".join( ( r"(?P((?:\s*\{\%\-|\{\%)\s*", self.end_block_type_name, @@ -98,13 +99,11 @@ def end_pat(self): class TagIterator: - def __init__(self, text): - self.text = text - self.blocks = [] - self._parenthesis_stack = [] - self.pos = 0 + def __init__(self, text: str) -> None: + self.text: str = text + self.pos: int = 0 - def linepos(self, end=None) -> str: + def linepos(self, end: Optional[int] = None) -> str: """Given an absolute position in the input text, return a pair of line number + relative position to the start of the line. """ @@ -116,26 +115,22 @@ def linepos(self, end=None) -> str: line_number = text.count("\n") + 1 return f"{line_number}:{end_val - last_line_start}" - def advance(self, new_position): + def advance(self, new_position: int) -> None: self.pos = new_position - def rewind(self, amount=1): + def rewind(self, amount: int = 1) -> None: self.pos -= amount - def _search(self, pattern): + def _search(self, pattern: re.Pattern) -> Optional[re.Match]: return pattern.search(self.text, self.pos) - def _match(self, pattern): + def _match(self, pattern: re.Pattern) -> Optional[re.Match]: return pattern.match(self.text, self.pos) - def _first_match(self, *patterns, **kwargs): + def _first_match(self, *patterns) -> Optional[re.Match]: # type: ignore matches = [] for pattern in patterns: - # default to 'search', but sometimes we want to 'match'. - if kwargs.get("method", "search") == "search": - match = self._search(pattern) - else: - match = self._match(pattern) + match = self._search(pattern) if match: matches.append(match) if not matches: @@ -144,13 +139,13 @@ def _first_match(self, *patterns, **kwargs): # TODO: do I need to account for m.start(), or is this ok? return min(matches, key=lambda m: m.end()) - def _expect_match(self, expected_name, *patterns, **kwargs): - match = self._first_match(*patterns, **kwargs) + def _expect_match(self, expected_name: str, *patterns) -> re.Match: # type: ignore + match = self._first_match(*patterns) if match is None: - raise UnexpectedMacroEOFError(expected_name, self.text[self.pos :]) + raise UnexpectedMacroEOFError(expected_name, self.text[self.pos:]) return match - def handle_expr(self, match): + def handle_expr(self, match: re.Match) -> None: """Handle an expression. At this point we're at a string like: {{ 1 + 2 }} ^ right here @@ -176,12 +171,12 @@ def handle_expr(self, match): self.advance(match.end()) - def handle_comment(self, match): + def handle_comment(self, match: re.Match) -> None: self.advance(match.end()) match = self._expect_match("#}", COMMENT_END_PATTERN) self.advance(match.end()) - def _expect_block_close(self): + def _expect_block_close(self) -> None: """Search for the tag close marker. To the right of the type name, there are a few possiblities: - a name (handled by the regex's 'block_name') @@ -203,13 +198,13 @@ def _expect_block_close(self): string_match = self._expect_match("string", STRING_PATTERN) self.advance(string_match.end()) - def handle_raw(self): + def handle_raw(self) -> int: # raw blocks are super special, they are a single complete regex match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN) self.advance(match.end()) return match.end() - def handle_tag(self, match): + def handle_tag(self, match: re.Match) -> Tag: """The tag could be one of a few things: {% mytag %} @@ -234,7 +229,7 @@ def handle_tag(self, match): self._expect_block_close() return Tag(block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos) - def find_tags(self): + def find_tags(self) -> Iterator[Tag]: while True: match = self._first_match(BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN) if match is None: @@ -259,7 +254,7 @@ def find_tags(self): "Invalid regex match in next_block, expected block start, " "expr start, or comment start" ) - def __iter__(self): + def __iter__(self) -> Iterator[Tag]: return self.find_tags() @@ -272,31 +267,31 @@ def __iter__(self): class BlockIterator: - def __init__(self, tag_iterator): + def __init__(self, tag_iterator: TagIterator) -> None: self.tag_parser = tag_iterator - self.current = None - self.stack = [] - self.last_position = 0 + self.current: Optional[Tag] = None + self.stack: List[str] = [] + self.last_position: int = 0 @property - def current_end(self): + def current_end(self) -> int: if self.current is None: return 0 else: return self.current.end @property - def data(self): + def data(self) -> str: return self.tag_parser.text - def is_current_end(self, tag): + def is_current_end(self, tag: Tag) -> bool: return ( tag.block_type_name.startswith("end") and self.current is not None and tag.block_type_name[3:] == self.current.block_type_name ) - def find_blocks(self, allowed_blocks=None, collect_raw_data=True): + def find_blocks(self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True) -> Iterator[BlockData | BlockTag]: """Find all top-level blocks in the data.""" if allowed_blocks is None: allowed_blocks = {"snapshot", "macro", "materialization", "docs"} @@ -347,5 +342,5 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True): if raw_data: yield BlockData(raw_data) - def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True): + def lex_for_blocks(self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True) -> List[BlockData | BlockTag]: return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)) From b9c0c2d8006c6581f86c5241c8232d1cc14bea57 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Tue, 23 Jan 2024 18:41:28 -0500 Subject: [PATCH 6/8] Fix formatting for black. --- dbt_common/clients/_jinja_blocks.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/dbt_common/clients/_jinja_blocks.py b/dbt_common/clients/_jinja_blocks.py index 67daa3ff..41877db9 100644 --- a/dbt_common/clients/_jinja_blocks.py +++ b/dbt_common/clients/_jinja_blocks.py @@ -27,7 +27,9 @@ def __init__(self, contents: str) -> None: class BlockTag: - def __init__(self, block_type_name: str, block_name: str, contents: Optional[str] = None, full_block: Optional[str] = None) -> None: + def __init__( + self, block_type_name: str, block_name: str, contents: Optional[str] = None, full_block: Optional[str] = None + ) -> None: self.block_type_name = block_type_name self.block_name = block_name self.contents = contents @@ -142,7 +144,7 @@ def _first_match(self, *patterns) -> Optional[re.Match]: # type: ignore def _expect_match(self, expected_name: str, *patterns) -> re.Match: # type: ignore match = self._first_match(*patterns) if match is None: - raise UnexpectedMacroEOFError(expected_name, self.text[self.pos:]) + raise UnexpectedMacroEOFError(expected_name, self.text[self.pos :]) return match def handle_expr(self, match: re.Match) -> None: @@ -291,7 +293,9 @@ def is_current_end(self, tag: Tag) -> bool: and tag.block_type_name[3:] == self.current.block_type_name ) - def find_blocks(self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True) -> Iterator[BlockData | BlockTag]: + def find_blocks( + self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True + ) -> Iterator[BlockData | BlockTag]: """Find all top-level blocks in the data.""" if allowed_blocks is None: allowed_blocks = {"snapshot", "macro", "materialization", "docs"} @@ -342,5 +346,7 @@ def find_blocks(self, allowed_blocks: Optional[set[str]] = None, collect_raw_dat if raw_data: yield BlockData(raw_data) - def lex_for_blocks(self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True) -> List[BlockData | BlockTag]: + def lex_for_blocks( + self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True + ) -> List[BlockData | BlockTag]: return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)) From 3a9652838f4934426fb5a88be6b3e6ec7bdc0656 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Tue, 23 Jan 2024 18:46:12 -0500 Subject: [PATCH 7/8] Tweak type annotation. --- dbt_common/clients/_jinja_blocks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dbt_common/clients/_jinja_blocks.py b/dbt_common/clients/_jinja_blocks.py index 41877db9..62a90122 100644 --- a/dbt_common/clients/_jinja_blocks.py +++ b/dbt_common/clients/_jinja_blocks.py @@ -1,6 +1,6 @@ import re from collections import namedtuple -from typing import Iterator, Optional, List +from typing import Iterator, Optional, List, Set from dbt_common.exceptions import ( BlockDefinitionNotAtTopError, @@ -294,7 +294,7 @@ def is_current_end(self, tag: Tag) -> bool: ) def find_blocks( - self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True + self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True ) -> Iterator[BlockData | BlockTag]: """Find all top-level blocks in the data.""" if allowed_blocks is None: @@ -347,6 +347,6 @@ def find_blocks( yield BlockData(raw_data) def lex_for_blocks( - self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True + self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True ) -> List[BlockData | BlockTag]: return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)) From 7c3d207aeb9989c8511b8d913a68959b5d5b3a5c Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Tue, 23 Jan 2024 18:49:05 -0500 Subject: [PATCH 8/8] Use Union type to make earlier python versions happy. --- dbt_common/clients/_jinja_blocks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dbt_common/clients/_jinja_blocks.py b/dbt_common/clients/_jinja_blocks.py index 62a90122..c6058bfa 100644 --- a/dbt_common/clients/_jinja_blocks.py +++ b/dbt_common/clients/_jinja_blocks.py @@ -1,6 +1,6 @@ import re from collections import namedtuple -from typing import Iterator, Optional, List, Set +from typing import Iterator, List, Optional, Set, Union from dbt_common.exceptions import ( BlockDefinitionNotAtTopError, @@ -295,7 +295,7 @@ def is_current_end(self, tag: Tag) -> bool: def find_blocks( self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True - ) -> Iterator[BlockData | BlockTag]: + ) -> Iterator[Union[BlockData, BlockTag]]: """Find all top-level blocks in the data.""" if allowed_blocks is None: allowed_blocks = {"snapshot", "macro", "materialization", "docs"} @@ -348,5 +348,5 @@ def find_blocks( def lex_for_blocks( self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True - ) -> List[BlockData | BlockTag]: + ) -> List[Union[BlockData, BlockTag]]: return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data))