From 2e584f47f30ee04fa79abfb86d6bbc0c33990683 Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Wed, 11 Oct 2023 09:04:39 -0300 Subject: [PATCH 1/7] add name resolution to threading codemod --- src/codemodder/codemods/utils_mixin.py | 21 +++++++++++++++ src/core_codemods/with_threading_lock.py | 8 +++--- tests/codemods/test_with_threading_lock.py | 30 ++++++++++++++++++++++ 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 98f3d658..73d00cdf 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -70,6 +70,19 @@ def find_assignments( return set(next(iter(scope.accesses[node])).referents) return set() + def find_used_names(self, node): + """ + Find all the used names in the scope of `node`. + """ + names = [] + scope = self.get_metadata(ScopeProvider, node) + nodes = [x.node for x in scope.assignments] + for other_nodes in nodes: + visitor = GatherNamesVisitor() + other_nodes.visit(visitor) + names.extend(visitor.names) + return names + def find_single_assignment( self, node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator], @@ -112,3 +125,11 @@ def _get_name(node: Union[cst.Import, cst.ImportFrom]) -> str: if matchers.matches(node, matchers.Import()): return get_full_name_for_node(node.names[0].name) return "" + + +class GatherNamesVisitor(cst.CSTVisitor): + def __init__(self): + self.names = [] + + def visit_Name(self, node: cst.Name) -> None: + self.names.append(node.value) diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py index b7ae9705..2ee277a3 100644 --- a/src/core_codemods/with_threading_lock.py +++ b/src/core_codemods/with_threading_lock.py @@ -1,9 +1,10 @@ import libcst as cst from codemodder.codemods.base_codemod import ReviewGuidance from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.utils_mixin import NameResolutionMixin -class WithThreadingLock(SemgrepCodemod): +class WithThreadingLock(SemgrepCodemod, NameResolutionMixin): NAME = "bad-lock-with-statement" SUMMARY = "Separate Lock Instantiation from `with` Call" DESCRIPTION = ( @@ -51,8 +52,9 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): if len(original_node.items) == 1 and self.node_is_selected( original_node.items[0] ): - # TODO: how to avoid name conflicts here? - name = cst.Name(value="lock") + current_names = self.find_used_names(original_node) + value = "lock" if "lock" not in current_names else "lock_" + name = cst.Name(value=value) assign = cst.SimpleStatementLine( body=[ cst.Assign( diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index 483cf0a8..27de9a42 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -86,3 +86,33 @@ def test_no_effect_multiple_with_clauses(self, tmpdir, klass): ... """ self.run_and_assert(tmpdir, input_code, input_code) + + @each_class + def test_name_resolution_var(self, tmpdir, klass): + input_code = f"""from threading import {klass} +lock = 1 +with {klass}(): + ... +""" + expected = f"""from threading import {klass} +lock = 1 +lock_ = {klass}() +with lock_: + ... +""" + self.run_and_assert(tmpdir, input_code, expected) + + @each_class + def test_name_resolution_import(self, tmpdir, klass): + input_code = f"""from threading import {klass} +from something import lock +with {klass}(): + ... +""" + expected = f"""from threading import {klass} +from something import lock +lock_ = {klass}() +with lock_: + ... +""" + self.run_and_assert(tmpdir, input_code, expected) From 084851aae32599e75385ecf2d639c6a2b521b261 Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Thu, 12 Oct 2023 07:56:25 -0300 Subject: [PATCH 2/7] gather all names in global scope --- src/codemodder/codemods/utils_mixin.py | 18 +++++- src/core_codemods/with_threading_lock.py | 6 +- tests/codemods/test_with_threading_lock.py | 65 +++++++++++++++------- 3 files changed, 63 insertions(+), 26 deletions(-) diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 73d00cdf..5555f26b 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -3,6 +3,7 @@ from libcst import MetadataDependent, matchers from libcst.helpers import get_full_name_for_node from libcst.metadata import Assignment, BaseAssignment, ImportAssignment, ScopeProvider +from libcst.metadata.scope_provider import GlobalScope class NameResolutionMixin(MetadataDependent): @@ -70,12 +71,15 @@ def find_assignments( return set(next(iter(scope.accesses[node])).referents) return set() - def find_used_names(self, node): + def find_used_names_in_module(self): """ - Find all the used names in the scope of `node`. + Find all the used names in the scope of a libcst Module. """ names = [] - scope = self.get_metadata(ScopeProvider, node) + scope = self.find_global_scope() + if scope is None: + return [] + nodes = [x.node for x in scope.assignments] for other_nodes in nodes: visitor = GatherNamesVisitor() @@ -83,6 +87,14 @@ def find_used_names(self, node): names.extend(visitor.names) return names + def find_global_scope(self): + """Find the global scope for a libcst Module node.""" + scopes = self.context.wrapper.resolve(ScopeProvider).values() + for scope in scopes: + if isinstance(scope, GlobalScope): + return scope + return None + def find_single_assignment( self, node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator], diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py index 2ee277a3..7039d21d 100644 --- a/src/core_codemods/with_threading_lock.py +++ b/src/core_codemods/with_threading_lock.py @@ -52,8 +52,10 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): if len(original_node.items) == 1 and self.node_is_selected( original_node.items[0] ): - current_names = self.find_used_names(original_node) - value = "lock" if "lock" not in current_names else "lock_" + current_names = self.find_used_names_in_module() + # arbitrarily choose `lock_cm` if `lock` name is already taken + # in hopes that `lock_cm` is very unlikely to be used. + value = "lock" if "lock" not in current_names else "lock_cm" name = cst.Name(value=value) assign = cst.SimpleStatementLine( body=[ diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index 27de9a42..1dd72941 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -87,32 +87,55 @@ def test_no_effect_multiple_with_clauses(self, tmpdir, klass): """ self.run_and_assert(tmpdir, input_code, input_code) - @each_class - def test_name_resolution_var(self, tmpdir, klass): - input_code = f"""from threading import {klass} + +class TestThreadingNameResolution(BaseSemgrepCodemodTest): + codemod = WithThreadingLock + + @pytest.mark.parametrize( + "input_code,expected_code", + [ + ( + """from threading import Lock lock = 1 -with {klass}(): +with Lock(): ... -""" - expected = f"""from threading import {klass} +""", + """from threading import Lock lock = 1 -lock_ = {klass}() -with lock_: +lock_cm = Lock() +with lock_cm: ... -""" - self.run_and_assert(tmpdir, input_code, expected) - - @each_class - def test_name_resolution_import(self, tmpdir, klass): - input_code = f"""from threading import {klass} +""", + ), + ( + """from threading import Lock from something import lock -with {klass}(): +with Lock(): ... -""" - expected = f"""from threading import {klass} +""", + """from threading import Lock from something import lock -lock_ = {klass}() -with lock_: +lock_cm = Lock() +with lock_cm: ... -""" - self.run_and_assert(tmpdir, input_code, expected) +""", + ), + ( + """import threading +lock = 1 +def f(l): + with threading.Lock(): + return [lock_ for lock_ in l] +""", + """import threading +lock = 1 +def f(l): + lock_cm = threading.Lock() + with lock_cm: + return [lock_ for lock_ in l] +""", + ), + ], + ) + def test_name_resolution(self, tmpdir, input_code, expected_code): + self.run_and_assert(tmpdir, input_code, expected_code) From 08138e61d087eb2059110f123f5ecba56c18f560 Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Thu, 12 Oct 2023 08:21:05 -0300 Subject: [PATCH 3/7] add no cover to just-in-case lines --- src/codemodder/codemods/utils_mixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 5555f26b..2bd2107d 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -78,7 +78,7 @@ def find_used_names_in_module(self): names = [] scope = self.find_global_scope() if scope is None: - return [] + return [] # pragma: no cover nodes = [x.node for x in scope.assignments] for other_nodes in nodes: @@ -93,7 +93,7 @@ def find_global_scope(self): for scope in scopes: if isinstance(scope, GlobalScope): return scope - return None + return None # pragma: no cover def find_single_assignment( self, From 9cc3f8aa268fec88efbd1d950b7397cb807d2f11 Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Mon, 16 Oct 2023 09:08:53 -0300 Subject: [PATCH 4/7] add test cases for multiple locks in a module --- tests/codemods/test_with_threading_lock.py | 32 ++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index 1dd72941..c2e4a771 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -133,6 +133,38 @@ def f(l): lock_cm = threading.Lock() with lock_cm: return [lock_ for lock_ in l] +""", + ), + ( + """import threading +with threading.Lock(): + int("1") + +with threading.Lock(): + print() +""", + """import threading +lock = threading.Lock() +with lock: + int("1") + +lock = threading.Lock() +with lock: + print() +""", + ), + ( + """import threading +with threading.Lock(): + with threading.Lock(): + print() +""", + """import threading +lock = threading.Lock() +with lock: + lock = threading.Lock() + with lock: + print() """, ), ], From a27a063f7d739924c3c0562f79e499cfc5dfcd01 Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Mon, 16 Oct 2023 09:29:28 -0300 Subject: [PATCH 5/7] use type-specific names --- src/core_codemods/with_threading_lock.py | 20 ++++++++++++++--- tests/codemods/test_with_threading_lock.py | 26 ++++++++++------------ 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py index 7039d21d..21767c5e 100644 --- a/src/core_codemods/with_threading_lock.py +++ b/src/core_codemods/with_threading_lock.py @@ -53,9 +53,14 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): original_node.items[0] ): current_names = self.find_used_names_in_module() - # arbitrarily choose `lock_cm` if `lock` name is already taken - # in hopes that `lock_cm` is very unlikely to be used. - value = "lock" if "lock" not in current_names else "lock_cm" + name_for_node_type = _get_node_name(original_node) + # arbitrarily choose `name_cm` if `name` name is already taken + # in hopes that `name_cm` is very unlikely to be used. + value = ( + name_for_node_type + if name_for_node_type not in current_names + else f"{name_for_node_type}_cm" + ) name = cst.Name(value=value) assign = cst.SimpleStatementLine( body=[ @@ -76,3 +81,12 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): ) return original_node + + +def _get_node_name(original_node: cst.With): + func_call = original_node.items[0].item.func + if isinstance(func_call, cst.Name): + return func_call.value.lower() + if isinstance(func_call, cst.Attribute): + return func_call.attr.value.lower() + return "" diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index c2e4a771..3e6c712f 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -27,8 +27,8 @@ def test_import(self, tmpdir, klass): ... """ expected = f"""import threading -lock = threading.{klass}() -with lock: +{klass.lower()} = threading.{klass}() +with {klass.lower()}: ... """ self.run_and_assert(tmpdir, input_code, expected) @@ -40,8 +40,8 @@ def test_from_import(self, tmpdir, klass): ... """ expected = f"""from threading import {klass} -lock = {klass}() -with lock: +{klass.lower()} = {klass}() +with {klass.lower()}: ... """ self.run_and_assert(tmpdir, input_code, expected) @@ -53,8 +53,8 @@ def test_simple_replacement_with_as(self, tmpdir, klass): ... """ expected = f"""import threading -lock = threading.{klass}() -with lock as foo: +{klass.lower()} = threading.{klass}() +with {klass.lower()} as foo: ... """ self.run_and_assert(tmpdir, input_code, expected) @@ -69,8 +69,8 @@ def test_no_effect_sanity_check(self, tmpdir, klass): ... """ expected = f"""import threading -lock = threading.{klass}() -with lock: +{klass.lower()} = threading.{klass}() +with {klass.lower()}: ... with threading_lock(): @@ -136,30 +136,28 @@ def f(l): """, ), ( - """import threading + """import threading with threading.Lock(): int("1") - with threading.Lock(): print() """, - """import threading + """import threading lock = threading.Lock() with lock: int("1") - lock = threading.Lock() with lock: print() """, ), ( - """import threading + """import threading with threading.Lock(): with threading.Lock(): print() """, - """import threading + """import threading lock = threading.Lock() with lock: lock = threading.Lock() From 969d6bb189c7c40fa81761b74347467d688f61d1 Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Mon, 16 Oct 2023 10:59:26 -0300 Subject: [PATCH 6/7] keep track of additional locks added --- src/core_codemods/with_threading_lock.py | 34 +++++++++++++++------- tests/codemods/base_codemod_test.py | 6 ++-- tests/codemods/test_with_threading_lock.py | 27 ++++++++++------- 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py index 21767c5e..39b466d5 100644 --- a/src/core_codemods/with_threading_lock.py +++ b/src/core_codemods/with_threading_lock.py @@ -45,6 +45,27 @@ def rule(cls): - focus-metavariable: $BODY """ + def __init__(self, *args): + SemgrepCodemod.__init__(self, *args) + NameResolutionMixin.__init__(self) + self.names_in_module = self.find_used_names_in_module() + + def _create_new_variable(self, original_node: cst.With): + """ + Create an appropriately named variable for the new + lock, condition, or semaphore. + Keep track of this addition in case that are other additions. + """ + base_name = _get_node_name(original_node) + value = base_name + counter = 1 + while value in self.names_in_module: + value = f"{base_name}_{counter}" + counter += 1 + + self.names_in_module.append(value) + return cst.Name(value=value) + def leave_With(self, original_node: cst.With, updated_node: cst.With): # We deliberately restrict ourselves to simple cases where there's only one with clause for now. # Semgrep appears to be insufficiently expressive to match multiple clauses correctly. @@ -52,16 +73,7 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): if len(original_node.items) == 1 and self.node_is_selected( original_node.items[0] ): - current_names = self.find_used_names_in_module() - name_for_node_type = _get_node_name(original_node) - # arbitrarily choose `name_cm` if `name` name is already taken - # in hopes that `name_cm` is very unlikely to be used. - value = ( - name_for_node_type - if name_for_node_type not in current_names - else f"{name_for_node_type}_cm" - ) - name = cst.Name(value=value) + name = self._create_new_variable(original_node) assign = cst.SimpleStatementLine( body=[ cst.Assign( @@ -89,4 +101,4 @@ def _get_node_name(original_node: cst.With): return func_call.value.lower() if isinstance(func_call, cst.Attribute): return func_call.attr.value.lower() - return "" + return "" # pragma: no cover diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index 90d6d5dc..0fbd4edd 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -37,8 +37,9 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): [], defaultdict(list), ) + wrapper = cst.MetadataWrapper(input_tree) command_instance = self.codemod( - CodemodContext(), + CodemodContext(wrapper=wrapper), self.execution_context, self.file_context, ) @@ -83,8 +84,9 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): [], results, ) + wrapper = cst.MetadataWrapper(input_tree) command_instance = self.codemod( - CodemodContext(), + CodemodContext(wrapper=wrapper), self.execution_context, self.file_context, ) diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index 3e6c712f..b9620144 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -102,8 +102,8 @@ class TestThreadingNameResolution(BaseSemgrepCodemodTest): """, """from threading import Lock lock = 1 -lock_cm = Lock() -with lock_cm: +lock_1 = Lock() +with lock_1: ... """, ), @@ -115,8 +115,8 @@ class TestThreadingNameResolution(BaseSemgrepCodemodTest): """, """from threading import Lock from something import lock -lock_cm = Lock() -with lock_cm: +lock_1 = Lock() +with lock_1: ... """, ), @@ -130,8 +130,8 @@ def f(l): """import threading lock = 1 def f(l): - lock_cm = threading.Lock() - with lock_cm: + lock_1 = threading.Lock() + with lock_1: return [lock_ for lock_ in l] """, ), @@ -139,6 +139,9 @@ def f(l): """import threading with threading.Lock(): int("1") +with threading.Lock(): + print() +var = 1 with threading.Lock(): print() """, @@ -146,8 +149,12 @@ def f(l): lock = threading.Lock() with lock: int("1") -lock = threading.Lock() -with lock: +lock_1 = threading.Lock() +with lock_1: + print() +var = 1 +lock_2 = threading.Lock() +with lock_2: print() """, ), @@ -158,8 +165,8 @@ def f(l): print() """, """import threading -lock = threading.Lock() -with lock: +lock_1 = threading.Lock() +with lock_1: lock = threading.Lock() with lock: print() From 800fb773bc146c3bd47039a70fb28a4345d58dbd Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Tue, 17 Oct 2023 07:55:36 -0300 Subject: [PATCH 7/7] add unit test --- tests/codemods/test_with_threading_lock.py | 23 ++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index b9620144..7a1d56f0 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -125,14 +125,14 @@ class TestThreadingNameResolution(BaseSemgrepCodemodTest): lock = 1 def f(l): with threading.Lock(): - return [lock_ for lock_ in l] + return [lock_1 for lock_1 in l] """, """import threading lock = 1 def f(l): - lock_1 = threading.Lock() - with lock_1: - return [lock_ for lock_ in l] + lock_2 = threading.Lock() + with lock_2: + return [lock_1 for lock_1 in l] """, ), ( @@ -170,6 +170,21 @@ def f(l): lock = threading.Lock() with lock: print() +""", + ), + ( + """import threading +def my_func(): + lock = "whatever" + with threading.Lock(): + foo() +""", + """import threading +def my_func(): + lock = "whatever" + lock_1 = threading.Lock() + with lock_1: + foo() """, ), ],