From b83a756ab422af3f6266bbd6b228e377c07afcaa Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 24 Jun 2024 15:32:33 -0700 Subject: [PATCH] Set api key env variables for the llm service classes that need them, remove the xfail for the NVFoundationLLMService --- tests/llm/services/conftest.py | 8 ++++++++ tests/llm/services/test_llm_service.py | 16 +++++++++++----- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/llm/services/conftest.py b/tests/llm/services/conftest.py index 8e128406b2..a802c6ec84 100644 --- a/tests/llm/services/conftest.py +++ b/tests/llm/services/conftest.py @@ -36,6 +36,14 @@ def openai_fixture(openai): yield openai +@pytest.fixture(name="nvfoundationllm", autouse=True, scope='session') +def nvfoundationllm_fixture(nvfoundationllm): + """ + All of the tests in this subdir require nvfoundationllm + """ + yield nvfoundationllm + + @pytest.fixture(name="mock_chat_completion", autouse=True) def mock_chat_completion_fixture(mock_chat_completion): yield mock_chat_completion diff --git a/tests/llm/services/test_llm_service.py b/tests/llm/services/test_llm_service.py index 9f874b95c0..871d1031cf 100644 --- a/tests/llm/services/test_llm_service.py +++ b/tests/llm/services/test_llm_service.py @@ -14,6 +14,7 @@ # limitations under the License. import inspect +import os from abc import ABC from unittest import mock @@ -31,11 +32,16 @@ def test_is_abstract(cls: ABC): assert inspect.isabstract(cls) -@pytest.mark.parametrize( - "service_name, expected_cls", - [("nemo", NeMoLLMService), ("openai", OpenAIChatService), - pytest.param("nvfoundation", NVFoundationLLMService, marks=pytest.mark.xfail(reason="missing dependency"))]) -def test_create(service_name: str, expected_cls: type): +@pytest.mark.usefixtures("restore_environ") +@pytest.mark.parametrize("service_name, expected_cls, env_values", + [("nemo", NeMoLLMService, {}), ("openai", OpenAIChatService, { + 'OPENAI_API_KEY': 'test_api' + }), + pytest.param("nvfoundation", NVFoundationLLMService, {'NVIDIA_API_KEY': 'test_api'})]) +def test_create(service_name: str, expected_cls: type, env_values: dict[str, str]): + if env_values: + os.environ.update(env_values) + service = LLMService.create(service_name) assert isinstance(service, expected_cls)