diff --git a/utils/check_copies.py b/utils/check_copies.py index b50f5845886b0b..4bb5c6fef4eeb7 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -169,7 +169,6 @@ }, } - # This is to make sure the transformers module imported is the one in the repo. transformers_module = direct_transformers_import(TRANSFORMERS_PATH) @@ -185,7 +184,7 @@ def _should_continue(line: str, indent: str) -> bool: return line.startswith(indent) or len(line.strip()) == 0 or _is_definition_header_ending_line(line) -def _sanity_check_splits(splits_1, splits_2, is_class): +def _sanity_check_splits(splits_1, splits_2, is_class, filename): """Check the two (inner) block structures of the corresponding code block given by `split_code_into_blocks` match. For the case of `class`, they must be of one of the following 3 cases: @@ -246,11 +245,12 @@ def g(x): ["block_without_name", "block_with_name"], ]: raise ValueError( - "For a class, it must have a specific structure. See the docstring of `_sanity_check_splits` in the file `utils/check_copies.py`" + f"""Class defined in {filename} doesn't have the expected stucture. + See the docstring of `_sanity_check_splits` in the file `utils/check_copies.py`""", ) if block_names_1 != block_names_2: - raise ValueError("The structures in the 2 code blocks differ.") + raise ValueError(f"In {filename}, two code blocks expected to be copies have different structures.") def find_block_end(lines: List[str], start_index: int, indent: int) -> int: @@ -661,11 +661,8 @@ def is_copy_consistent(filename: str, overwrite: bool = False, buffer: dict = No diffs = [] line_index = 0 # Not a for loop cause `lines` is going to change (if `overwrite=True`). + search_re = _re_copy_warning_for_test_file if filename.startswith("tests") else _re_copy_warning while line_index < len(lines): - search_re = _re_copy_warning - if filename.startswith("tests"): - search_re = _re_copy_warning_for_test_file - search = search_re.search(lines[line_index]) if search is None: line_index += 1 @@ -718,7 +715,7 @@ def is_copy_consistent(filename: str, overwrite: bool = False, buffer: dict = No is_class = lines[start_index].startswith(f"{' ' * (len(indent) - 4)}class ") # sanity check - _sanity_check_splits(theoretical_code_splits, observed_code_splits, is_class=is_class) + _sanity_check_splits(theoretical_code_splits, observed_code_splits, is_class=is_class, filename=filename) # observed code in a structured way (a dict mapping block names to blocks' code) observed_code_blocks = OrderedDict()