Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement file-level parallelization #100

Merged
merged 1 commit into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/codemodder/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ def parse_args(argv, codemod_registry):
default=DEFAULT_INCLUDED_PATHS,
help="Comma-separated set of UNIX glob patterns to include",
)
parser.add_argument(
"--max-workers",
type=int,
default=1,
help="maximum number of workers (threads) to use for processing files in parallel",
)

# At this time we don't do anything with the sarif arg.
parser.add_argument(
Expand Down
108 changes: 66 additions & 42 deletions src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import ThreadPoolExecutor
import datetime
import difflib
import logging
Expand Down Expand Up @@ -29,18 +30,14 @@ def update_code(file_path, new_code):


def apply_codemod_to_file(
execution_context: CodemodExecutionContext,
base_directory: Path,
file_context,
codemod_kls: CodemodExecutorWrapper,
source_tree,
dry_run: bool = False,
):
name = codemod_kls.id
wrapper = cst.MetadataWrapper(source_tree)
codemod = codemod_kls(
CodemodContext(wrapper=wrapper),
execution_context,
file_context,
)
codemod = codemod_kls(CodemodContext(wrapper=wrapper), file_context)
drdavella marked this conversation as resolved.
Show resolved Hide resolved
if not codemod.should_transform:
return False

Expand All @@ -57,59 +54,86 @@ def apply_codemod_to_file(
)

change_set = ChangeSet(
str(file_context.file_path.relative_to(execution_context.directory)),
str(file_context.file_path.relative_to(base_directory)),
diff,
changes=file_context.codemod_changes,
)
execution_context.add_result(name, change_set)
file_context.add_result(change_set)

if not execution_context.dry_run:
if not dry_run:
update_code(file_context.file_path, output_tree.code)

return True


def process_file(
idx: int,
file_path: Path,
base_directory: Path,
codemod,
sarif,
cli_args,
): # pylint: disable=too-many-arguments
logger.debug("scanning file %s", file_path)
if idx and idx % 100 == 0:
logger.info("scanned %s files...", idx) # pragma: no cover

line_exclude = file_line_patterns(file_path, cli_args.path_exclude)
line_include = file_line_patterns(file_path, cli_args.path_include)
sarif_for_file = sarif.get(str(file_path)) or {}

file_context = FileContext(
base_directory,
file_path,
line_exclude,
line_include,
sarif_for_file,
)

try:
with open(file_path, "r", encoding="utf-8") as f:
source_tree = cst.parse_module(f.read())
except Exception:
file_context.add_failure(file_path)
logger.exception("error parsing file %s", file_path)
return file_context

apply_codemod_to_file(
base_directory,
file_context,
codemod,
source_tree,
cli_args.dry_run,
)

return file_context


def analyze_files(
execution_context: CodemodExecutionContext,
files_to_analyze,
codemod,
sarif,
cli_args,
):
# TODO: parallelize this loop
for idx, file_path in enumerate(files_to_analyze):
logger.debug("scanning file %s", file_path)
if idx and idx % 100 == 0:
logger.info("scanned %s files...", idx) # pragma: no cover

try:
with open(file_path, "r", encoding="utf-8") as f:
source_tree = cst.parse_module(f.read())
except Exception:
execution_context.add_failure(codemod.id, file_path)
logger.exception("error parsing file %s", file_path)
continue

line_exclude = file_line_patterns(file_path, cli_args.path_exclude)
line_include = file_line_patterns(file_path, cli_args.path_include)
sarif_for_file = sarif.get(str(file_path)) or {}

# NOTE: file context will become more important if/when we parallelize this loop
file_context = FileContext(
file_path,
line_exclude,
line_include,
sarif_for_file,
with ThreadPoolExecutor(max_workers=cli_args.max_workers) as executor:
logger.debug(
"using executor with %s threads",
cli_args.max_workers,
)

apply_codemod_to_file(
execution_context,
file_context,
codemod,
source_tree,
results = executor.map(
lambda args: process_file(
args[0],
args[1],
execution_context.directory,
codemod,
sarif,
cli_args,
),
enumerate(files_to_analyze),
)

execution_context.add_dependencies(codemod.id, file_context.dependencies)
executor.shutdown(wait=True)
execution_context.process_results(codemod.id, results)


def run(original_args) -> int:
Expand Down
21 changes: 5 additions & 16 deletions src/codemodder/codemods/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from codemodder.codemods.base_visitor import BaseTransformer
from codemodder.change import Change
from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext
from .helpers import Helpers

Expand Down Expand Up @@ -84,13 +83,8 @@ class BaseCodemod(
BaseTransformer,
Helpers,
):
def __init__(
self,
codemod_context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
):
_BaseCodemod.__init__(self, execution_context, file_context)
def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
_BaseCodemod.__init__(self, file_context)
BaseTransformer.__init__(self, codemod_context, [])

def report_change(self, original_node):
Expand All @@ -112,14 +106,9 @@ def __init_subclass__(cls):
super().__init_subclass__()
cls.YAML_FILES = _create_temp_yaml_file(cls, cls.METADATA)

def __init__(
self,
codemod_context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
):
BaseCodemod.__init__(self, codemod_context, execution_context, file_context)
_SemgrepCodemod.__init__(self, execution_context, file_context)
def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
BaseCodemod.__init__(self, codemod_context, file_context)
_SemgrepCodemod.__init__(self, file_context)
BaseTransformer.__init__(self, codemod_context, self._results)

def _new_or_updated_node(self, original_node, updated_node):
Expand Down
6 changes: 1 addition & 5 deletions src/codemodder/codemods/base_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from libcst._position import CodeRange

from codemodder.change import Change
from codemodder.context import CodemodExecutionContext
from codemodder.dependency import Dependency
from codemodder.file_context import FileContext
from codemodder.semgrep import run as semgrep_run
Expand Down Expand Up @@ -46,12 +45,9 @@ class BaseCodemod:
SUMMARY: ClassVar[str] = NotImplemented
is_semgrep: bool = False
adds_dependency: bool = False

execution_context: CodemodExecutionContext
file_context: FileContext

def __init__(self, execution_context: CodemodExecutionContext, file_context):
self.execution_context = execution_context
def __init__(self, file_context: FileContext):
self.file_context = file_context

@classmethod
Expand Down
22 changes: 15 additions & 7 deletions src/codemodder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from pathlib import Path
import itertools
from textwrap import indent
from typing import List, Iterator

from codemodder.change import ChangeSet
from codemodder.dependency import Dependency
from codemodder.dependency_manager import DependencyManager
from codemodder.executor import CodemodExecutorWrapper
from codemodder.file_context import FileContext
from codemodder.logging import logger, log_list
from codemodder.registry import CodemodRegistry
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
Expand Down Expand Up @@ -52,16 +54,16 @@ def __init__(
self.registry = registry
self.repo_manager = repo_manager

def add_result(self, codemod_name, change_set):
self._results_by_codemod.setdefault(codemod_name, []).append(change_set)
def add_results(self, codemod_name: str, change_sets: List[ChangeSet]):
self._results_by_codemod.setdefault(codemod_name, []).extend(change_sets)

def add_failure(self, codemod_name, file_path):
self._failures_by_codemod.setdefault(codemod_name, []).append(file_path)
def add_failures(self, codemod_name: str, failed_files: List[Path]):
self._failures_by_codemod.setdefault(codemod_name, []).extend(failed_files)

def add_dependencies(self, codemod_id: str, dependencies: set[Dependency]):
self.dependencies.setdefault(codemod_id, set()).update(dependencies)

def get_results(self, codemod_name):
def get_results(self, codemod_name: str):
return self._results_by_codemod.get(codemod_name, [])

def get_changed_files(self):
Expand All @@ -71,7 +73,7 @@ def get_changed_files(self):
for change_set in changes
]

def get_failures(self, codemod_name):
def get_failures(self, codemod_name: str):
return self._failures_by_codemod.get(codemod_name, [])

def get_failed_files(self):
Expand All @@ -96,7 +98,7 @@ def process_dependencies(self, codemod_id: str):

dm.add(list(dependencies))
if (changeset := dm.write(self.dry_run)) is not None:
self.add_result(codemod_id, changeset)
self.add_results(codemod_id, [changeset])

def add_description(self, codemod: CodemodExecutorWrapper):
description = codemod.description
Expand All @@ -105,6 +107,12 @@ def add_description(self, codemod: CodemodExecutorWrapper):

return description

def process_results(self, codemod_id: str, results: Iterator[FileContext]):
for file_context in results:
self.add_results(codemod_id, file_context.results)
self.add_failures(codemod_id, file_context.failures)
self.add_dependencies(codemod_id, file_context.dependencies)

def compile_results(self, codemods: list[CodemodExecutorWrapper]):
results = []
for codemod in codemods:
Expand Down
13 changes: 11 additions & 2 deletions src/codemodder/file_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,31 @@
from pathlib import Path
from typing import Dict, List

from codemodder.change import Change
from codemodder.change import Change, ChangeSet
from codemodder.dependency import Dependency


@dataclass
class FileContext:
class FileContext: # pylint: disable=too-many-instance-attributes
"""
Extra context for running codemods on a given file based on the cli parameters.
"""

base_directory: Path
file_path: Path
line_exclude: List[int] = field(default_factory=list)
line_include: List[int] = field(default_factory=list)
results_by_id: Dict = field(default_factory=dict)
dependencies: set[Dependency] = field(default_factory=set)
codemod_changes: List[Change] = field(default_factory=list)
results: List[ChangeSet] = field(default_factory=list)
failures: List[Path] = field(default_factory=list)

def add_dependency(self, dependency: Dependency):
self.dependencies.add(dependency)

def add_result(self, result: ChangeSet):
self.results.append(result)

def add_failure(self, filename: Path):
self.failures.append(filename)
3 changes: 2 additions & 1 deletion src/core_codemods/order_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
filtered_blocks.append(block)
if filtered_blocks:
order_transformer = OrderImportsBlocksTransform(
self.execution_context.directory, filtered_blocks
self.file_context.base_directory,
filtered_blocks,
)
result_tree = tree.visit(order_transformer)

Expand Down
4 changes: 1 addition & 3 deletions src/core_codemods/sql_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
from codemodder.codemods.utils import Append, ReplaceNodes, get_function_name_node
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext

parameter_token = "?"
Expand Down Expand Up @@ -61,14 +60,13 @@ class SQLQueryParameterization(BaseCodemod, UtilsMixin, Codemod):
def __init__(
self,
context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
*codemod_args,
) -> None:
self.changed_nodes: dict[
cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel
] = {}
BaseCodemod.__init__(self, execution_context, file_context, *codemod_args)
BaseCodemod.__init__(self, file_context, *codemod_args)
UtilsMixin.__init__(self, context, {})
Codemod.__init__(self, context)

Expand Down
10 changes: 2 additions & 8 deletions src/core_codemods/upgrade_sslcontext_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
ReviewGuidance,
)
from codemodder.change import Change
from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext


Expand Down Expand Up @@ -43,13 +42,8 @@ class UpgradeSSLContextTLS(SemgrepCodemod, BaseTransformer):
PROTOCOL_ARG_INDEX = 0
PROTOCOL_KWARG_NAME = "protocol"

def __init__(
self,
codemod_context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
):
SemgrepCodemod.__init__(self, execution_context, file_context)
def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
SemgrepCodemod.__init__(self, file_context)
BaseTransformer.__init__(self, codemod_context, self._results)

# TODO: apply unused import remover
Expand Down
Loading