diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 98f3d658..2bd2107d 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -3,6 +3,7 @@ from libcst import MetadataDependent, matchers from libcst.helpers import get_full_name_for_node from libcst.metadata import Assignment, BaseAssignment, ImportAssignment, ScopeProvider +from libcst.metadata.scope_provider import GlobalScope class NameResolutionMixin(MetadataDependent): @@ -70,6 +71,30 @@ def find_assignments( return set(next(iter(scope.accesses[node])).referents) return set() + def find_used_names_in_module(self): + """ + Find all the used names in the scope of a libcst Module. + """ + names = [] + scope = self.find_global_scope() + if scope is None: + return [] # pragma: no cover + + nodes = [x.node for x in scope.assignments] + for other_nodes in nodes: + visitor = GatherNamesVisitor() + other_nodes.visit(visitor) + names.extend(visitor.names) + return names + + def find_global_scope(self): + """Find the global scope for a libcst Module node.""" + scopes = self.context.wrapper.resolve(ScopeProvider).values() + for scope in scopes: + if isinstance(scope, GlobalScope): + return scope + return None # pragma: no cover + def find_single_assignment( self, node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator], @@ -112,3 +137,11 @@ def _get_name(node: Union[cst.Import, cst.ImportFrom]) -> str: if matchers.matches(node, matchers.Import()): return get_full_name_for_node(node.names[0].name) return "" + + +class GatherNamesVisitor(cst.CSTVisitor): + def __init__(self): + self.names = [] + + def visit_Name(self, node: cst.Name) -> None: + self.names.append(node.value) diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py index b7ae9705..39b466d5 100644 --- a/src/core_codemods/with_threading_lock.py +++ b/src/core_codemods/with_threading_lock.py @@ -1,9 +1,10 @@ import libcst as cst from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.utils_mixin import NameResolutionMixin -class WithThreadingLock(SemgrepCodemod): +class WithThreadingLock(SemgrepCodemod, NameResolutionMixin): NAME = "bad-lock-with-statement" SUMMARY = "Separate Lock Instantiation from `with` Call" DESCRIPTION = ( @@ -44,6 +45,27 @@ def rule(cls): - focus-metavariable: $BODY """ + def __init__(self, *args): + SemgrepCodemod.__init__(self, *args) + NameResolutionMixin.__init__(self) + self.names_in_module = self.find_used_names_in_module() + + def _create_new_variable(self, original_node: cst.With): + """ + Create an appropriately named variable for the new + lock, condition, or semaphore. + Keep track of this addition in case that are other additions. + """ + base_name = _get_node_name(original_node) + value = base_name + counter = 1 + while value in self.names_in_module: + value = f"{base_name}_{counter}" + counter += 1 + + self.names_in_module.append(value) + return cst.Name(value=value) + def leave_With(self, original_node: cst.With, updated_node: cst.With): # We deliberately restrict ourselves to simple cases where there's only one with clause for now. # Semgrep appears to be insufficiently expressive to match multiple clauses correctly. @@ -51,8 +73,7 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): if len(original_node.items) == 1 and self.node_is_selected( original_node.items[0] ): - # TODO: how to avoid name conflicts here? - name = cst.Name(value="lock") + name = self._create_new_variable(original_node) assign = cst.SimpleStatementLine( body=[ cst.Assign( @@ -72,3 +93,12 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): ) return original_node + + +def _get_node_name(original_node: cst.With): + func_call = original_node.items[0].item.func + if isinstance(func_call, cst.Name): + return func_call.value.lower() + if isinstance(func_call, cst.Attribute): + return func_call.attr.value.lower() + return "" # pragma: no cover diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index 90d6d5dc..0fbd4edd 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -37,8 +37,9 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): [], defaultdict(list), ) + wrapper = cst.MetadataWrapper(input_tree) command_instance = self.codemod( - CodemodContext(), + CodemodContext(wrapper=wrapper), self.execution_context, self.file_context, ) @@ -83,8 +84,9 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): [], results, ) + wrapper = cst.MetadataWrapper(input_tree) command_instance = self.codemod( - CodemodContext(), + CodemodContext(wrapper=wrapper), self.execution_context, self.file_context, ) diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index 483cf0a8..7a1d56f0 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -27,8 +27,8 @@ def test_import(self, tmpdir, klass): ... """ expected = f"""import threading -lock = threading.{klass}() -with lock: +{klass.lower()} = threading.{klass}() +with {klass.lower()}: ... """ self.run_and_assert(tmpdir, input_code, expected) @@ -40,8 +40,8 @@ def test_from_import(self, tmpdir, klass): ... """ expected = f"""from threading import {klass} -lock = {klass}() -with lock: +{klass.lower()} = {klass}() +with {klass.lower()}: ... """ self.run_and_assert(tmpdir, input_code, expected) @@ -53,8 +53,8 @@ def test_simple_replacement_with_as(self, tmpdir, klass): ... """ expected = f"""import threading -lock = threading.{klass}() -with lock as foo: +{klass.lower()} = threading.{klass}() +with {klass.lower()} as foo: ... """ self.run_and_assert(tmpdir, input_code, expected) @@ -69,8 +69,8 @@ def test_no_effect_sanity_check(self, tmpdir, klass): ... """ expected = f"""import threading -lock = threading.{klass}() -with lock: +{klass.lower()} = threading.{klass}() +with {klass.lower()}: ... with threading_lock(): @@ -86,3 +86,108 @@ def test_no_effect_multiple_with_clauses(self, tmpdir, klass): ... """ self.run_and_assert(tmpdir, input_code, input_code) + + +class TestThreadingNameResolution(BaseSemgrepCodemodTest): + codemod = WithThreadingLock + + @pytest.mark.parametrize( + "input_code,expected_code", + [ + ( + """from threading import Lock +lock = 1 +with Lock(): + ... +""", + """from threading import Lock +lock = 1 +lock_1 = Lock() +with lock_1: + ... +""", + ), + ( + """from threading import Lock +from something import lock +with Lock(): + ... +""", + """from threading import Lock +from something import lock +lock_1 = Lock() +with lock_1: + ... +""", + ), + ( + """import threading +lock = 1 +def f(l): + with threading.Lock(): + return [lock_1 for lock_1 in l] +""", + """import threading +lock = 1 +def f(l): + lock_2 = threading.Lock() + with lock_2: + return [lock_1 for lock_1 in l] +""", + ), + ( + """import threading +with threading.Lock(): + int("1") +with threading.Lock(): + print() +var = 1 +with threading.Lock(): + print() +""", + """import threading +lock = threading.Lock() +with lock: + int("1") +lock_1 = threading.Lock() +with lock_1: + print() +var = 1 +lock_2 = threading.Lock() +with lock_2: + print() +""", + ), + ( + """import threading +with threading.Lock(): + with threading.Lock(): + print() +""", + """import threading +lock_1 = threading.Lock() +with lock_1: + lock = threading.Lock() + with lock: + print() +""", + ), + ( + """import threading +def my_func(): + lock = "whatever" + with threading.Lock(): + foo() +""", + """import threading +def my_func(): + lock = "whatever" + lock_1 = threading.Lock() + with lock_1: + foo() +""", + ), + ], + ) + def test_name_resolution(self, tmpdir, input_code, expected_code): + self.run_and_assert(tmpdir, input_code, expected_code)