Skip to content

Commit

Permalink
Use codemod execution context to store results
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Sep 13, 2023
1 parent ca4bd91 commit 2922572
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 51 deletions.
62 changes: 29 additions & 33 deletions src/codemodder/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,14 @@
from codemodder.cli import parse_args
from codemodder.code_directory import file_line_patterns, match_files
from codemodder.codemods import match_codemods
from codemodder.codemods.change import Change
from codemodder.context import CodemodExecutionContext, ChangeSet
from codemodder.dependency_manager import DependencyManager
from codemodder.report.codetf_reporter import report_default
from codemodder.semgrep import run as semgrep_run
from codemodder.sarifs import parse_sarif_files

# Must use from import here to point to latest state
from codemodder import global_state

RESULTS_BY_CODEMOD = []
from dataclasses import dataclass


@dataclass
class ChangeSet:
"""A set of changes made to a file at `path`"""

path: str
diff: str
changes: list[Change]

def to_json(self):
return {"path": self.path, "diff": self.diff, "changes": self.changes}
from codemodder import global_state # TODO: should not use global state


def update_code(file_path, new_code):
Expand All @@ -46,6 +31,7 @@ def update_code(file_path, new_code):


def run_codemods_for_file(
execution_context: CodemodExecutionContext,
file_context,
codemods_to_run,
source_tree,
Expand All @@ -54,13 +40,16 @@ def run_codemods_for_file(
wrapper = cst.MetadataWrapper(source_tree)
codemod = codemod_kls(
CodemodContext(wrapper=wrapper),
# TODO: eventually pass execution context here
# It will be used for things like dependency management
file_context,
)
if not codemod.should_transform:
continue

logger.info("Running codemod %s for %s", name, file_context.file_path)
output_tree = codemod.transform_module(source_tree)
# TODO: there has got to be a more efficient way to check this?
changed_file = not output_tree.deep_equals(source_tree)

if changed_file:
Expand All @@ -72,21 +61,21 @@ def run_codemods_for_file(
logger.debug("CHANGED %s with codemod %s", file_context.file_path, name)
logger.debug(diff)

codemod_kls.CHANGESET_ALL_FILES.append(
ChangeSet(
# TODO: we should not be using global state here
str(file_context.file_path.relative_to(global_state.DIRECTORY)),
diff,
changes=file_context.codemod_changes,
).to_json()
change_set = ChangeSet(
str(file_context.file_path.relative_to(execution_context.directory)),
diff,
changes=file_context.codemod_changes,
)
if file_context.dry_run:
execution_context.add_result(name, change_set)

if execution_context.dry_run:
logger.info("Dry run, not changing files")
else:
update_code(file_context.file_path, output_tree.code)


def analyze_files(
execution_context: CodemodExecutionContext,
files_to_analyze,
codemods_to_run,
sarif,
Expand All @@ -109,34 +98,39 @@ def analyze_files(

file_context = FileContext(
file_path,
cli_args.dry_run,
line_exclude,
line_include,
sarif_for_file,
)

run_codemods_for_file(
execution_context,
file_context,
codemods_to_run,
source_tree,
)


def compile_results(codemods):
def compile_results(execution_context: CodemodExecutionContext, codemods):
results = []
for name, codemod_kls in codemods.items():
if not codemod_kls.CHANGESET_ALL_FILES:
if not (changeset := execution_context.results_by_codemod.get(name)):
continue

data = {
# TODO: this prefix should not be hardcoded
"codemod": f"pixee:python/{name}",
"summary": codemod_kls.SUMMARY,
"description": codemod_kls.METADATA.DESCRIPTION,
"references": [],
"properties": {},
"failedFiles": [],
"changeset": codemod_kls.CHANGESET_ALL_FILES,
"changeset": [change.to_json() for change in changeset],
}

RESULTS_BY_CODEMOD.append(data)
results.append(data)

return results


def run(argv, original_args) -> int:
Expand All @@ -147,6 +141,7 @@ def run(argv, original_args) -> int:
return 1

global_state.set_directory(Path(argv.directory))
context = CodemodExecutionContext(Path(argv.directory), argv.dry_run)

codemods_to_run = match_codemods(argv.codemod_include, argv.codemod_exclude)
if not codemods_to_run:
Expand Down Expand Up @@ -179,18 +174,19 @@ def run(argv, original_args) -> int:
logger.debug("Matched files:\n%s", "\n".join(full_names))

analyze_files(
context,
files_to_analyze,
codemods_to_run,
sarif_results,
argv,
)

compile_results(codemods_to_run)
results = compile_results(context, codemods_to_run)

DependencyManager().write(dry_run=argv.dry_run)
DependencyManager().write(dry_run=context.dry_run)
elapsed = datetime.datetime.now() - start
elapsed_ms = int(elapsed.total_seconds() * 1000)
report_default(elapsed_ms, argv, original_args, RESULTS_BY_CODEMOD)
report_default(elapsed_ms, argv, original_args, results)
return 0


Expand Down
3 changes: 3 additions & 0 deletions src/codemodder/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
CODEMOD_NAMES = [codemod.name() for codemod in DEFAULT_CODEMODS]


# TODO: codemod registry


def match_codemods(codemod_include: list, codemod_exclude: list) -> dict:
if not codemod_include and not codemod_exclude:
return {codemod.name(): codemod for codemod in DEFAULT_CODEMODS}
Expand Down
2 changes: 0 additions & 2 deletions src/codemodder/codemods/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ class BaseCodemod(
BaseTransformer,
Helpers,
):
CHANGESET_ALL_FILES: list = []

def report_change(self, original_node):
line_number = self.lineno_for_node(original_node)
self.file_context.codemod_changes.append(
Expand Down
13 changes: 7 additions & 6 deletions src/codemodder/codemods/base_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from enum import Enum
import itertools
from typing import List, ClassVar

from codemodder.file_context import FileContext
from codemodder.semgrep import rule_ids_from_yaml_files


Expand All @@ -24,11 +26,10 @@ class BaseCodemod:
SUMMARY: ClassVar[str] = NotImplemented
IS_SEMGREP = False

file_context: FileContext

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# TODO: this should not belong to codemod classes
# Instead it should be owned by a global context that gets passed in
cls.CHANGESET_ALL_FILES: list = []

if "codemodder.codemods.base_codemod.SemgrepCodemod" in str(cls):
# hack: SemgrepCodemod won't NotImplementedError but all other child
Expand Down Expand Up @@ -86,11 +87,11 @@ def __init_subclass__(cls, **kwargs):

cls.RULE_IDS = rule_ids_from_yaml_files(cls.YAML_FILES)

def __init__(self, file_context):
super().__init__(file_context)
def __init__(self, *args):
super().__init__(*args)
self._results = list(
itertools.chain.from_iterable(
map(lambda rId: file_context.results_by_id[rId], self.RULE_IDS)
map(lambda rId: self.file_context.results_by_id[rId], self.RULE_IDS)
)
)

Expand Down
30 changes: 30 additions & 0 deletions src/codemodder/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pathlib import Path
from dataclasses import dataclass

from codemodder.codemods.change import Change


@dataclass
class ChangeSet:
"""A set of changes made to a file at `path`"""

path: str
diff: str
changes: list[Change]

def to_json(self):
return {"path": self.path, "diff": self.diff, "changes": self.changes}


class CodemodExecutionContext:
results_by_codemod: dict[str, list[ChangeSet]] = {}
directory: Path
dry_run: bool = False

def __init__(self, directory, dry_run):
self.directory = directory
self.dry_run = dry_run
self.results_by_codemod = {}

def add_result(self, codemod_name, change_set):
self.results_by_codemod.setdefault(codemod_name, []).append(change_set)
1 change: 0 additions & 1 deletion src/codemodder/file_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class FileContext:
"""

file_path: Path
dry_run: bool
line_exclude: List[int]
line_include: List[int]
results_by_id: DefaultDict
Expand Down
5 changes: 2 additions & 3 deletions tests/codemods/base_codemod_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from libcst.codemod import CodemodContext
from pathlib import Path
import os
from collections import defaultdict
from codemodder.file_context import FileContext
from codemodder.semgrep import run_on_directory as semgrep_run
from codemodder.semgrep import find_all_yaml_files
Expand All @@ -23,10 +24,9 @@ def run_and_assert_filepath(self, _, file_path, input_code, expected):
input_tree = cst.parse_module(input_code)
self.file_context = FileContext(
file_path,
False,
[],
[],
[],
defaultdict(list),
)
command_instance = self.codemod(CodemodContext(), self.file_context)
output_tree = command_instance.transform_module(input_tree)
Expand All @@ -49,7 +49,6 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected):
results = all_results[str(file_path)]
self.file_context = FileContext(
file_path,
False,
[],
[],
results,
Expand Down
6 changes: 1 addition & 5 deletions tests/shared.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Module to store shared utilities for both unit and integration tests"""
import pytest
from codemodder import global_state
from codemodder.__main__ import RESULTS_BY_CODEMOD
from codemodder.codemods import ALL_CODEMODS
from codemodder.dependency_manager import DependencyManager


# TODO: should not have any global state
@pytest.fixture(autouse=True, scope="function")
def reset_global_state():
"""
Expand All @@ -16,6 +15,3 @@ def reset_global_state():
yield
DependencyManager.clear_instance()
global_state.set_directory("")
RESULTS_BY_CODEMOD.clear()
for codemod_kls in ALL_CODEMODS:
codemod_kls.CHANGESET_ALL_FILES = []
2 changes: 1 addition & 1 deletion tests/test_file_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@


def test_file_context():
file_context = FileContext(None, None, None, None, None)
file_context = FileContext(None, None, None, None)
assert file_context.line_exclude == []
assert file_context.line_include == []

0 comments on commit 2922572

Please sign in to comment.