Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add task and label database repository for get task API #10

Open
wants to merge 6 commits into
base: feat-get-todo-api-add-models
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions todo/repositories/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Added this because without this file Django isn't able to auto detect the test files
20 changes: 20 additions & 0 deletions todo/repositories/common/mongo_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from abc import ABC

from todo_project.db.config import DatabaseManager


class MongoRepository(ABC):
collection = None
collection_name = None
database_manager = DatabaseManager()

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not hasattr(cls, "collection_name") or not isinstance(cls.collection_name, str):
raise TypeError(f"Class {cls.__name__} must define a static `collection_name` field as a string.")

@classmethod
def get_collection(cls):
if cls.collection is None:
cls.collection = cls.database_manager.get_collection(cls.collection_name)
return cls.collection
17 changes: 17 additions & 0 deletions todo/repositories/label_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import List

from bson import ObjectId
from todo.models.label import LabelModel
from todo.repositories.common.mongo_repository import MongoRepository


class LabelRepository(MongoRepository):
collection_name = LabelModel.collection_name

@classmethod
def list_by_ids(cls, ids: List[ObjectId]) -> List[LabelModel]:
if len(ids) == 0:
return []
labels_collection = cls.get_collection()
labels_cursor = labels_collection.find({"_id": {"$in": ids}})
return [LabelModel(**label) for label in labels_cursor]
18 changes: 18 additions & 0 deletions todo/repositories/task_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import List
from todo.models.task import TaskModel
from todo.repositories.common.mongo_repository import MongoRepository


class TaskRepository(MongoRepository):
collection_name = TaskModel.collection_name

@classmethod
def list(cls, page: int, limit: int) -> List[TaskModel]:
tasks_collection = cls.get_collection()
tasks_cursor = tasks_collection.find().skip((page - 1) * limit).limit(limit)
return [TaskModel(**task) for task in tasks_cursor]

@classmethod
def count(cls) -> int:
tasks_collection = cls.get_collection()
return tasks_collection.count_documents({})
2 changes: 2 additions & 0 deletions todo/tests/fixtures/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
"createdBy": "qMbT6M2GB65W7UHgJS4g",
},
]

label_models = [LabelModel(**data) for data in label_db_data]
2 changes: 2 additions & 0 deletions todo/tests/fixtures/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@
"updatedBy": "qMbT6M2GB65W7UHgJS4g",
},
]

tasks_models = [TaskModel(**data) for data in tasks_db_data]
1 change: 1 addition & 0 deletions todo/tests/unit/repositories/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Added this because without this file Django isn't able to auto detect the test files
1 change: 1 addition & 0 deletions todo/tests/unit/repositories/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Added this because without this file Django isn't able to auto detect the test files
57 changes: 57 additions & 0 deletions todo/tests/unit/repositories/common/test_mongo_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from unittest import TestCase
from unittest.mock import MagicMock, patch
from todo.repositories.common.mongo_repository import MongoRepository
from todo_project.db.config import DatabaseManager


class MongoRepositoryTests(TestCase):
def test_subclass_without_collection_name_raises_error(self):
with self.assertRaises(TypeError) as context:

class InvalidRepository(MongoRepository):
pass

self.assertIn(
"Class InvalidRepository must define a static `collection_name` field as a string.", str(context.exception)
)

def test_subclass_with_invalid_collection_name_raises_error(self):
with self.assertRaises(TypeError) as context:

class InvalidRepository(MongoRepository):
collection_name = 123

self.assertIn(
"Class InvalidRepository must define a static `collection_name` field as a string.", str(context.exception)
)

def test_subclass_with_valid_collection_name_passes(self):
try:

class ValidRepository(MongoRepository):
collection_name = "valid_collection"
except TypeError:
self.fail("TypeError raised for a valid subclass with collection_name")

@patch.object(DatabaseManager, "get_collection")
def test_get_collection_initializes_collection(self, mock_get_collection):
class TestRepository(MongoRepository):
collection_name = "test_collection"

mock_get_collection.return_value = MagicMock()

