diff --git a/tests/llm/services/conftest.py b/tests/llm/services/conftest.py index 8e128406b2..a802c6ec84 100644 --- a/tests/llm/services/conftest.py +++ b/tests/llm/services/conftest.py @@ -36,6 +36,14 @@ def openai_fixture(openai): yield openai +@pytest.fixture(name="nvfoundationllm", autouse=True, scope='session') +def nvfoundationllm_fixture(nvfoundationllm): + """ + All of the tests in this subdir require nvfoundationllm + """ + yield nvfoundationllm + + @pytest.fixture(name="mock_chat_completion", autouse=True) def mock_chat_completion_fixture(mock_chat_completion): yield mock_chat_completion diff --git a/tests/llm/services/test_llm_service.py b/tests/llm/services/test_llm_service.py index 9f874b95c0..871d1031cf 100644 --- a/tests/llm/services/test_llm_service.py +++ b/tests/llm/services/test_llm_service.py @@ -14,6 +14,7 @@ # limitations under the License. import inspect +import os from abc import ABC from unittest import mock @@ -31,11 +32,16 @@ def test_is_abstract(cls: ABC): assert inspect.isabstract(cls) -@pytest.mark.parametrize( - "service_name, expected_cls", - [("nemo", NeMoLLMService), ("openai", OpenAIChatService), - pytest.param("nvfoundation", NVFoundationLLMService, marks=pytest.mark.xfail(reason="missing dependency"))]) -def test_create(service_name: str, expected_cls: type): +@pytest.mark.usefixtures("restore_environ") +@pytest.mark.parametrize("service_name, expected_cls, env_values", + [("nemo", NeMoLLMService, {}), ("openai", OpenAIChatService, { + 'OPENAI_API_KEY': 'test_api' + }), + pytest.param("nvfoundation", NVFoundationLLMService, {'NVIDIA_API_KEY': 'test_api'})]) +def test_create(service_name: str, expected_cls: type, env_values: dict[str, str]): + if env_values: + os.environ.update(env_values) + service = LLMService.create(service_name) assert isinstance(service, expected_cls)