From f5054e129e0c60ddc09035e7de629bc4cb421f8b Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Mon, 16 Oct 2023 09:29:28 -0300 Subject: [PATCH] use type-specific names --- src/core_codemods/with_threading_lock.py | 21 +++++++++++++--- tests/codemods/test_with_threading_lock.py | 28 +++++++++++----------- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py index 2b531abcb..9f7cc526c 100644 --- a/src/core_codemods/with_threading_lock.py +++ b/src/core_codemods/with_threading_lock.py @@ -41,9 +41,14 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): original_node.items[0] ): current_names = self.find_used_names_in_module() - # arbitrarily choose `lock_cm` if `lock` name is already taken - # in hopes that `lock_cm` is very unlikely to be used. - value = "lock" if "lock" not in current_names else "lock_cm" + name_for_node_type = _get_node_name(original_node) + # arbitrarily choose `name_cm` if `name` name is already taken + # in hopes that `name_cm` is very unlikely to be used. + value = ( + name_for_node_type + if name_for_node_type not in current_names + else f"{name_for_node_type}_cm" + ) name = cst.Name(value=value) assign = cst.SimpleStatementLine( body=[ @@ -64,3 +69,13 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): ) return original_node + + +def _get_node_name(original_node: cst.With): + func_call = original_node.items[0].item.func + if isinstance(func_call, cst.Name): + return func_call.value.lower() + elif isinstance(func_call, cst.Attribute): + return func_call.attr.value.lower() + else: + return "" diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py index c2e4a7717..78665d075 100644 --- a/tests/codemods/test_with_threading_lock.py +++ b/tests/codemods/test_with_threading_lock.py @@ -27,8 +27,8 @@ def test_import(self, tmpdir, klass): ... """ expected = f"""import threading -lock = threading.{klass}() -with lock: +{klass.lower()} = threading.{klass}() +with {klass.lower()}: ... """ self.run_and_assert(tmpdir, input_code, expected) @@ -40,8 +40,8 @@ def test_from_import(self, tmpdir, klass): ... """ expected = f"""from threading import {klass} -lock = {klass}() -with lock: +{klass.lower()} = {klass}() +with {klass.lower()}: ... """ self.run_and_assert(tmpdir, input_code, expected) @@ -53,8 +53,8 @@ def test_simple_replacement_with_as(self, tmpdir, klass): ... """ expected = f"""import threading -lock = threading.{klass}() -with lock as foo: +{klass.lower()} = threading.{klass}() +with {klass.lower()} as foo: ... """ self.run_and_assert(tmpdir, input_code, expected) @@ -69,8 +69,8 @@ def test_no_effect_sanity_check(self, tmpdir, klass): ... """ expected = f"""import threading -lock = threading.{klass}() -with lock: +{klass.lower()} = threading.{klass}() +with {klass.lower()}: ... with threading_lock(): @@ -136,30 +136,30 @@ def f(l): """, ), ( - """import threading + """import threading with threading.Lock(): int("1") - + with threading.Lock(): print() """, - """import threading + """import threading lock = threading.Lock() with lock: int("1") - + lock = threading.Lock() with lock: print() """, ), ( - """import threading + """import threading with threading.Lock(): with threading.Lock(): print() """, - """import threading + """import threading lock = threading.Lock() with lock: lock = threading.Lock()