collection = TestRepository.get_collection()
mock_get_collection.assert_called_once_with("test_collection")
self.assertEqual(TestRepository.collection, collection)

@patch.object(DatabaseManager, "get_collection")
def test_get_collection_uses_cached_collection(self, mock_get_collection):
class TestRepository(MongoRepository):
collection_name = "test_collection"

mock_get_collection.return_value = MagicMock()

TestRepository.get_collection()
TestRepository.get_collection()

mock_get_collection.assert_called_once_with("test_collection")
42 changes: 42 additions & 0 deletions todo/tests/unit/repositories/test_label_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from unittest import TestCase
from unittest.mock import patch, MagicMock
from pymongo.collection import Collection
from todo.models.label import LabelModel
from todo.repositories.label_repository import LabelRepository
from todo.tests.fixtures.label import label_db_data


class LabelRepositoryTests(TestCase):
def setUp(self):
self.label_ids = [label_data["_id"] for label_data in label_db_data]
self.label_data = label_db_data

self.patcher_get_collection = patch("todo.repositories.label_repository.LabelRepository.get_collection")
self.mock_get_collection = self.patcher_get_collection.start()
self.mock_collection = MagicMock(spec=Collection)
self.mock_get_collection.return_value = self.mock_collection

def tearDown(self):
self.patcher_get_collection.stop()

def test_list_by_ids_returns_label_models(self):
self.mock_collection.find.return_value = self.label_data

result = LabelRepository.list_by_ids(self.label_ids)

self.assertEqual(len(result), len(self.label_data))
self.assertTrue(all(isinstance(label, LabelModel) for label in result))

def test_list_by_ids_returns_empty_list_if_not_found(self):
self.mock_collection.find.return_value = []

result = LabelRepository.list_by_ids([self.label_ids[0]])

self.assertEqual(result, [])

def test_list_by_ids_skips_db_call_for_empty_input(self):
result = LabelRepository.list_by_ids([])

self.assertEqual(result, [])
self.mock_get_collection.assert_not_called()
self.mock_collection.assert_not_called()
51 changes: 51 additions & 0 deletions todo/tests/unit/repositories/test_task_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from unittest import TestCase
from unittest.mock import patch, MagicMock
from pymongo.collection import Collection
from todo.models.task import TaskModel
from todo.repositories.task_repository import TaskRepository
from todo.tests.fixtures.task import tasks_db_data


class TaskRepositoryTests(TestCase):
def setUp(self):
self.task_data = tasks_db_data

self.patcher_get_collection = patch("todo.repositories.task_repository.TaskRepository.get_collection")
self.mock_get_collection = self.patcher_get_collection.start()
self.mock_collection = MagicMock(spec=Collection)
self.mock_get_collection.return_value = self.mock_collection

def tearDown(self):
self.patcher_get_collection.stop()

def test_list_applies_pagination_correctly(self):
self.mock_collection.find.return_value.skip.return_value.limit.return_value = self.task_data

page = 1
limit = 10
result = TaskRepository.list(page, limit)

self.assertEqual(len(result), len(self.task_data))
self.assertTrue(all(isinstance(task, TaskModel) for task in result))

self.mock_collection.find.assert_called_once()
self.mock_collection.find.return_value.skip.assert_called_once_with(0)
self.mock_collection.find.return_value.skip.return_value.limit.assert_called_once_with(limit)

def test_list_returns_empty_list_for_no_tasks(self):
self.mock_collection.find.return_value.skip.return_value.limit.return_value = []

result = TaskRepository.list(2, 10)

self.assertEqual(result, [])
self.mock_collection.find.assert_called_once()
self.mock_collection.find.return_value.skip.assert_called_once_with(10)
self.mock_collection.find.return_value.skip.return_value.limit.assert_called_once_with(10)

def test_count_returns_total_task_count(self):
self.mock_collection.count_documents.return_value = 42

result = TaskRepository.count()

self.assertEqual(result, 42)
self.mock_collection.count_documents.assert_called_once_with({})
Loading