Skip to content

Commit

Permalink
Tests passing, updated mocks for embedding factory
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsmoker committed Oct 30, 2024
1 parent e283155 commit 4b7d240
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 35 deletions.
4 changes: 4 additions & 0 deletions backend/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for queries without source data in vector database
- Graceful failure of triple export when no chunks are found

### Changed

- Separated embedding service from LLM service

## [v0.1.5] - 2024-10-29

### Changed
Expand Down
19 changes: 15 additions & 4 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from app import main
from app.core.config import Settings, get_settings
from app.services.document_service import DocumentService
from app.services.embedding.openai_embedding_service import (
OpenAIEmbeddingService,
)
from app.services.llm.factory import CompletionServiceFactory
from app.services.llm.openai_llm_service import OpenAICompletionService
from app.services.vector_db.base import VectorDBService
Expand Down Expand Up @@ -58,16 +61,24 @@ def client(test_app):


@pytest.fixture(scope="session")
def mock_vector_db_service():
return AsyncMock(spec=VectorDBService)
def mock_vector_db_service(mock_embeddings_service):
service = AsyncMock(spec=VectorDBService)
service.embedding_service = mock_embeddings_service
return service


@pytest.fixture(scope="session")
def mock_embeddings_service():
service = AsyncMock(spec=OpenAIEmbeddingService)
service.get_embeddings.return_value = [0.1, 0.2, 0.3]
return service


@pytest.fixture(scope="session")
def mock_llm_service():
service = MagicMock(spec=OpenAICompletionService)
service = AsyncMock(spec=OpenAICompletionService)
service.client = MagicMock()
service.generate_completion.return_value = "Mocked completion"
service.get_embeddings.return_value = [0.1, 0.2, 0.3]
return service


Expand Down
2 changes: 1 addition & 1 deletion backend/tests/test_factory_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_create_openai_service(mock_settings):
mock_settings.llm_provider = "openai"

with patch(
"app.services.llm.factory.OpenAIService"
"app.services.llm.factory.OpenAICompletionService"
) as mock_openai_service:
service = CompletionServiceFactory.create_service(mock_settings)

Expand Down
12 changes: 9 additions & 3 deletions backend/tests/test_factory_vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from app.core.config import Settings
from app.services.embedding.base import EmbeddingService
from app.services.llm.base import CompletionService
from app.services.vector_db.factory import VectorDBFactory
from app.services.vector_db.milvus_service import MilvusService
Expand All @@ -13,17 +14,22 @@ def mock_llm_service():
return Mock(spec=CompletionService)


def test_create_milvus_service(mock_llm_service):
@pytest.fixture
def mock_embeddings_service():
return Mock(spec=EmbeddingService)


def test_create_milvus_service(mock_llm_service, mock_embeddings_service):
settings = Settings(vector_db_provider="milvus")
vector_db_service = VectorDBFactory.create_vector_db_service(
mock_llm_service, settings
mock_embeddings_service, mock_llm_service, settings
)
assert isinstance(vector_db_service, MilvusService)


def test_create_unknown_vector_db_service(mock_llm_service):
settings = Settings(vector_db_provider="unknown")
vector_db_service = VectorDBFactory.create_vector_db_service(
mock_llm_service, settings
mock_embeddings_service, mock_llm_service, settings
)
assert vector_db_service is None
31 changes: 12 additions & 19 deletions backend/tests/test_service_llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
@pytest.fixture
def openai_service(test_settings):
with (
patch("app.services.llm.openai_service.OpenAI"),
patch("app.services.llm.openai_service.OpenAIEmbeddings"),
patch("app.services.llm.openai_llm_service.OpenAICompletionService"),
patch(
"app.services.embedding.openai_embedding_service.OpenAIEmbeddingService"
),
):
service = OpenAICompletionService(test_settings)
yield service
Expand All @@ -25,18 +27,12 @@ class DummyResponseModel(BaseModel):
mock_parsed = MagicMock()
mock_parsed.model_dump.return_value = {"content": "Test response"}
mock_response.choices[0].message.parsed = mock_parsed
openai_service.client.beta.chat.completions.parse.return_value = (
mock_response
)

