diff --git a/tests/llm/services/test_llm_service.py b/tests/llm/services/test_llm_service.py index fecc1551a8..69ad4c7546 100644 --- a/tests/llm/services/test_llm_service.py +++ b/tests/llm/services/test_llm_service.py @@ -20,8 +20,30 @@ from morpheus.llm.services.llm_service import LLMClient from morpheus.llm.services.llm_service import LLMService +from morpheus.llm.services.nemo_llm_service import NeMoLLMService +from morpheus.llm.services.nvfoundation_llm_service import NVFoundationLLMService +from morpheus.llm.services.openai_chat_service import OpenAIChatService @pytest.mark.parametrize("cls", [LLMClient, LLMService]) def test_is_abstract(cls: ABC): assert inspect.isabstract(cls) + + +@pytest.mark.parametrize( + "service_name, expected_cls", + [("nemo", NeMoLLMService), ("openai", OpenAIChatService), + pytest.param("nvfoundation", NVFoundationLLMService, marks=pytest.mark.xfail(reason="missing dependency"))]) +def test_create(service_name: str, expected_cls: type): + service = LLMService.create(service_name) + assert isinstance(service, expected_cls) + + +@pytest.mark.parametrize( + "service_name, class_name", + [("nemo", "morpheus.llm.services.nemo_llm_service.NeMoLLMService"), + ("openai", "morpheus.llm.services.openai_chat_service.OpenAIChatService"), + ("nvfoundation", NVFoundationLLMService, marks=pytest.mark.xfail(reason="missing dependency"))]) +def test_create_mocked(service_name: str, class_name: str): + service = LLMService.create(service_name) + assert isinstance(service, expected_cls)