-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial fix task init codemod * fix task codemod can handle kwargs and more complex cases * fix task init codemod can handle different loop types * document fix task init codemod * asyncio codemod handles eager_start * refactor codemod * change codemod metadata * fix codemod description
- Loading branch information
1 parent
76e1950
commit 6fa0615
Showing
9 changed files
with
513 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from core_codemods.fix_async_task_instantiation import FixAsyncTaskInstantiation | ||
from integration_tests.base_test import ( | ||
BaseIntegrationTest, | ||
original_and_expected_from_code_path, | ||
) | ||
|
||
|
||
class TestFixAsyncTaskInstantiation(BaseIntegrationTest): | ||
codemod = FixAsyncTaskInstantiation | ||
code_path = "tests/samples/fix_async_task_instantiation.py" | ||
original_code, expected_new_code = original_and_expected_from_code_path( | ||
code_path, | ||
[ | ||
( | ||
7, | ||
""" task = asyncio.create_task(my_coroutine(), name="my task")\n""", | ||
), | ||
], | ||
) | ||
|
||
# fmt: off | ||
expected_diff =( | ||
"""--- \n""" | ||
"""+++ \n""" | ||
"""@@ -5,7 +5,7 @@\n""" | ||
""" print("Task completed")\n""" | ||
""" \n""" | ||
""" async def main():\n""" | ||
"""- task = asyncio.Task(my_coroutine(), name="my task")\n""" | ||
"""+ task = asyncio.create_task(my_coroutine(), name="my task")\n""" | ||
""" await task\n""" | ||
""" \n""" | ||
""" asyncio.run(main())\n""" | ||
) | ||
# fmt: on | ||
|
||
expected_line_change = "8" | ||
change_description = FixAsyncTaskInstantiation.change_description | ||
num_changed_files = 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 changes: 9 additions & 0 deletions
9
src/core_codemods/docs/pixee_python_fix-async-task-instantiation.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
The `asyncio` [documentation](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task) explicitly discourages manual instantiation of a `Task` instance and instead recommends calling `create_task`. This keeps your code in line with recommended best practices and promotes maintainability. | ||
|
||
Our changes look like the following: | ||
```diff | ||
import asyncio | ||
|
||
- task = asyncio.Task(my_coroutine(), name="my task") | ||
+ task = asyncio.create_task(my_coroutine(), name="my task") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import libcst as cst | ||
from libcst import MaybeSentinel | ||
from typing import Optional | ||
from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod, Reference | ||
from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin | ||
from codemodder.codemods.utils import BaseType, infer_expression_type | ||
|
||
|
||
class FixAsyncTaskInstantiation(SimpleCodemod, NameAndAncestorResolutionMixin): | ||
metadata = Metadata( | ||
name="fix-async-task-instantiation", | ||
summary="Use High-Level `asyncio` API Functions to Create Tasks", | ||
review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, | ||
references=[ | ||
Reference( | ||
url="https://docs.python.org/3/library/asyncio-task.html#asyncio.Task" | ||
), | ||
], | ||
) | ||
change_description = "Replace instantiation of `asyncio.Task` with higher-level functions to create tasks." | ||
_module_name = "asyncio" | ||
|
||
# pylint: disable=too-many-return-statements | ||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: | ||
if not self.filter_by_path_includes_or_excludes( | ||
self.node_position(original_node) | ||
): | ||
return updated_node | ||
|
||
if self.find_base_name(original_node) != "asyncio.Task": | ||
return updated_node | ||
coroutine_arg = original_node.args[0] | ||
loop_arg, eager_start_arg, other_args = self._split_args(original_node.args[1:]) | ||
|
||
loop_type = ( | ||
infer_expression_type(self.resolve_expression(loop_arg.value)) | ||
if loop_arg | ||
else None | ||
) | ||
|
||
eager_start_type = ( | ||
infer_expression_type(self.resolve_expression(eager_start_arg.value)) | ||
if eager_start_arg | ||
else None | ||
) | ||
|
||
if eager_start_type == BaseType.TRUE: | ||
if not loop_arg or self._is_invalid_loop_value(loop_type): | ||
# asking for eager_start without a loop or incorrectly setting loop is bad. | ||
# We won't do anything. | ||
return updated_node | ||
|
||
loop_arg = loop_arg.with_changes(keyword=None, equal=MaybeSentinel.DEFAULT) | ||
return self.node_eager_task( | ||
original_node, | ||
updated_node, | ||
replacement_args=[loop_arg, coroutine_arg] + other_args, | ||
) | ||
|
||
if loop_arg: | ||
if loop_type == BaseType.NONE: | ||
return self.node_create_task( | ||
original_node, | ||
updated_node, | ||
replacement_args=[coroutine_arg] + other_args, | ||
) | ||
if self._is_invalid_loop_value(loop_type): | ||
# incorrectly assigned loop kwarg to something that is not a loop. | ||
# We won't do anything. | ||
return updated_node | ||
|
||
return self.node_loop_create_task( | ||
original_node, coroutine_arg, loop_arg, other_args | ||
) | ||
return self.node_create_task( | ||
original_node, updated_node, replacement_args=[coroutine_arg] + other_args | ||
) | ||
|
||
def node_create_task( | ||
self, | ||
original_node: cst.Call, | ||
updated_node: cst.Call, | ||
replacement_args=list[cst.Arg], | ||
) -> cst.Call: | ||
"""Convert `asyncio.Task(...)` to `asyncio.create_task(...)`""" | ||
self.report_change(original_node) | ||
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) | ||
if (maybe_name := maybe_name or self._module_name) == self._module_name: | ||
self.add_needed_import(self._module_name) | ||
self.remove_unused_import(original_node) | ||
|
||
if len(replacement_args) == 1: | ||
replacement_args[0] = replacement_args[0].with_changes( | ||
comma=MaybeSentinel.DEFAULT | ||
) | ||
return self.update_call_target( | ||
updated_node, maybe_name, "create_task", replacement_args=replacement_args | ||
) | ||
|
||
def node_eager_task( | ||
self, | ||
original_node: cst.Call, | ||
updated_node: cst.Call, | ||
replacement_args=list[cst.Arg], | ||
) -> cst.Call: | ||
"""Convert `asyncio.Task(...)` to `asyncio.eager_task_factory(loop, coro...)`""" | ||
self.report_change(original_node) | ||
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) | ||
if (maybe_name := maybe_name or self._module_name) == self._module_name: | ||
self.add_needed_import(self._module_name) | ||
self.remove_unused_import(original_node) | ||
return self.update_call_target( | ||
updated_node, | ||
maybe_name, | ||
"eager_task_factory", | ||
replacement_args=replacement_args, | ||
) | ||
|
||
def node_loop_create_task( | ||
self, | ||
original_node: cst.Call, | ||
coroutine_arg: cst.Arg, | ||
loop_arg: cst.Arg, | ||
other_args: list[cst.Arg], | ||
) -> cst.Call: | ||
"""Convert `asyncio.Task(..., loop=loop,...)` to `loop.create_task(...)`""" | ||
self.report_change(original_node) | ||
coroutine_arg = coroutine_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT) | ||
loop_attr = loop_arg.value | ||
new_call = cst.Call( | ||
func=cst.Attribute(value=loop_attr, attr=cst.Name("create_task")), | ||
args=[coroutine_arg] + other_args, | ||
) | ||
self.remove_unused_import(original_node) | ||
return new_call | ||
|
||
def _split_args( | ||
self, args: list[cst.Arg] | ||
) -> tuple[Optional[cst.Arg], Optional[cst.Arg], list[cst.Arg]]: | ||
"""Find the loop kwarg and the eager_start kwarg from a list of args. | ||
Return any args or non-None kwargs. | ||
""" | ||
loop_arg, eager_start_arg = None, None | ||
other_args = [] | ||
for arg in args: | ||
match arg: | ||
case cst.Arg(keyword=cst.Name(value="loop")): | ||
loop_arg = arg | ||
case cst.Arg(keyword=cst.Name(value="eager_start")): | ||
eager_start_arg = arg | ||
case cst.Arg(keyword=cst.Name() as k) if k.value != "None": | ||
# keep kwarg that are not set to None | ||
other_args.append(arg) | ||
case cst.Arg(keyword=None): | ||
# keep post args | ||
other_args.append(arg) | ||
|
||
return loop_arg, eager_start_arg, other_args | ||
|
||
def _is_invalid_loop_value(self, loop_type): | ||
return loop_type in ( | ||
BaseType.NONE, | ||
BaseType.NUMBER, | ||
BaseType.LIST, | ||
BaseType.STRING, | ||
BaseType.BYTES, | ||
BaseType.TRUE, | ||
BaseType.FALSE, | ||
) |
Oops, something went wrong.