result = await openai_service.generate_completion(
"Test prompt", DummyResponseModel
)
# Create an async mock for the parse method
async def mock_parse(*args, **kwargs):
return mock_response

assert isinstance(result, DummyResponseModel)
assert result.content == "Test response"
openai_service.client.beta.chat.completions.parse.assert_called_once()
mock_parsed.model_dump.assert_called_once()
openai_service.client.beta.chat.completions.parse = mock_parse


@pytest.mark.asyncio
Expand All @@ -46,15 +42,12 @@ class DummyResponseModel(BaseModel):

mock_response = MagicMock()
mock_response.choices[0].message.parsed = None
openai_service.client.beta.chat.completions.parse.return_value = (
mock_response
)

result = await openai_service.generate_completion(
"Test prompt", DummyResponseModel
)
# Create an async mock for the parse method
async def mock_parse(*args, **kwargs):
return mock_response

assert result is None
openai_service.client.beta.chat.completions.parse = mock_parse


@pytest.mark.asyncio
Expand Down
48 changes: 40 additions & 8 deletions backend/tests/test_service_vector_db_milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@ def mock_milvus_client():


@pytest.fixture
def milvus_service(mock_llm_service, mock_embeddings_servce, mock_milvus_client):
def mock_embeddings_service():
service = AsyncMock()
service.get_embeddings = AsyncMock()
return service


@pytest.fixture
def milvus_service(
mock_embeddings_service, mock_llm_service, mock_milvus_client
):
settings = Settings(
milvus_db_uri="test_uri",
milvus_db_token="test_token",
Expand All @@ -30,24 +39,47 @@ def milvus_service(mock_llm_service, mock_embeddings_servce, mock_milvus_client)
"app.services.vector_db.milvus_service.MilvusClient",
return_value=mock_milvus_client,
):
service = MilvusService(mock_llm_service, mock_embeddings_servce, settings)
service = MilvusService(
mock_embeddings_service, mock_llm_service, settings
)
yield service


@pytest.mark.asyncio
async def test_get_embeddings_single(milvus_service, mock_llm_service):
mock_llm_service.get_embeddings.return_value = [[0.1, 0.2, 0.3]]
async def test_get_embeddings_single(
milvus_service, mock_embeddings_service
): # Changed from mock_llm_service
# Set up the mock return value
mock_embeddings_service.get_embeddings.return_value = [[0.1, 0.2, 0.3]]

# Execute the test
result = await milvus_service.get_embeddings("test text")

# Verify the result
assert result == [[0.1, 0.2, 0.3]]
mock_llm_service.get_embeddings.assert_called_once_with(["test text"])
mock_embeddings_service.get_embeddings.assert_called_once_with(
["test text"]
)


@pytest.mark.asyncio
async def test_get_embeddings_multiple(milvus_service, mock_llm_service):
mock_llm_service.get_embeddings.return_value = [[0.1, 0.2], [0.3, 0.4]]
async def test_get_embeddings_multiple(
milvus_service, mock_embeddings_service
): # Changed from mock_llm_service
# Set up the mock return value
mock_embeddings_service.get_embeddings.return_value = [
[0.1, 0.2],
[0.3, 0.4],
]

# Execute the test
result = await milvus_service.get_embeddings(["text1", "text2"])

# Verify the result
assert result == [[0.1, 0.2], [0.3, 0.4]]
mock_llm_service.get_embeddings.assert_called_once_with(["text1", "text2"])
mock_embeddings_service.get_embeddings.assert_called_once_with(
["text1", "text2"]
)


@pytest.mark.asyncio
Expand Down

0 comments on commit 4b7d240

Please sign in to comment.