From b082b8038f1143b41a078354a0ededa44fe727c8 Mon Sep 17 00:00:00 2001 From: Daniel D'Avella Date: Fri, 29 Sep 2023 12:18:25 -0400 Subject: [PATCH] First steps towards threading codemod --- src/core_codemods/with_threading_lock.py | 57 +++++++++++++++++ tests/codemods/test_with_threading_lock.py | 71 ++++++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 src/core_codemods/with_threading_lock.py create mode 100644 tests/codemods/test_with_threading_lock.py diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py new file mode 100644 index 00000000..022106f0 --- /dev/null +++ b/src/core_codemods/with_threading_lock.py @@ -0,0 +1,57 @@ +import libcst as cst +from codemodder.codemods.base_codemod import ReviewGuidance +from codemodder.codemods.api import SemgrepCodemod + + +class WithThreadingLock(SemgrepCodemod): + NAME = "bad-lock-with-statement" + SUMMARY = "Replace deprecated usage of threading.Lock context manager" + REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW + DESCRIPTION = "Separates threading lock instantiation and call with `with` statement into two steps." + + @classmethod + def rule(cls): + return """ + rules: + - patterns: + - pattern: | + with $BODY: + ... + - metavariable-pattern: + metavariable: $BODY + patterns: + - pattern: threading.Lock() + - pattern-inside: | + import threading + ... + - focus-metavariable: $BODY + """ + + def leave_With(self, original_node: cst.With, updated_node: cst.With): + # We deliberately restrict ourselves to simple cases where there's only one with clause for now. + # Semgrep appears to be insufficiently expressive to match multiple clauses correctly. + # We should probably just rewrite this codemod using libcst without semgrep. + if len(original_node.items) == 1 and self.node_is_selected( + original_node.items[0] + ): + # TODO: how to avoid name conflicts here? + name = cst.Name(value="lock") + assign = cst.SimpleStatementLine( + body=[ + cst.Assign( + targets=[cst.AssignTarget(target=name)], + value=updated_node.items[0].item, + ) + ] + ) + # TODO: add result + return cst.FlattenSentinel( + [ + assign, + updated_node.with_changes( + items=[cst.WithItem(name, asname=updated_node.items[0].asname)] + ), + ] + ) + + return original_node diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py new file mode 100644 index 00000000..2e5d65fb --- /dev/null +++ b/tests/codemods/test_with_threading_lock.py @@ -0,0 +1,71 @@ +from core_codemods.with_threading_lock import WithThreadingLock +from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest + + +class TestWithThreadingLock(BaseSemgrepCodemodTest): + codemod = WithThreadingLock + + def test_rule_ids(self): + assert self.codemod.name() == "bad-lock-with-statement" + + def test_import(self, tmpdir): + input_code = """import threading +with threading.Lock(): + ... +""" + expected = """import threading +lock = threading.Lock() +with lock: + ... +""" + self.run_and_assert(tmpdir, input_code, expected) + + def test_from_import(self, tmpdir): + input_code = """from threading import Lock +with Lock(): + ... +""" + expected = """from threading import Lock +lock = Lock() +with lock: + ... +""" + self.run_and_assert(tmpdir, input_code, expected) + + def test_simple_replacement_with_as(self, tmpdir): + input_code = """import threading +with threading.Lock() as foo: + ... +""" + expected = """import threading +lock = threading.Lock() +with lock as foo: + ... +""" + self.run_and_assert(tmpdir, input_code, expected) + + def test_no_effect_sanity_check(self, tmpdir): + input_code = """import threading +with threading.Lock(): + ... + +with threading_lock(): + ... +""" + expected = """import threading +lock = threading.Lock() +with lock: + ... + +with threading_lock(): + ... +""" + self.run_and_assert(tmpdir, input_code, expected) + + def test_no_effect_multiple_with_clauses(self, tmpdir): + """This is currently an unsupported case""" + input_code = """import threading +with threading.Lock(), foo(): + ... +""" + self.run_and_assert(tmpdir, input_code, input_code)