Skip to content

Commit

Permalink
Respect given order of included codemods
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Oct 6, 2023
1 parent aa15f97 commit 4883916
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 50 deletions.
4 changes: 2 additions & 2 deletions integration_tests/test_multiple_codemods.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ def func(foo=None):
ids = codemods.split(",")
assert len(results["results"]) == 2
# Order matters
# assert results["results"][0]["codemod"] == f"pixee:python/{ids[0]}"
# assert results["results"][1]["codemod"] == f"pixee:python/{ids[1]}"
assert results["results"][0]["codemod"] == f"pixee:python/{ids[0]}"
assert results["results"][1]["codemod"] == f"pixee:python/{ids[1]}"
5 changes: 3 additions & 2 deletions src/codemodder/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ class CsvListAction(argparse.Action):
argparse Action to convert "a,b,c" into ["a", "b", "c"]
"""

def __call__(self, parser, namespace, values, option_string=None):
items = set(values.split(","))
def __call__(self, parser, namespace, values: str, option_string=None):
# Conversion to dict removes duplicates while preserving order
items = list(dict.fromkeys(values.split(",")).keys())
self.validate_items(items)
setattr(namespace, self.dest, items)

Expand Down
5 changes: 2 additions & 3 deletions src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def run(original_args) -> int:
codemod_registry,
)

# TODO: this needs to preserve the given order of codemods
# TODO: this should be a method of CodemodExecutionContext
codemods_to_run = codemod_registry.match_codemods(
argv.codemod_include, argv.codemod_exclude
Expand All @@ -137,7 +136,7 @@ def run(original_args) -> int:
logger.warning("No codemods to run")
return 0

logger.debug("Codemods to run: %s", codemods_to_run)
logger.debug("Codemods to run: %s", [codemod.id for codemod in codemods_to_run])

# XXX: sarif files given on the command line are currently not used by any codemods

Expand All @@ -152,7 +151,7 @@ def run(original_args) -> int:
logger.debug("Matched files:\n%s", "\n".join(full_names))

# run codemods in sequence
for codemod in codemods_to_run.values():
for codemod in codemods_to_run:
results = codemod.apply(context)
analyze_files(
context,
Expand Down
53 changes: 26 additions & 27 deletions src/codemodder/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,38 @@ class CodemodCollection:


class CodemodRegistry:
_codemods: list[CodemodExecutorWrapper]
_codemods_by_name: dict[str, CodemodExecutorWrapper]
_codemods_by_id: dict[str, CodemodExecutorWrapper]

def __init__(self):
self._codemods = []
self._codemods_by_name = {}
self._codemods_by_id = {}

@property
def names(self):
return [codemod.name for codemod in self._codemods]
return list(self._codemods_by_name.keys())

@property
def ids(self):
return [codemod.id for codemod in self._codemods]
return list(self._codemods_by_id.keys())

@property
def codemods(self):
return self._codemods
return list(self._codemods_by_name.values())

def add_codemod_collection(self, collection: CodemodCollection):
docs_module = files(collection.docs_module)
semgrep_module = files(collection.semgrep_config_module)
for codemod in collection.codemods:
self._validate_codemod(codemod)
self._codemods.append(
CodemodExecutorWrapper(
codemod,
collection.origin,
docs_module,
semgrep_module,
)
wrapper = CodemodExecutorWrapper(
codemod,
collection.origin,
docs_module,
semgrep_module,
)
self._codemods_by_name[wrapper.name] = wrapper
self._codemods_by_id[wrapper.id] = wrapper

def _validate_codemod(self, codemod):
for name in ["SUMMARY", "METADATA"]:
Expand All @@ -74,31 +76,28 @@ def match_codemods(
self,
codemod_include: Optional[list] = None,
codemod_exclude: Optional[list] = None,
) -> dict:
) -> list[CodemodExecutorWrapper]:
if not codemod_include and not codemod_exclude:
return {codemod.name: codemod for codemod in self._codemods}
return self.codemods

codemod_include = codemod_include or []
codemod_exclude = codemod_exclude or []

# cli should've already prevented both include/exclude from being set.
assert codemod_include or codemod_exclude

# TODO: preserve order of includes
if codemod_exclude:
return {
name: codemod
for codemod in self._codemods
if (name := codemod.name) not in codemod_exclude
return [
codemod
for codemod in self.codemods
if codemod.name not in codemod_exclude
and codemod.id not in codemod_exclude
}

return {
name: codemod
for codemod in self._codemods
if (name := codemod.name) in codemod_include
or codemod.id in codemod_include
}
]

return [
self._codemods_by_name.get(name) or self._codemods_by_id[name]
for name in codemod_include
]


def load_registered_codemods() -> CodemodRegistry:
Expand Down
2 changes: 1 addition & 1 deletion tests/codemods/base_codemod_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def results_by_id_filepath(self, input_code, root, file_path):

name = self.codemod.name()
results = self.registry.match_codemods(codemod_include=[name])
return semgrep_run(results[name].yaml_files, root)
return semgrep_run(results[0].yaml_files, root)

def run_and_assert_filepath(self, root, file_path, input_code, expected):
input_tree = cst.parse_module(input_code)
Expand Down
26 changes: 14 additions & 12 deletions tests/codemods/test_include_exclude.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,30 @@ class TestMatchCodemods:
@classmethod
def setup_class(cls):
cls.registry = load_registered_codemods()
cls.codemod_map = {
codemod.name: codemod
for codemod in cls.registry._codemods # pylint: disable=protected-access
}
cls.codemod_map = (
cls.registry._codemods_by_name # pylint: disable=protected-access
)

def test_no_include_exclude(self):
assert self.registry.match_codemods(None, None) == self.codemod_map
assert self.registry.match_codemods(None, None) == self.registry.codemods

@pytest.mark.parametrize(
"input_str", ["secure-random", "secure-random,url-sandbox"]
)
def test_include(self, input_str):
includes = input_str.split(",")
assert self.registry.match_codemods(includes, None) == {
name: self.codemod_map[name] for name in includes
}
assert self.registry.match_codemods(includes, None) == [
self.codemod_map[name] for name in includes
]

@pytest.mark.parametrize(
"input_str", ["url-sandbox,secure-random", "secure-random,url-sandbox"]
)
def test_include_preserve_order(self, input_str):
includes = input_str.split(",")
assert list(self.registry.match_codemods(includes, None).keys()) == includes
assert [
codemod.name for codemod in self.registry.match_codemods(includes, None)
] == includes

@pytest.mark.parametrize(
"input_str",
Expand All @@ -38,6 +39,7 @@ def test_include_preserve_order(self, input_str):
],
)
def test_exclude(self, input_str):
assert self.registry.match_codemods(None, input_str) == {
k: v for (k, v) in self.codemod_map.items() if k not in input_str.split(",")
}
excludes = input_str.split(",")
assert self.registry.match_codemods(None, excludes) == [
v for (k, v) in self.codemod_map.items() if k not in excludes
]
4 changes: 1 addition & 3 deletions tests/test_codemod_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
def pytest_generate_tests(metafunc):
registry = load_registered_codemods()
if "codemod" in metafunc.fixturenames:
metafunc.parametrize(
"codemod", registry._codemods # pylint: disable=protected-access
)
metafunc.parametrize("codemod", registry.codemods)


def test_load_codemod_description(codemod):
Expand Down

0 comments on commit 4883916

Please sign in to comment.