Skip to content

Commit

Permalink
Merge pull request #735 from raymyers/extract-incomplete-block
Browse files Browse the repository at this point in the history
Add error handling for extract of incomplete block
  • Loading branch information
lieryan authored Jan 11, 2024
2 parents 95585e8 + 9cb3031 commit ffc8551
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- #719 Allows the in-memory db to be shared across threads (@tkrabel)
- #720 create one sqlite3.Connection per thread using a thread local (@tkrabel)
- #715 change AutoImport's `get_modules` to be case sensitive (@bagel897)
- #734 raise exception when extracting the start of a block without the end

# Release 1.10.0

Expand Down
38 changes: 38 additions & 0 deletions rope/refactor/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,10 @@ def __call__(self, info):
def base_conditions(self, info):
if info.region[1] > info.scope_region[1]:
raise RefactoringError("Bad region selected for extract method")

end_line = info.region_lines[1]
end_scope = info.global_scope.get_inner_scope_for_line(end_line)

if end_scope != info.scope and end_scope.get_end() != end_line:
raise RefactoringError("Bad region selected for extract method")
try:
Expand Down Expand Up @@ -497,6 +499,14 @@ def multi_line_conditions(self, info):
raise RefactoringError(
"Extracted piece should contain complete statements."
)
unbalanced_region_finder = _UnbalancedRegionFinder(
info.region_lines[0], info.region_lines[1]
)
unbalanced_region_finder.visit(info.pymodule.ast_node)
if unbalanced_region_finder.error:
raise RefactoringError(
"Extracted piece cannot contain the start of a block without the end."
)

def _is_region_on_a_word(self, info):
if (
Expand Down Expand Up @@ -1093,6 +1103,34 @@ def _ClassDef(self, node):
pass


class _UnbalancedRegionFinder(_BaseErrorFinder):
"""
Flag an error if we are including the start of a block without the end.
We detect this by ensuring there is no AST node that starts inside the
selected range but ends outside of it.
"""

def __init__(self, line_start: int, line_end: int):
self.error = False
self.line_start = line_start
self.line_end = line_end

def generic_visit(self, node: ast.AST):
if not hasattr(node, "end_lineno"):
super().generic_visit(node) # Visit children
return
ends_before_range_starts = node.end_lineno < self.line_start
starts_after_range_ends = node.lineno > self.line_end
if ends_before_range_starts or starts_after_range_ends:
return # Don't visit children
starts_on_or_after_range_start = node.lineno >= self.line_start
ends_after_range_ends = node.end_lineno > self.line_end
if starts_on_or_after_range_start and ends_after_range_ends:
self.error = True
return # Don't visit children
super().generic_visit(node) # Visit children


class _GlobalFinder(ast.RopeNodeVisitor):
def __init__(self):
self.globals_ = OrderedSet()
Expand Down
58 changes: 58 additions & 0 deletions ropetest/refactor/extracttest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,64 @@ def xxx_test_raising_exception_on_function_parens(self):
end = code.rindex(")") + 1
with self.assertRaises(rope.base.exceptions.RefactoringError):
self.do_extract_method(code, start, end, "new_func")

def test_raising_exception_on_incomplete_block(self):
code = dedent("""\
if True:
a = 1
b = 2
""")
start = code.index("if")
end = code.index("1") + 1
with self.assertRaises(rope.base.exceptions.RefactoringError):
self.do_extract_method(code, start, end, "new_func")

def test_raising_exception_on_incomplete_block_2(self):
code = dedent("""\
if True:
a = 1
#
b = 2
""")
start = code.index("if")
end = code.index("1") + 1
with self.assertRaises(rope.base.exceptions.RefactoringError):
self.do_extract_method(code, start, end, "new_func")

def test_raising_exception_on_incomplete_block_3(self):
code = dedent("""\
if True:
a = 1
b = 2
""")
start = code.index("if")
end = code.index("1") + 1
with self.assertRaises(rope.base.exceptions.RefactoringError):
self.do_extract_method(code, start, end, "new_func")

def test_raising_exception_on_incomplete_block_4(self):
code = dedent("""\
#
if True:
a = 1
b = 2
""")
start = code.index("#")
end = code.index("1") + 1
with self.assertRaises(rope.base.exceptions.RefactoringError):
self.do_extract_method(code, start, end, "new_func")

def test_raising_exception_on_incomplete_block_5(self):
code = dedent("""\
if True:
if 0:
a = 1
""")
start = code.index("if")
end = code.index("0:") + 2
with self.assertRaises(rope.base.exceptions.RefactoringError):
self.do_extract_method(code, start, end, "new_func")

def test_extract_method_and_extra_blank_lines(self):
code = dedent("""\
Expand Down

0 comments on commit ffc8551

Please sign in to comment.