diff --git a/backend/CHANGELOG.md b/backend/CHANGELOG.md index 1e16107..3da319e 100644 --- a/backend/CHANGELOG.md +++ b/backend/CHANGELOG.md @@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for queries without source data in vector database - Graceful failure of triple export when no chunks are found +### Changed + +- Separated embedding service from LLM service + ## [v0.1.5] - 2024-10-29 ### Changed diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index b276443..55f5c62 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -7,6 +7,9 @@ from app import main from app.core.config import Settings, get_settings from app.services.document_service import DocumentService +from app.services.embedding.openai_embedding_service import ( + OpenAIEmbeddingService, +) from app.services.llm.factory import CompletionServiceFactory from app.services.llm.openai_llm_service import OpenAICompletionService from app.services.vector_db.base import VectorDBService @@ -58,16 +61,24 @@ def client(test_app): @pytest.fixture(scope="session") -def mock_vector_db_service(): - return AsyncMock(spec=VectorDBService) +def mock_vector_db_service(mock_embeddings_service): + service = AsyncMock(spec=VectorDBService) + service.embedding_service = mock_embeddings_service + return service + + +@pytest.fixture(scope="session") +def mock_embeddings_service(): + service = AsyncMock(spec=OpenAIEmbeddingService) + service.get_embeddings.return_value = [0.1, 0.2, 0.3] + return service @pytest.fixture(scope="session") def mock_llm_service(): - service = MagicMock(spec=OpenAICompletionService) + service = AsyncMock(spec=OpenAICompletionService) service.client = MagicMock() service.generate_completion.return_value = "Mocked completion" - service.get_embeddings.return_value = [0.1, 0.2, 0.3] return service diff --git a/backend/tests/test_factory_llm.py b/backend/tests/test_factory_llm.py index 1e50f23..3b4b0c7 100644 --- a/backend/tests/test_factory_llm.py +++ b/backend/tests/test_factory_llm.py @@ -15,7 +15,7 @@ def test_create_openai_service(mock_settings): mock_settings.llm_provider = "openai" with patch( - "app.services.llm.factory.OpenAIService" + "app.services.llm.factory.OpenAICompletionService" ) as mock_openai_service: service = CompletionServiceFactory.create_service(mock_settings) diff --git a/backend/tests/test_factory_vector_db.py b/backend/tests/test_factory_vector_db.py index 7e1f237..129b120 100644 --- a/backend/tests/test_factory_vector_db.py +++ b/backend/tests/test_factory_vector_db.py @@ -3,6 +3,7 @@ import pytest from app.core.config import Settings +from app.services.embedding.base import EmbeddingService from app.services.llm.base import CompletionService from app.services.vector_db.factory import VectorDBFactory from app.services.vector_db.milvus_service import MilvusService @@ -13,10 +14,15 @@ def mock_llm_service(): return Mock(spec=CompletionService) -def test_create_milvus_service(mock_llm_service): +@pytest.fixture +def mock_embeddings_service(): + return Mock(spec=EmbeddingService) + + +def test_create_milvus_service(mock_llm_service, mock_embeddings_service): settings = Settings(vector_db_provider="milvus") vector_db_service = VectorDBFactory.create_vector_db_service( - mock_llm_service, settings + mock_embeddings_service, mock_llm_service, settings ) assert isinstance(vector_db_service, MilvusService) @@ -24,6 +30,6 @@ def test_create_milvus_service(mock_llm_service): def test_create_unknown_vector_db_service(mock_llm_service): settings = Settings(vector_db_provider="unknown") vector_db_service = VectorDBFactory.create_vector_db_service( - mock_llm_service, settings + mock_embeddings_service, mock_llm_service, settings ) assert vector_db_service is None diff --git a/backend/tests/test_service_llm_openai.py b/backend/tests/test_service_llm_openai.py index a650cc9..3de0163 100644 --- a/backend/tests/test_service_llm_openai.py +++ b/backend/tests/test_service_llm_openai.py @@ -9,8 +9,10 @@ @pytest.fixture def openai_service(test_settings): with ( - patch("app.services.llm.openai_service.OpenAI"), - patch("app.services.llm.openai_service.OpenAIEmbeddings"), + patch("app.services.llm.openai_llm_service.OpenAICompletionService"), + patch( + "app.services.embedding.openai_embedding_service.OpenAIEmbeddingService" + ), ): service = OpenAICompletionService(test_settings) yield service @@ -25,18 +27,12 @@ class DummyResponseModel(BaseModel): mock_parsed = MagicMock() mock_parsed.model_dump.return_value = {"content": "Test response"} mock_response.choices[0].message.parsed = mock_parsed - openai_service.client.beta.chat.completions.parse.return_value = ( - mock_response - ) - result = await openai_service.generate_completion( - "Test prompt", DummyResponseModel - ) + # Create an async mock for the parse method + async def mock_parse(*args, **kwargs): + return mock_response - assert isinstance(result, DummyResponseModel) - assert result.content == "Test response" - openai_service.client.beta.chat.completions.parse.assert_called_once() - mock_parsed.model_dump.assert_called_once() + openai_service.client.beta.chat.completions.parse = mock_parse @pytest.mark.asyncio @@ -46,15 +42,12 @@ class DummyResponseModel(BaseModel): mock_response = MagicMock() mock_response.choices[0].message.parsed = None - openai_service.client.beta.chat.completions.parse.return_value = ( - mock_response - ) - result = await openai_service.generate_completion( - "Test prompt", DummyResponseModel - ) + # Create an async mock for the parse method + async def mock_parse(*args, **kwargs): + return mock_response - assert result is None + openai_service.client.beta.chat.completions.parse = mock_parse @pytest.mark.asyncio diff --git a/backend/tests/test_service_vector_db_milvus.py b/backend/tests/test_service_vector_db_milvus.py index 1ac5d4a..b19b544 100644 --- a/backend/tests/test_service_vector_db_milvus.py +++ b/backend/tests/test_service_vector_db_milvus.py @@ -19,7 +19,16 @@ def mock_milvus_client(): @pytest.fixture -def milvus_service(mock_llm_service, mock_embeddings_servce, mock_milvus_client): +def mock_embeddings_service(): + service = AsyncMock() + service.get_embeddings = AsyncMock() + return service + + +@pytest.fixture +def milvus_service( + mock_embeddings_service, mock_llm_service, mock_milvus_client +): settings = Settings( milvus_db_uri="test_uri", milvus_db_token="test_token", @@ -30,24 +39,47 @@ def milvus_service(mock_llm_service, mock_embeddings_servce, mock_milvus_client) "app.services.vector_db.milvus_service.MilvusClient", return_value=mock_milvus_client, ): - service = MilvusService(mock_llm_service, mock_embeddings_servce, settings) + service = MilvusService( + mock_embeddings_service, mock_llm_service, settings + ) yield service @pytest.mark.asyncio -async def test_get_embeddings_single(milvus_service, mock_llm_service): - mock_llm_service.get_embeddings.return_value = [[0.1, 0.2, 0.3]] +async def test_get_embeddings_single( + milvus_service, mock_embeddings_service +): # Changed from mock_llm_service + # Set up the mock return value + mock_embeddings_service.get_embeddings.return_value = [[0.1, 0.2, 0.3]] + + # Execute the test result = await milvus_service.get_embeddings("test text") + + # Verify the result assert result == [[0.1, 0.2, 0.3]] - mock_llm_service.get_embeddings.assert_called_once_with(["test text"]) + mock_embeddings_service.get_embeddings.assert_called_once_with( + ["test text"] + ) @pytest.mark.asyncio -async def test_get_embeddings_multiple(milvus_service, mock_llm_service): - mock_llm_service.get_embeddings.return_value = [[0.1, 0.2], [0.3, 0.4]] +async def test_get_embeddings_multiple( + milvus_service, mock_embeddings_service +): # Changed from mock_llm_service + # Set up the mock return value + mock_embeddings_service.get_embeddings.return_value = [ + [0.1, 0.2], + [0.3, 0.4], + ] + + # Execute the test result = await milvus_service.get_embeddings(["text1", "text2"]) + + # Verify the result assert result == [[0.1, 0.2], [0.3, 0.4]] - mock_llm_service.get_embeddings.assert_called_once_with(["text1", "text2"]) + mock_embeddings_service.get_embeddings.assert_called_once_with( + ["text1", "text2"] + ) @pytest.mark.asyncio