Skip to content

Commit

Permalink
add name resolution to threading codemod
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Oct 11, 2023
1 parent 90ce80b commit 31cb417
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
21 changes: 21 additions & 0 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ def find_assignments(
return set(next(iter(scope.accesses[node])).referents)
return set()

def find_used_names(self, node):
"""
Find all the used names in the scope of `node`.
"""
names = []
scope = self.get_metadata(ScopeProvider, node)
nodes = [x.node for x in scope.assignments]
for node in nodes:
visitor = GatherNamesVisitor()
node.visit(visitor)
names.extend(visitor.names)
return names

def find_single_assignment(
self,
node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator],
Expand Down Expand Up @@ -112,3 +125,11 @@ def _get_name(node: Union[cst.Import, cst.ImportFrom]) -> str:
if matchers.matches(node, matchers.Import()):
return get_full_name_for_node(node.names[0].name)
return ""


class GatherNamesVisitor(cst.CSTVisitor):
def __init__(self):
self.names = []

def visit_Name(self, node: cst.Name) -> None:
self.names.append(node.value)
8 changes: 5 additions & 3 deletions src/core_codemods/with_threading_lock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import libcst as cst
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.utils_mixin import NameResolutionMixin


class WithThreadingLock(SemgrepCodemod):
class WithThreadingLock(SemgrepCodemod, NameResolutionMixin):
NAME = "bad-lock-with-statement"
SUMMARY = "Replace deprecated usage of threading lock classes as context managers"
REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW
Expand Down Expand Up @@ -39,8 +40,9 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With):
if len(original_node.items) == 1 and self.node_is_selected(
original_node.items[0]
):
# TODO: how to avoid name conflicts here?
name = cst.Name(value="lock")
current_names = self.find_used_names(original_node)
value = "lock" if "lock" not in current_names else "lock_"
name = cst.Name(value=value)
assign = cst.SimpleStatementLine(
body=[
cst.Assign(
Expand Down
30 changes: 30 additions & 0 deletions tests/codemods/test_with_threading_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,33 @@ def test_no_effect_multiple_with_clauses(self, tmpdir, klass):
...
"""
self.run_and_assert(tmpdir, input_code, input_code)

@each_class
def test_name_resolution_var(self, tmpdir, klass):
input_code = f"""from threading import {klass}
lock = 1
with {klass}():
...
"""
expected = f"""from threading import {klass}
lock = 1
lock_ = {klass}()
with lock_:
...
"""
self.run_and_assert(tmpdir, input_code, expected)

@each_class
def test_name_resolution_import(self, tmpdir, klass):
input_code = f"""from threading import {klass}
from something import lock
with {klass}():
...
"""
expected = f"""from threading import {klass}
from something import lock
lock_ = {klass}()
with lock_:
...
"""
self.run_and_assert(tmpdir, input_code, expected)

0 comments on commit 31cb417

Please sign in to comment.