Skip to content

Commit

Permalink
Codemod to fix asyncio.Task (#248)
Browse files Browse the repository at this point in the history
* initial fix task init codemod

* fix task codemod can handle kwargs and more complex cases

* fix task init codemod can handle different loop types

* document fix task init codemod

* asyncio codemod handles eager_start

* refactor codemod

* change codemod metadata

* fix codemod description
  • Loading branch information
clavedeluna authored Feb 16, 2024
1 parent 76e1950 commit 6fa0615
Show file tree
Hide file tree
Showing 9 changed files with 513 additions and 0 deletions.
39 changes: 39 additions & 0 deletions integration_tests/test_fix_task_instantiation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from core_codemods.fix_async_task_instantiation import FixAsyncTaskInstantiation
from integration_tests.base_test import (
BaseIntegrationTest,
original_and_expected_from_code_path,
)


class TestFixAsyncTaskInstantiation(BaseIntegrationTest):
codemod = FixAsyncTaskInstantiation
code_path = "tests/samples/fix_async_task_instantiation.py"
original_code, expected_new_code = original_and_expected_from_code_path(
code_path,
[
(
7,
""" task = asyncio.create_task(my_coroutine(), name="my task")\n""",
),
],
)

# fmt: off
expected_diff =(
"""--- \n"""
"""+++ \n"""
"""@@ -5,7 +5,7 @@\n"""
""" print("Task completed")\n"""
""" \n"""
""" async def main():\n"""
"""- task = asyncio.Task(my_coroutine(), name="my task")\n"""
"""+ task = asyncio.create_task(my_coroutine(), name="my task")\n"""
""" await task\n"""
""" \n"""
""" asyncio.run(main())\n"""
)
# fmt: on

expected_line_change = "8"
change_description = FixAsyncTaskInstantiation.change_description
num_changed_files = 1
9 changes: 9 additions & 0 deletions src/codemodder/codemods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class BaseType(Enum):
LIST = 2
STRING = 3
BYTES = 4
NONE = 5
TRUE = 6
FALSE = 7


# pylint: disable-next=R0911
Expand All @@ -26,6 +29,12 @@ def infer_expression_type(node: cst.BaseExpression) -> Optional[BaseType]:
"""
# The current implementation covers some common cases and is in no way complete
match node:
case cst.Name(value="None"):
return BaseType.NONE
case cst.Name(value="True"):
return BaseType.TRUE
case cst.Name(value="False"):
return BaseType.FALSE
case (
cst.Integer()
| cst.Imaginary()
Expand Down
4 changes: 4 additions & 0 deletions src/codemodder/scripts/generate_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ class DocMetadata:
importance="Medium",
guidance_explained="While string concatenation inside a sequence iterable is likely a mistake, there are instances when you may choose to use them..",
),
"fix-async-task-instantiation": DocMetadata(
importance="Low",
guidance_explained="Manual instantiation of `asyncio.Task` is discouraged. We believe this change is safe and will not cause any issues.",
),
}

METADATA = CORE_METADATA | {
Expand Down
2 changes: 2 additions & 0 deletions src/core_codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .sonar.sonar_django_json_response_type import SonarDjangoJsonResponseType
from .lazy_logging import LazyLogging
from .str_concat_in_seq_literal import StrConcatInSeqLiteral
from .fix_async_task_instantiation import FixAsyncTaskInstantiation

registry = CodemodCollection(
origin="pixee",
Expand Down Expand Up @@ -118,6 +119,7 @@
FixAssertTuple,
LazyLogging,
StrConcatInSeqLiteral,
FixAsyncTaskInstantiation,
],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
The `asyncio` [documentation](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task) explicitly discourages manual instantiation of a `Task` instance and instead recommends calling `create_task`. This keeps your code in line with recommended best practices and promotes maintainability.

Our changes look like the following:
```diff
import asyncio

- task = asyncio.Task(my_coroutine(), name="my task")
+ task = asyncio.create_task(my_coroutine(), name="my task")
```
169 changes: 169 additions & 0 deletions src/core_codemods/fix_async_task_instantiation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import libcst as cst
from libcst import MaybeSentinel
from typing import Optional
from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod, Reference
from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin
from codemodder.codemods.utils import BaseType, infer_expression_type


class FixAsyncTaskInstantiation(SimpleCodemod, NameAndAncestorResolutionMixin):
metadata = Metadata(
name="fix-async-task-instantiation",
summary="Use High-Level `asyncio` API Functions to Create Tasks",
review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW,
references=[
Reference(
url="https://docs.python.org/3/library/asyncio-task.html#asyncio.Task"
),
],
)
change_description = "Replace instantiation of `asyncio.Task` with higher-level functions to create tasks."
_module_name = "asyncio"

# pylint: disable=too-many-return-statements
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
if not self.filter_by_path_includes_or_excludes(
self.node_position(original_node)
):
return updated_node

if self.find_base_name(original_node) != "asyncio.Task":
return updated_node
coroutine_arg = original_node.args[0]
loop_arg, eager_start_arg, other_args = self._split_args(original_node.args[1:])

loop_type = (
infer_expression_type(self.resolve_expression(loop_arg.value))
if loop_arg
else None
)

eager_start_type = (
infer_expression_type(self.resolve_expression(eager_start_arg.value))
if eager_start_arg
else None
)

if eager_start_type == BaseType.TRUE:
if not loop_arg or self._is_invalid_loop_value(loop_type):
# asking for eager_start without a loop or incorrectly setting loop is bad.
# We won't do anything.
return updated_node

loop_arg = loop_arg.with_changes(keyword=None, equal=MaybeSentinel.DEFAULT)
return self.node_eager_task(
original_node,
updated_node,
replacement_args=[loop_arg, coroutine_arg] + other_args,
)

if loop_arg:
if loop_type == BaseType.NONE:
return self.node_create_task(
original_node,
updated_node,
replacement_args=[coroutine_arg] + other_args,
)
if self._is_invalid_loop_value(loop_type):
# incorrectly assigned loop kwarg to something that is not a loop.
# We won't do anything.
return updated_node

return self.node_loop_create_task(
original_node, coroutine_arg, loop_arg, other_args
)
return self.node_create_task(
original_node, updated_node, replacement_args=[coroutine_arg] + other_args
)

def node_create_task(
self,
original_node: cst.Call,
updated_node: cst.Call,
replacement_args=list[cst.Arg],
) -> cst.Call:
"""Convert `asyncio.Task(...)` to `asyncio.create_task(...)`"""
self.report_change(original_node)
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
if (maybe_name := maybe_name or self._module_name) == self._module_name:
self.add_needed_import(self._module_name)
self.remove_unused_import(original_node)

if len(replacement_args) == 1:
replacement_args[0] = replacement_args[0].with_changes(
comma=MaybeSentinel.DEFAULT
)
return self.update_call_target(
updated_node, maybe_name, "create_task", replacement_args=replacement_args
)

def node_eager_task(
self,
original_node: cst.Call,
updated_node: cst.Call,
replacement_args=list[cst.Arg],
) -> cst.Call:
"""Convert `asyncio.Task(...)` to `asyncio.eager_task_factory(loop, coro...)`"""
self.report_change(original_node)
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
if (maybe_name := maybe_name or self._module_name) == self._module_name:
self.add_needed_import(self._module_name)
self.remove_unused_import(original_node)
return self.update_call_target(
updated_node,
maybe_name,
"eager_task_factory",
replacement_args=replacement_args,
)

def node_loop_create_task(
self,
original_node: cst.Call,
coroutine_arg: cst.Arg,
loop_arg: cst.Arg,
other_args: list[cst.Arg],
) -> cst.Call:
"""Convert `asyncio.Task(..., loop=loop,...)` to `loop.create_task(...)`"""
self.report_change(original_node)
coroutine_arg = coroutine_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT)
loop_attr = loop_arg.value
new_call = cst.Call(
func=cst.Attribute(value=loop_attr, attr=cst.Name("create_task")),
args=[coroutine_arg] + other_args,
)
self.remove_unused_import(original_node)
return new_call

def _split_args(
self, args: list[cst.Arg]
) -> tuple[Optional[cst.Arg], Optional[cst.Arg], list[cst.Arg]]:
"""Find the loop kwarg and the eager_start kwarg from a list of args.
Return any args or non-None kwargs.
"""
loop_arg, eager_start_arg = None, None
other_args = []
for arg in args:
match arg:
case cst.Arg(keyword=cst.Name(value="loop")):
loop_arg = arg
case cst.Arg(keyword=cst.Name(value="eager_start")):
eager_start_arg = arg
case cst.Arg(keyword=cst.Name() as k) if k.value != "None":
# keep kwarg that are not set to None
other_args.append(arg)
case cst.Arg(keyword=None):
# keep post args
other_args.append(arg)

return loop_arg, eager_start_arg, other_args

def _is_invalid_loop_value(self, loop_type):
return loop_type in (
BaseType.NONE,
BaseType.NUMBER,
BaseType.LIST,
BaseType.STRING,
BaseType.BYTES,
BaseType.TRUE,
BaseType.FALSE,
)
Loading

0 comments on commit 6fa0615

Please sign in to comment.