diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 73d00cdf..5555f26b 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -3,6 +3,7 @@ from libcst import MetadataDependent, matchers from libcst.helpers import get_full_name_for_node from libcst.metadata import Assignment, BaseAssignment, ImportAssignment, ScopeProvider +from libcst.metadata.scope_provider import GlobalScope class NameResolutionMixin(MetadataDependent): @@ -70,12 +71,15 @@ def find_assignments( return set(next(iter(scope.accesses[node])).referents) return set() - def find_used_names(self, node): + def find_used_names_in_module(self): """ - Find all the used names in the scope of `node`. + Find all the used names in the scope of a libcst Module. """ names = [] - scope = self.get_metadata(ScopeProvider, node) + scope = self.find_global_scope() + if scope is None: + return [] + nodes = [x.node for x in scope.assignments] for other_nodes in nodes: visitor = GatherNamesVisitor() @@ -83,6 +87,14 @@ def find_used_names(self, node): names.extend(visitor.names) return names + def find_global_scope(self): + """Find the global scope for a libcst Module node.""" + scopes = self.context.wrapper.resolve(ScopeProvider).values() + for scope in scopes: + if isinstance(scope, GlobalScope): + return scope + return None + def find_single_assignment( self, node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator], diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py index c36a3c5f..2b531abc 100644 --- a/src/core_codemods/with_threading_lock.py +++ b/src/core_codemods/with_threading_lock.py @@ -40,8 +40,10 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): if len(original_node.items) == 1 and self.node_is_selected( original_node.items[0] ): - current_names = self.find_used_names(original_node) - value = "lock" if "lock" not in current_names else "lock_" + current_names = self.find_used_names_in_module() + # arbitrarily choose `lock_cm` if `lock` name is already taken + # in hopes that `lock_cm` is very unlikely to be used. + value = "lock" if "lock" not in current_names else "lock_cm" name = cst.Name(value=value) assign = cst.SimpleStatementLine( body=[ diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index 27de9a42..1dd72941 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -87,32 +87,55 @@ def test_no_effect_multiple_with_clauses(self, tmpdir, klass): """ self.run_and_assert(tmpdir, input_code, input_code) - @each_class - def test_name_resolution_var(self, tmpdir, klass): - input_code = f"""from threading import {klass} + +class TestThreadingNameResolution(BaseSemgrepCodemodTest): + codemod = WithThreadingLock + + @pytest.mark.parametrize( + "input_code,expected_code", + [ + ( + """from threading import Lock lock = 1 -with {klass}(): +with Lock(): ... -""" - expected = f"""from threading import {klass} +""", + """from threading import Lock lock = 1 -lock_ = {klass}() -with lock_: +lock_cm = Lock() +with lock_cm: ... -""" - self.run_and_assert(tmpdir, input_code, expected) - - @each_class - def test_name_resolution_import(self, tmpdir, klass): - input_code = f"""from threading import {klass} +""", + ), + ( + """from threading import Lock from something import lock -with {klass}(): +with Lock(): ... -""" - expected = f"""from threading import {klass} +""", + """from threading import Lock from something import lock -lock_ = {klass}() -with lock_: +lock_cm = Lock() +with lock_cm: ... -""" - self.run_and_assert(tmpdir, input_code, expected) +""", + ), + ( + """import threading +lock = 1 +def f(l): + with threading.Lock(): + return [lock_ for lock_ in l] +""", + """import threading +lock = 1 +def f(l): + lock_cm = threading.Lock() + with lock_cm: + return [lock_ for lock_ in l] +""", + ), + ], + ) + def test_name_resolution(self, tmpdir, input_code, expected_code): + self.run_and_assert(tmpdir, input_code, expected_code)