diff --git a/src/codemodder/codemodder.py b/src/codemodder/codemodder.py index d7b5cab4..0758ea89 100644 --- a/src/codemodder/codemodder.py +++ b/src/codemodder/codemodder.py @@ -135,6 +135,7 @@ def run( original_cli_args: list[str] | None = None, codemod_registry: registry.CodemodRegistry | None = None, sast_only: bool = False, + ai_client: bool = True, ) -> tuple[CodeTF | None, int]: start = datetime.datetime.now() @@ -173,6 +174,7 @@ def run( path_exclude, tool_result_files_map, max_workers, + ai_client, ) except MisconfiguredAIClient as e: logger.error(e) diff --git a/src/codemodder/context.py b/src/codemodder/context.py index a951be25..4fe48541 100644 --- a/src/codemodder/context.py +++ b/src/codemodder/context.py @@ -66,6 +66,7 @@ def __init__( path_exclude: list[str] | None = None, tool_result_files_map: dict[str, list[Path]] | None = None, max_workers: int = 1, + ai_client: bool = True, ): self.directory = directory self.dry_run = dry_run @@ -84,8 +85,10 @@ def __init__( self.max_workers = max_workers self.tool_result_files_map = tool_result_files_map or {} self.semgrep_prefilter_results = None - self.openai_llm_client = setup_openai_llm_client() - self.azure_llama_llm_client = setup_azure_llama_llm_client() + self.openai_llm_client = setup_openai_llm_client() if ai_client else None + self.azure_llama_llm_client = ( + setup_azure_llama_llm_client() if ai_client else None + ) def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]): self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets) diff --git a/tests/test_context.py b/tests/test_context.py index d80b699c..abebad91 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -243,3 +243,58 @@ def test_get_api_version_from_env(self, mocker): ) assert isinstance(context.openai_llm_client, AzureOpenAI) assert context.openai_llm_client._api_version == version + + def test_disable_ai_client_openai(self, mocker): + mocker.patch.dict(os.environ, {"CODEMODDER_OPENAI_API_KEY": "test"}) + context = Context( + mocker.Mock(), + True, + False, + load_registered_codemods(), + None, + PythonRepoManager(mocker.Mock()), + [], + [], + ai_client=False, + ) + assert context.openai_llm_client is None + + def test_disable_ai_client_azure(self, mocker): + mocker.patch.dict( + os.environ, + { + "CODEMODDER_AZURE_OPENAI_API_KEY": "test", + "CODEMODDER_AZURE_OPENAI_ENDPOINT": "test", + }, + ) + context = Context( + mocker.Mock(), + True, + False, + load_registered_codemods(), + None, + PythonRepoManager(mocker.Mock()), + [], + [], + ai_client=False, + ) + assert context.openai_llm_client is None + + @pytest.mark.parametrize( + "env_var", + ["CODEMODDER_AZURE_OPENAI_API_KEY", "CODEMODDER_AZURE_OPENAI_ENDPOINT"], + ) + def test_no_misconfiguration_ai_client_disabled(self, mocker, env_var): + mocker.patch.dict(os.environ, {env_var: "test"}) + context = Context( + mocker.Mock(), + True, + False, + load_registered_codemods(), + None, + PythonRepoManager(mocker.Mock()), + [], + [], + ai_client=False, + ) + assert context.openai_llm_client is None