Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add name resolution to threading codemod #71

Merged
merged 7 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from libcst import MetadataDependent, matchers
from libcst.helpers import get_full_name_for_node
from libcst.metadata import Assignment, BaseAssignment, ImportAssignment, ScopeProvider
from libcst.metadata.scope_provider import GlobalScope


class NameResolutionMixin(MetadataDependent):
Expand Down Expand Up @@ -70,6 +71,30 @@ def find_assignments(
return set(next(iter(scope.accesses[node])).referents)
return set()

def find_used_names_in_module(self):
"""
Find all the used names in the scope of a libcst Module.
"""
names = []
scope = self.find_global_scope()
if scope is None:
return [] # pragma: no cover

nodes = [x.node for x in scope.assignments]
clavedeluna marked this conversation as resolved.
Show resolved Hide resolved
for other_nodes in nodes:
visitor = GatherNamesVisitor()
other_nodes.visit(visitor)
names.extend(visitor.names)
return names

def find_global_scope(self):
"""Find the global scope for a libcst Module node."""
scopes = self.context.wrapper.resolve(ScopeProvider).values()
for scope in scopes:
if isinstance(scope, GlobalScope):
return scope
return None # pragma: no cover
Comment on lines +74 to +96
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it currently is, It's grabbing the names based on the global scope. This means no matter where the Lock() call is detected current_names will always return the same names.

You should calculate the current_names at __init__ and update it accordingly when you add a new name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means no matter where the Lock() call is detected current_names will always return the same names.

this is correct and I believe exactly what we want. When we hit this threading codemod at leave_With, we are hitting a module's with clause. With the current approach, we grab the entire module where the with clause is located and find all the variable / import names. This means we will attempt to never pick a variable name for the lock at any scope, even if we could pick lock bc it's used in some other function. This is the safest thing.

Calculating current_names at init will result in the exact same thing but the downside is that we have to update the codemod architecture. With this approach, all I had to do is to update leave_With and call find_used_names_in_module. This architecture is more generalizable than adding an init call.

If you think I'm mistaken here give me a unit test that would fail the current approach so I can iterate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more of a question but I'm unclear what "global" scope means in this context. Is it the module-level scope? Or is it some kind of scope containing all variables declared anywhere in the module? I would be rather surprised to find out that it's the latter. But here's a potential test case:

# no conflicts declared at module-level scope
def my_func():
    # Will we correctly detect this conflict?
    lock = "whatever"
    with threading.Lock():
        foo()

I'm definitely curious about this, especially since I think all of the existing test cases are implemented in terms of module-level scope.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is it some kind of scope containing all variables declared anywhere in the module?

It is exactly this. I said it somewhere along the lines but it will in fact find all variable / import names. This is good bc it's extra safe, but of course there are scopes in which we could reuse but we don't.

I"ve added the test case that you've pointed out and it passes as expected - we add a new lock_1 variable right under lock = "whatever"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's module-level scope. I believe the current implementation will be able to detect the conflict in your particular example.


def find_single_assignment(
self,
node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator],
Expand Down Expand Up @@ -112,3 +137,11 @@ def _get_name(node: Union[cst.Import, cst.ImportFrom]) -> str:
if matchers.matches(node, matchers.Import()):
return get_full_name_for_node(node.names[0].name)
return ""


class GatherNamesVisitor(cst.CSTVisitor):
def __init__(self):
self.names = []

def visit_Name(self, node: cst.Name) -> None:
self.names.append(node.value)
36 changes: 33 additions & 3 deletions src/core_codemods/with_threading_lock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import libcst as cst
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.utils_mixin import NameResolutionMixin


class WithThreadingLock(SemgrepCodemod):
class WithThreadingLock(SemgrepCodemod, NameResolutionMixin):
NAME = "bad-lock-with-statement"
SUMMARY = "Separate Lock Instantiation from `with` Call"
DESCRIPTION = (
Expand Down Expand Up @@ -44,15 +45,35 @@ def rule(cls):
- focus-metavariable: $BODY
"""

def __init__(self, *args):
SemgrepCodemod.__init__(self, *args)
NameResolutionMixin.__init__(self)
self.names_in_module = self.find_used_names_in_module()

def _create_new_variable(self, original_node: cst.With):
"""
Create an appropriately named variable for the new
lock, condition, or semaphore.
Keep track of this addition in case that are other additions.
"""
base_name = _get_node_name(original_node)
value = base_name
counter = 1
while value in self.names_in_module:
value = f"{base_name}_{counter}"
counter += 1

self.names_in_module.append(value)
return cst.Name(value=value)

def leave_With(self, original_node: cst.With, updated_node: cst.With):
# We deliberately restrict ourselves to simple cases where there's only one with clause for now.
# Semgrep appears to be insufficiently expressive to match multiple clauses correctly.
# We should probably just rewrite this codemod using libcst without semgrep.
if len(original_node.items) == 1 and self.node_is_selected(
original_node.items[0]
):
# TODO: how to avoid name conflicts here?
name = cst.Name(value="lock")
name = self._create_new_variable(original_node)
assign = cst.SimpleStatementLine(
body=[
cst.Assign(
Expand All @@ -72,3 +93,12 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With):
)

return original_node


def _get_node_name(original_node: cst.With):
func_call = original_node.items[0].item.func
if isinstance(func_call, cst.Name):
return func_call.value.lower()
if isinstance(func_call, cst.Attribute):
return func_call.attr.value.lower()
return "" # pragma: no cover
6 changes: 4 additions & 2 deletions tests/codemods/base_codemod_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected):
[],
defaultdict(list),
)
wrapper = cst.MetadataWrapper(input_tree)
command_instance = self.codemod(
CodemodContext(),
CodemodContext(wrapper=wrapper),
self.execution_context,
self.file_context,
)
Expand Down Expand Up @@ -83,8 +84,9 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected):
[],
results,
)
wrapper = cst.MetadataWrapper(input_tree)
command_instance = self.codemod(
CodemodContext(),
CodemodContext(wrapper=wrapper),
self.execution_context,
self.file_context,
)
Expand Down
121 changes: 113 additions & 8 deletions tests/codemods/test_with_threading_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def test_import(self, tmpdir, klass):
...
"""
expected = f"""import threading
lock = threading.{klass}()
with lock:
{klass.lower()} = threading.{klass}()
with {klass.lower()}:
...
"""
self.run_and_assert(tmpdir, input_code, expected)
Expand All @@ -40,8 +40,8 @@ def test_from_import(self, tmpdir, klass):
...
"""
expected = f"""from threading import {klass}
lock = {klass}()
with lock:
{klass.lower()} = {klass}()
with {klass.lower()}:
...
"""
self.run_and_assert(tmpdir, input_code, expected)
Expand All @@ -53,8 +53,8 @@ def test_simple_replacement_with_as(self, tmpdir, klass):
...
"""
expected = f"""import threading
lock = threading.{klass}()
with lock as foo:
{klass.lower()} = threading.{klass}()
with {klass.lower()} as foo:
...
"""
self.run_and_assert(tmpdir, input_code, expected)
Expand All @@ -69,8 +69,8 @@ def test_no_effect_sanity_check(self, tmpdir, klass):
...
"""
expected = f"""import threading
lock = threading.{klass}()
with lock:
{klass.lower()} = threading.{klass}()
with {klass.lower()}:
...

with threading_lock():
Expand All @@ -86,3 +86,108 @@ def test_no_effect_multiple_with_clauses(self, tmpdir, klass):
...
"""
self.run_and_assert(tmpdir, input_code, input_code)


class TestThreadingNameResolution(BaseSemgrepCodemodTest):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need another test here:

from threading import Lock

with Lock():
    foo()
    
with Lock():
    bar()

This case probably isn't such a concern because variable reuse is probably acceptable. But we also need another test case:

with Lock():
    with Lock():
        foo()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrecsilva wondering if these are the cases you're referring to by suggesting to implement get current names at init? because right now the codemod will change

with Lock():
    with Lock():
        foo()

to

import threading
lock = threading.Lock()
with lock:
    lock = threading.Lock()
    with lock:
        foo()

which is not great since it itself is adding a name clash. This is happening because both calls to leave_With happen one after the other, with the second one not "learning about" the new addition of the first lock = .... I don't believe this would be solved by moving current_names = self.find_used_names_in_module() to codemod init because that too will happen only once.

Copy link
Contributor

@andrecsilva andrecsilva Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was what I was afraid it would happen.
To reiterate, my suggestion is moving the calculation of current_names to WithThreadingLock's __init__ and update current_names with the newly added name within leave_With.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh I see what you mean now.

codemod = WithThreadingLock

@pytest.mark.parametrize(
"input_code,expected_code",
[
(
"""from threading import Lock
lock = 1
with Lock():
...
""",
"""from threading import Lock
lock = 1
lock_1 = Lock()
with lock_1:
...
""",
),
(
"""from threading import Lock
from something import lock
with Lock():
...
""",
"""from threading import Lock
from something import lock
lock_1 = Lock()
with lock_1:
...
""",
),
(
"""import threading
lock = 1
def f(l):
with threading.Lock():
return [lock_1 for lock_1 in l]
""",
"""import threading
lock = 1
def f(l):
lock_2 = threading.Lock()
with lock_2:
return [lock_1 for lock_1 in l]
""",
),
(
"""import threading
with threading.Lock():
int("1")
with threading.Lock():
print()
var = 1
with threading.Lock():
print()
""",
"""import threading
lock = threading.Lock()
with lock:
int("1")
lock_1 = threading.Lock()
with lock_1:
print()
var = 1
lock_2 = threading.Lock()
with lock_2:
print()
""",
),
(
"""import threading
with threading.Lock():
with threading.Lock():
print()
""",
"""import threading
lock_1 = threading.Lock()
with lock_1:
lock = threading.Lock()
with lock:
print()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

notice this test. libcst processes the internal with statement first, so our naming will be like this. Not a huuuge deal

""",
),
(
"""import threading
def my_func():
lock = "whatever"
with threading.Lock():
foo()
""",
"""import threading
def my_func():
lock = "whatever"
lock_1 = threading.Lock()
with lock_1:
foo()
""",
),
],
)
def test_name_resolution(self, tmpdir, input_code, expected_code):
self.run_and_assert(tmpdir, input_code, expected_code)