Skip to content

Commit

Permalink
Improve error message for mismatched copies in code blocks (#31535)
Browse files Browse the repository at this point in the history
improve error message for mismatched code blocks
  • Loading branch information
molbap authored Jun 25, 2024
1 parent e73a97a commit aab0829
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions utils/check_copies.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@
},
}


# This is to make sure the transformers module imported is the one in the repo.
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)

Expand All @@ -185,7 +184,7 @@ def _should_continue(line: str, indent: str) -> bool:
return line.startswith(indent) or len(line.strip()) == 0 or _is_definition_header_ending_line(line)


def _sanity_check_splits(splits_1, splits_2, is_class):
def _sanity_check_splits(splits_1, splits_2, is_class, filename):
"""Check the two (inner) block structures of the corresponding code block given by `split_code_into_blocks` match.
For the case of `class`, they must be of one of the following 3 cases:
Expand Down Expand Up @@ -246,11 +245,12 @@ def g(x):
["block_without_name", "block_with_name"],
]:
raise ValueError(
"For a class, it must have a specific structure. See the docstring of `_sanity_check_splits` in the file `utils/check_copies.py`"
f"""Class defined in {filename} doesn't have the expected stucture.
See the docstring of `_sanity_check_splits` in the file `utils/check_copies.py`""",
)

if block_names_1 != block_names_2:
raise ValueError("The structures in the 2 code blocks differ.")
raise ValueError(f"In {filename}, two code blocks expected to be copies have different structures.")


def find_block_end(lines: List[str], start_index: int, indent: int) -> int:
Expand Down Expand Up @@ -661,11 +661,8 @@ def is_copy_consistent(filename: str, overwrite: bool = False, buffer: dict = No
diffs = []
line_index = 0
# Not a for loop cause `lines` is going to change (if `overwrite=True`).
search_re = _re_copy_warning_for_test_file if filename.startswith("tests") else _re_copy_warning
while line_index < len(lines):
search_re = _re_copy_warning
if filename.startswith("tests"):
search_re = _re_copy_warning_for_test_file

search = search_re.search(lines[line_index])
if search is None:
line_index += 1
Expand Down Expand Up @@ -718,7 +715,7 @@ def is_copy_consistent(filename: str, overwrite: bool = False, buffer: dict = No

is_class = lines[start_index].startswith(f"{' ' * (len(indent) - 4)}class ")
# sanity check
_sanity_check_splits(theoretical_code_splits, observed_code_splits, is_class=is_class)
_sanity_check_splits(theoretical_code_splits, observed_code_splits, is_class=is_class, filename=filename)

# observed code in a structured way (a dict mapping block names to blocks' code)
observed_code_blocks = OrderedDict()
Expand Down

0 comments on commit aab0829

Please sign in to comment.