From 9d69aff1fe5ac9a7c347b236bee17f4014a5bc3a Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Wed, 11 Oct 2023 09:04:39 -0300 Subject: [PATCH] add name resolution to threading codemod --- src/codemodder/codemods/utils_mixin.py | 21 +++++++++++++++ src/core_codemods/with_threading_lock.py | 8 +++--- tests/codemods/test_with_threading_lock.py | 30 ++++++++++++++++++++++ 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 98f3d658d..73d00cdf0 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -70,6 +70,19 @@ def find_assignments( return set(next(iter(scope.accesses[node])).referents) return set() + def find_used_names(self, node): + """ + Find all the used names in the scope of `node`. + """ + names = [] + scope = self.get_metadata(ScopeProvider, node) + nodes = [x.node for x in scope.assignments] + for other_nodes in nodes: + visitor = GatherNamesVisitor() + other_nodes.visit(visitor) + names.extend(visitor.names) + return names + def find_single_assignment( self, node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator], @@ -112,3 +125,11 @@ def _get_name(node: Union[cst.Import, cst.ImportFrom]) -> str: if matchers.matches(node, matchers.Import()): return get_full_name_for_node(node.names[0].name) return "" + + +class GatherNamesVisitor(cst.CSTVisitor): + def __init__(self): + self.names = [] + + def visit_Name(self, node: cst.Name) -> None: + self.names.append(node.value) diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py index 349bb48fc..c36a3c5f9 100644 --- a/src/core_codemods/with_threading_lock.py +++ b/src/core_codemods/with_threading_lock.py @@ -1,9 +1,10 @@ import libcst as cst from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.utils_mixin import NameResolutionMixin -class WithThreadingLock(SemgrepCodemod): +class WithThreadingLock(SemgrepCodemod, NameResolutionMixin): NAME = "bad-lock-with-statement" SUMMARY = "Replace deprecated usage of threading lock classes as context managers" REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW @@ -39,8 +40,9 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): if len(original_node.items) == 1 and self.node_is_selected( original_node.items[0] ): - # TODO: how to avoid name conflicts here? - name = cst.Name(value="lock") + current_names = self.find_used_names(original_node) + value = "lock" if "lock" not in current_names else "lock_" + name = cst.Name(value=value) assign = cst.SimpleStatementLine( body=[ cst.Assign( diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index 483cf0a81..27de9a423 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -86,3 +86,33 @@ def test_no_effect_multiple_with_clauses(self, tmpdir, klass): ... """ self.run_and_assert(tmpdir, input_code, input_code) + + @each_class + def test_name_resolution_var(self, tmpdir, klass): + input_code = f"""from threading import {klass} +lock = 1 +with {klass}(): + ... +""" + expected = f"""from threading import {klass} +lock = 1 +lock_ = {klass}() +with lock_: + ... +""" + self.run_and_assert(tmpdir, input_code, expected) + + @each_class + def test_name_resolution_import(self, tmpdir, klass): + input_code = f"""from threading import {klass} +from something import lock +with {klass}(): + ... +""" + expected = f"""from threading import {klass} +from something import lock +lock_ = {klass}() +with lock_: + ... +""" + self.run_and_assert(tmpdir, input_code, expected)