diff --git a/tests/llm/services/conftest.py b/tests/llm/services/conftest.py index 8e128406b2..11c5d8f27c 100644 --- a/tests/llm/services/conftest.py +++ b/tests/llm/services/conftest.py @@ -36,7 +36,7 @@ def openai_fixture(openai): yield openai -@pytest.fixture(name="mock_chat_completion", autouse=True) +@pytest.fixture(name="mock_chat_completion", autouse=False) def mock_chat_completion_fixture(mock_chat_completion): yield mock_chat_completion diff --git a/tests/llm/services/test_openai_chat_client.py b/tests/llm/services/test_openai_chat_client.py index 4976eb1655..aa28643afa 100644 --- a/tests/llm/services/test_openai_chat_client.py +++ b/tests/llm/services/test_openai_chat_client.py @@ -14,6 +14,8 @@ # limitations under the License. import asyncio +import openai +import os from unittest import mock import pytest @@ -44,6 +46,49 @@ def test_constructor_default_service_constructor(mock_chat_completion: tuple[moc mock_client.assert_called_once_with(api_key=None, base_url=None, max_retries=max_retries) +@pytest.mark.usefixtures("openai") +@pytest.mark.parametrize('use_env', [True, False]) +def test_constructor_api_key_base_url(use_env: bool): + + if use_env: + env_api_key = "env-12345" + env_base_url = "http://env.openai.com/v1/" + os.environ["OPENAI_API_KEY"] = env_api_key + os.environ["OPENAI_BASE_URL"] = env_base_url + + # Test when api_key and base_url are not passed + client = OpenAIChatService().get_client(model_name="test_model") + assert client._client.api_key == env_api_key + assert str(client._client.base_url) == env_base_url + + # Test when api_key and base_url are passed + arg_api_key = "arg-12345" + arg_base_url = "http://arg.openai.com/v1/" + client = OpenAIChatService(api_key=arg_api_key, base_url=arg_base_url).get_client(model_name="test_model") + assert client._client.api_key == arg_api_key + assert str(client._client.base_url) == arg_base_url + else: + os.environ.pop("OPENAI_API_KEY") + os.environ.pop("OPENAI_BASE_URL") + # Test when api_key and base_url are not passed + with pytest.raises(openai.OpenAIError) as excinfo: + client = OpenAIChatService().get_client(model_name="test_model") + + assert "api_key client option must be set" in str(excinfo.value) + + # Test when only api_key is passed + arg_api_key = "arg-12345" + client = OpenAIChatService(api_key=arg_api_key).get_client(model_name="test_model") + assert client._client.api_key == arg_api_key + assert str(client._client.base_url) == "https://api.openai.com/v1/" + + # Test when both api_key and base_url are passed + arg_base_url = "http://arg.openai.com/v1/" + client = OpenAIChatService(api_key=arg_api_key, base_url=arg_base_url).get_client(model_name="test_model") + assert client._client.api_key == arg_api_key + assert str(client._client.base_url) == arg_base_url + + @pytest.mark.parametrize("use_async", [True, False]) @pytest.mark.parametrize( "input_dict, set_assistant, expected_messages",