diff --git a/integration_tests/test_multiple_codemods.py b/integration_tests/test_multiple_codemods.py index a21f5759..edf30880 100644 --- a/integration_tests/test_multiple_codemods.py +++ b/integration_tests/test_multiple_codemods.py @@ -56,5 +56,5 @@ def func(foo=None): ids = codemods.split(",") assert len(results["results"]) == 2 # Order matters - # assert results["results"][0]["codemod"] == f"pixee:python/{ids[0]}" - # assert results["results"][1]["codemod"] == f"pixee:python/{ids[1]}" + assert results["results"][0]["codemod"] == f"pixee:python/{ids[0]}" + assert results["results"][1]["codemod"] == f"pixee:python/{ids[1]}" diff --git a/src/codemodder/cli.py b/src/codemodder/cli.py index ea494aac..529101f8 100644 --- a/src/codemodder/cli.py +++ b/src/codemodder/cli.py @@ -43,8 +43,9 @@ class CsvListAction(argparse.Action): argparse Action to convert "a,b,c" into ["a", "b", "c"] """ - def __call__(self, parser, namespace, values, option_string=None): - items = set(values.split(",")) + def __call__(self, parser, namespace, values: str, option_string=None): + # Conversion to dict removes duplicates while preserving order + items = list(dict.fromkeys(values.split(",")).keys()) self.validate_items(items) setattr(namespace, self.dest, items) diff --git a/src/codemodder/codemodder.py b/src/codemodder/codemodder.py index a78e61da..aa3f3c6a 100644 --- a/src/codemodder/codemodder.py +++ b/src/codemodder/codemodder.py @@ -127,7 +127,6 @@ def run(original_args) -> int: codemod_registry, ) - # TODO: this needs to preserve the given order of codemods # TODO: this should be a method of CodemodExecutionContext codemods_to_run = codemod_registry.match_codemods( argv.codemod_include, argv.codemod_exclude @@ -137,7 +136,7 @@ def run(original_args) -> int: logger.warning("No codemods to run") return 0 - logger.debug("Codemods to run: %s", codemods_to_run) + logger.debug("Codemods to run: %s", [codemod.id for codemod in codemods_to_run]) # XXX: sarif files given on the command line are currently not used by any codemods @@ -152,7 +151,7 @@ def run(original_args) -> int: logger.debug("Matched files:\n%s", "\n".join(full_names)) # run codemods in sequence - for codemod in codemods_to_run.values(): + for codemod in codemods_to_run: results = codemod.apply(context) analyze_files( context, diff --git a/src/codemodder/registry.py b/src/codemodder/registry.py index dd55a5b0..f53d452a 100644 --- a/src/codemodder/registry.py +++ b/src/codemodder/registry.py @@ -18,36 +18,38 @@ class CodemodCollection: class CodemodRegistry: - _codemods: list[CodemodExecutorWrapper] + _codemods_by_name: dict[str, CodemodExecutorWrapper] + _codemods_by_id: dict[str, CodemodExecutorWrapper] def __init__(self): - self._codemods = [] + self._codemods_by_name = {} + self._codemods_by_id = {} @property def names(self): - return [codemod.name for codemod in self._codemods] + return list(self._codemods_by_name.keys()) @property def ids(self): - return [codemod.id for codemod in self._codemods] + return list(self._codemods_by_id.keys()) @property def codemods(self): - return self._codemods + return list(self._codemods_by_name.values()) def add_codemod_collection(self, collection: CodemodCollection): docs_module = files(collection.docs_module) semgrep_module = files(collection.semgrep_config_module) for codemod in collection.codemods: self._validate_codemod(codemod) - self._codemods.append( - CodemodExecutorWrapper( - codemod, - collection.origin, - docs_module, - semgrep_module, - ) + wrapper = CodemodExecutorWrapper( + codemod, + collection.origin, + docs_module, + semgrep_module, ) + self._codemods_by_name[wrapper.name] = wrapper + self._codemods_by_id[wrapper.id] = wrapper def _validate_codemod(self, codemod): for name in ["SUMMARY", "METADATA"]: @@ -74,9 +76,9 @@ def match_codemods( self, codemod_include: Optional[list] = None, codemod_exclude: Optional[list] = None, - ) -> dict: + ) -> list[CodemodExecutorWrapper]: if not codemod_include and not codemod_exclude: - return {codemod.name: codemod for codemod in self._codemods} + return self.codemods codemod_include = codemod_include or [] codemod_exclude = codemod_exclude or [] @@ -84,21 +86,18 @@ def match_codemods( # cli should've already prevented both include/exclude from being set. assert codemod_include or codemod_exclude - # TODO: preserve order of includes if codemod_exclude: - return { - name: codemod - for codemod in self._codemods - if (name := codemod.name) not in codemod_exclude + return [ + codemod + for codemod in self.codemods + if codemod.name not in codemod_exclude and codemod.id not in codemod_exclude - } - - return { - name: codemod - for codemod in self._codemods - if (name := codemod.name) in codemod_include - or codemod.id in codemod_include - } + ] + + return [ + self._codemods_by_name.get(name) or self._codemods_by_id[name] + for name in codemod_include + ] def load_registered_codemods() -> CodemodRegistry: diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index 2b40fbc0..29fa2d86 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -64,7 +64,7 @@ def results_by_id_filepath(self, input_code, root, file_path): name = self.codemod.name() results = self.registry.match_codemods(codemod_include=[name]) - return semgrep_run(results[name].yaml_files, root) + return semgrep_run(results[0].yaml_files, root) def run_and_assert_filepath(self, root, file_path, input_code, expected): input_tree = cst.parse_module(input_code) diff --git a/tests/codemods/test_include_exclude.py b/tests/codemods/test_include_exclude.py index ba28d34a..c792145d 100644 --- a/tests/codemods/test_include_exclude.py +++ b/tests/codemods/test_include_exclude.py @@ -6,29 +6,30 @@ class TestMatchCodemods: @classmethod def setup_class(cls): cls.registry = load_registered_codemods() - cls.codemod_map = { - codemod.name: codemod - for codemod in cls.registry._codemods # pylint: disable=protected-access - } + cls.codemod_map = ( + cls.registry._codemods_by_name # pylint: disable=protected-access + ) def test_no_include_exclude(self): - assert self.registry.match_codemods(None, None) == self.codemod_map + assert self.registry.match_codemods(None, None) == self.registry.codemods @pytest.mark.parametrize( "input_str", ["secure-random", "secure-random,url-sandbox"] ) def test_include(self, input_str): includes = input_str.split(",") - assert self.registry.match_codemods(includes, None) == { - name: self.codemod_map[name] for name in includes - } + assert self.registry.match_codemods(includes, None) == [ + self.codemod_map[name] for name in includes + ] @pytest.mark.parametrize( "input_str", ["url-sandbox,secure-random", "secure-random,url-sandbox"] ) def test_include_preserve_order(self, input_str): includes = input_str.split(",") - assert list(self.registry.match_codemods(includes, None).keys()) == includes + assert [ + codemod.name for codemod in self.registry.match_codemods(includes, None) + ] == includes @pytest.mark.parametrize( "input_str", @@ -38,6 +39,7 @@ def test_include_preserve_order(self, input_str): ], ) def test_exclude(self, input_str): - assert self.registry.match_codemods(None, input_str) == { - k: v for (k, v) in self.codemod_map.items() if k not in input_str.split(",") - } + excludes = input_str.split(",") + assert self.registry.match_codemods(None, excludes) == [ + v for (k, v) in self.codemod_map.items() if k not in excludes + ] diff --git a/tests/test_codemod_docs.py b/tests/test_codemod_docs.py index 38bef1b0..57cc95f0 100644 --- a/tests/test_codemod_docs.py +++ b/tests/test_codemod_docs.py @@ -6,9 +6,7 @@ def pytest_generate_tests(metafunc): registry = load_registered_codemods() if "codemod" in metafunc.fixturenames: - metafunc.parametrize( - "codemod", registry._codemods # pylint: disable=protected-access - ) + metafunc.parametrize("codemod", registry.codemods) def test_load_codemod_description(codemod):