Skip to content

Commit

Permalink
Enable library callers to disable AI clients (#907)
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella authored Nov 4, 2024
1 parent 65129f7 commit 199d696
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def run(
original_cli_args: list[str] | None = None,
codemod_registry: registry.CodemodRegistry | None = None,
sast_only: bool = False,
ai_client: bool = True,
) -> tuple[CodeTF | None, int]:
start = datetime.datetime.now()

Expand Down Expand Up @@ -173,6 +174,7 @@ def run(
path_exclude,
tool_result_files_map,
max_workers,
ai_client,
)
except MisconfiguredAIClient as e:
logger.error(e)
Expand Down
7 changes: 5 additions & 2 deletions src/codemodder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
path_exclude: list[str] | None = None,
tool_result_files_map: dict[str, list[Path]] | None = None,
max_workers: int = 1,
ai_client: bool = True,
):
self.directory = directory
self.dry_run = dry_run
Expand All @@ -84,8 +85,10 @@ def __init__(
self.max_workers = max_workers
self.tool_result_files_map = tool_result_files_map or {}
self.semgrep_prefilter_results = None
self.openai_llm_client = setup_openai_llm_client()
self.azure_llama_llm_client = setup_azure_llama_llm_client()
self.openai_llm_client = setup_openai_llm_client() if ai_client else None
self.azure_llama_llm_client = (
setup_azure_llama_llm_client() if ai_client else None
)

def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]):
self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets)
Expand Down
55 changes: 55 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,58 @@ def test_get_api_version_from_env(self, mocker):
)
assert isinstance(context.openai_llm_client, AzureOpenAI)
assert context.openai_llm_client._api_version == version

def test_disable_ai_client_openai(self, mocker):
mocker.patch.dict(os.environ, {"CODEMODDER_OPENAI_API_KEY": "test"})
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
None,
PythonRepoManager(mocker.Mock()),
[],
[],
ai_client=False,
)
assert context.openai_llm_client is None

def test_disable_ai_client_azure(self, mocker):
mocker.patch.dict(
os.environ,
{
"CODEMODDER_AZURE_OPENAI_API_KEY": "test",
"CODEMODDER_AZURE_OPENAI_ENDPOINT": "test",
},
)
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
None,
PythonRepoManager(mocker.Mock()),
[],
[],
ai_client=False,
)
assert context.openai_llm_client is None

@pytest.mark.parametrize(
"env_var",
["CODEMODDER_AZURE_OPENAI_API_KEY", "CODEMODDER_AZURE_OPENAI_ENDPOINT"],
)
def test_no_misconfiguration_ai_client_disabled(self, mocker, env_var):
mocker.patch.dict(os.environ, {env_var: "test"})
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
None,
PythonRepoManager(mocker.Mock()),
[],
[],
ai_client=False,
)
assert context.openai_llm_client is None

0 comments on commit 199d696

Please sign in to comment.