Skip to content

Commit

Permalink
Set api key env variables for the llm service classes that need them,…
Browse files Browse the repository at this point in the history
… remove the xfail for the NVFoundationLLMService
  • Loading branch information
dagardner-nv committed Jun 24, 2024
1 parent d563105 commit b83a756
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
8 changes: 8 additions & 0 deletions tests/llm/services/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def openai_fixture(openai):
yield openai


@pytest.fixture(name="nvfoundationllm", autouse=True, scope='session')
def nvfoundationllm_fixture(nvfoundationllm):
"""
All of the tests in this subdir require nvfoundationllm
"""
yield nvfoundationllm


@pytest.fixture(name="mock_chat_completion", autouse=True)
def mock_chat_completion_fixture(mock_chat_completion):
yield mock_chat_completion
Expand Down
16 changes: 11 additions & 5 deletions tests/llm/services/test_llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import inspect
import os
from abc import ABC
from unittest import mock

Expand All @@ -31,11 +32,16 @@ def test_is_abstract(cls: ABC):
assert inspect.isabstract(cls)


@pytest.mark.parametrize(
"service_name, expected_cls",
[("nemo", NeMoLLMService), ("openai", OpenAIChatService),
pytest.param("nvfoundation", NVFoundationLLMService, marks=pytest.mark.xfail(reason="missing dependency"))])
def test_create(service_name: str, expected_cls: type):
@pytest.mark.usefixtures("restore_environ")
@pytest.mark.parametrize("service_name, expected_cls, env_values",
[("nemo", NeMoLLMService, {}), ("openai", OpenAIChatService, {
'OPENAI_API_KEY': 'test_api'
}),
pytest.param("nvfoundation", NVFoundationLLMService, {'NVIDIA_API_KEY': 'test_api'})])
def test_create(service_name: str, expected_cls: type, env_values: dict[str, str]):
if env_values:
os.environ.update(env_values)

service = LLMService.create(service_name)
assert isinstance(service, expected_cls)

Expand Down

0 comments on commit b83a756

Please sign in to comment.