Skip to content

Commit

Permalink
test base_url and api_key
Browse files Browse the repository at this point in the history
  • Loading branch information
efajardo-nv committed May 15, 2024
1 parent d3c739c commit ea8e9af
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests/llm/services/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def openai_fixture(openai):
yield openai


@pytest.fixture(name="mock_chat_completion", autouse=True)
@pytest.fixture(name="mock_chat_completion", autouse=False)
def mock_chat_completion_fixture(mock_chat_completion):
yield mock_chat_completion

Expand Down
45 changes: 45 additions & 0 deletions tests/llm/services/test_openai_chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.

import asyncio
import openai
import os
from unittest import mock

import pytest
Expand Down Expand Up @@ -44,6 +46,49 @@ def test_constructor_default_service_constructor(mock_chat_completion: tuple[moc
mock_client.assert_called_once_with(api_key=None, base_url=None, max_retries=max_retries)


@pytest.mark.usefixtures("openai")
@pytest.mark.parametrize('use_env', [True, False])
def test_constructor_api_key_base_url(use_env: bool):

if use_env:
env_api_key = "env-12345"
env_base_url = "http://env.openai.com/v1/"
os.environ["OPENAI_API_KEY"] = env_api_key
os.environ["OPENAI_BASE_URL"] = env_base_url

# Test when api_key and base_url are not passed
client = OpenAIChatService().get_client(model_name="test_model")
assert client._client.api_key == env_api_key
assert str(client._client.base_url) == env_base_url

# Test when api_key and base_url are passed
arg_api_key = "arg-12345"
arg_base_url = "http://arg.openai.com/v1/"
client = OpenAIChatService(api_key=arg_api_key, base_url=arg_base_url).get_client(model_name="test_model")
assert client._client.api_key == arg_api_key
assert str(client._client.base_url) == arg_base_url
else:
os.environ.pop("OPENAI_API_KEY")
os.environ.pop("OPENAI_BASE_URL")
# Test when api_key and base_url are not passed
with pytest.raises(openai.OpenAIError) as excinfo:
client = OpenAIChatService().get_client(model_name="test_model")

assert "api_key client option must be set" in str(excinfo.value)

# Test when only api_key is passed
arg_api_key = "arg-12345"
client = OpenAIChatService(api_key=arg_api_key).get_client(model_name="test_model")
assert client._client.api_key == arg_api_key
assert str(client._client.base_url) == "https://api.openai.com/v1/"

# Test when both api_key and base_url are passed
arg_base_url = "http://arg.openai.com/v1/"
client = OpenAIChatService(api_key=arg_api_key, base_url=arg_base_url).get_client(model_name="test_model")
assert client._client.api_key == arg_api_key
assert str(client._client.base_url) == arg_base_url


@pytest.mark.parametrize("use_async", [True, False])
@pytest.mark.parametrize(
"input_dict, set_assistant, expected_messages",
Expand Down

0 comments on commit ea8e9af

Please sign in to comment.