From 83b391011e15ad63d660f5df04dad10e470badb6 Mon Sep 17 00:00:00 2001 From: Samarpan Harit Date: Wed, 18 Dec 2024 01:24:16 +0530 Subject: [PATCH 1/6] Add base mongo repository --- todo/repositories/__init__.py | 1 + todo/repositories/common/mongo_repository.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 todo/repositories/__init__.py create mode 100644 todo/repositories/common/mongo_repository.py diff --git a/todo/repositories/__init__.py b/todo/repositories/__init__.py new file mode 100644 index 0000000..84a4d93 --- /dev/null +++ b/todo/repositories/__init__.py @@ -0,0 +1 @@ +# Added this because without this file Django isn't able to auto detect the test files diff --git a/todo/repositories/common/mongo_repository.py b/todo/repositories/common/mongo_repository.py new file mode 100644 index 0000000..7f8f8b5 --- /dev/null +++ b/todo/repositories/common/mongo_repository.py @@ -0,0 +1,20 @@ +from abc import ABC + +from todo_project.db.config import DatabaseManager + + +class MongoRepository(ABC): + collection = None + collection_name = None + database_manager = DatabaseManager() + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if not hasattr(cls, "collection_name") or not isinstance(cls.collection_name, str): + raise TypeError(f"Class {cls.__name__} must define a static `collection_name` field as a string.") + + @classmethod + def get_collection(cls): + if cls.collection is None: + cls.collection = cls.database_manager.get_collection(cls.collection_name) + return cls.collection From ddd96abc7443727b425f5dadb51baeff67d997d7 Mon Sep 17 00:00:00 2001 From: Samarpan Harit Date: Wed, 18 Dec 2024 01:24:33 +0530 Subject: [PATCH 2/6] Add task repository --- todo/repositories/task_repository.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 todo/repositories/task_repository.py diff --git a/todo/repositories/task_repository.py b/todo/repositories/task_repository.py new file mode 100644 index 0000000..8e89ca0 --- /dev/null +++ b/todo/repositories/task_repository.py @@ -0,0 +1,18 @@ +from typing import List +from todo.models.task import TaskModel +from todo.repositories.common.mongo_repository import MongoRepository + + +class TaskRepository(MongoRepository): + collection_name = TaskModel.collection_name + + @classmethod + def list(cls, page: int, limit: int) -> List[TaskModel]: + tasks_collection = cls.get_collection() + tasks_cursor = tasks_collection.find().skip((page - 1) * limit).limit(limit) + return [TaskModel(**task) for task in tasks_cursor] + + @classmethod + def count(cls) -> int: + tasks_collection = cls.get_collection() + return tasks_collection.count_documents({}) From 285a154c79aa6a90b46b42aaf60ac0de51d374f3 Mon Sep 17 00:00:00 2001 From: Samarpan Harit Date: Wed, 18 Dec 2024 01:24:41 +0530 Subject: [PATCH 3/6] Add label repository --- todo/repositories/label_repository.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 todo/repositories/label_repository.py diff --git a/todo/repositories/label_repository.py b/todo/repositories/label_repository.py new file mode 100644 index 0000000..f65fe1a --- /dev/null +++ b/todo/repositories/label_repository.py @@ -0,0 +1,17 @@ +from typing import List + +from bson import ObjectId +from todo.models.label import LabelModel +from todo.repositories.common.mongo_repository import MongoRepository + + +class LabelRepository(MongoRepository): + collection_name = LabelModel.collection_name + + @classmethod + def list_by_ids(cls, ids: List[ObjectId]) -> List[LabelModel]: + if len(ids) == 0: + return [] + labels_collection = cls.get_collection() + labels_cursor = labels_collection.find({"_id": {"$in": ids}}) + return [LabelModel(**label) for label in labels_cursor] From 4b221804572cb8da65d54eef23a7f74675396e4d Mon Sep 17 00:00:00 2001 From: Samarpan Harit Date: Wed, 18 Dec 2024 01:25:08 +0530 Subject: [PATCH 4/6] Add tests for base mongo repository --- todo/tests/unit/repositories/__init__.py | 1 + .../unit/repositories/common/__init__.py | 1 + .../common/test_mongo_repository.py | 57 +++++++++++++++++++ 3 files changed, 59 insertions(+) create mode 100644 todo/tests/unit/repositories/__init__.py create mode 100644 todo/tests/unit/repositories/common/__init__.py create mode 100644 todo/tests/unit/repositories/common/test_mongo_repository.py diff --git a/todo/tests/unit/repositories/__init__.py b/todo/tests/unit/repositories/__init__.py new file mode 100644 index 0000000..84a4d93 --- /dev/null +++ b/todo/tests/unit/repositories/__init__.py @@ -0,0 +1 @@ +# Added this because without this file Django isn't able to auto detect the test files diff --git a/todo/tests/unit/repositories/common/__init__.py b/todo/tests/unit/repositories/common/__init__.py new file mode 100644 index 0000000..84a4d93 --- /dev/null +++ b/todo/tests/unit/repositories/common/__init__.py @@ -0,0 +1 @@ +# Added this because without this file Django isn't able to auto detect the test files diff --git a/todo/tests/unit/repositories/common/test_mongo_repository.py b/todo/tests/unit/repositories/common/test_mongo_repository.py new file mode 100644 index 0000000..3e963f0 --- /dev/null +++ b/todo/tests/unit/repositories/common/test_mongo_repository.py @@ -0,0 +1,57 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch +from todo.repositories.common.mongo_repository import MongoRepository +from todo_project.db.config import DatabaseManager + + +class MongoRepositoryTests(TestCase): + def test_subclass_without_collection_name_raises_error(self): + with self.assertRaises(TypeError) as context: + + class InvalidRepository(MongoRepository): + pass + + self.assertIn( + "Class InvalidRepository must define a static `collection_name` field as a string.", str(context.exception) + ) + + def test_subclass_with_invalid_collection_name_raises_error(self): + with self.assertRaises(TypeError) as context: + + class InvalidRepository(MongoRepository): + collection_name = 123 + + self.assertIn( + "Class InvalidRepository must define a static `collection_name` field as a string.", str(context.exception) + ) + + def test_subclass_with_valid_collection_name_passes(self): + try: + + class ValidRepository(MongoRepository): + collection_name = "valid_collection" + except TypeError: + self.fail("TypeError raised for a valid subclass with collection_name") + + @patch.object(DatabaseManager, "get_collection") + def test_get_collection_initializes_collection(self, mock_get_collection): + class TestRepository(MongoRepository): + collection_name = "test_collection" + + mock_get_collection.return_value = MagicMock() + + collection = TestRepository.get_collection() + mock_get_collection.assert_called_once_with("test_collection") + self.assertEqual(TestRepository.collection, collection) + + @patch.object(DatabaseManager, "get_collection") + def test_get_collection_uses_cached_collection(self, mock_get_collection): + class TestRepository(MongoRepository): + collection_name = "test_collection" + + mock_get_collection.return_value = MagicMock() + + TestRepository.get_collection() + TestRepository.get_collection() + + mock_get_collection.assert_called_once_with("test_collection") From 1e5d8bfee2908b2206bee32feef787a3731e53d5 Mon Sep 17 00:00:00 2001 From: Samarpan Harit Date: Wed, 18 Dec 2024 01:25:26 +0530 Subject: [PATCH 5/6] Add tests for task repository --- todo/tests/fixtures/task.py | 2 + .../unit/repositories/test_task_repository.py | 51 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 todo/tests/unit/repositories/test_task_repository.py diff --git a/todo/tests/fixtures/task.py b/todo/tests/fixtures/task.py index ff87579..81a4a17 100644 --- a/todo/tests/fixtures/task.py +++ b/todo/tests/fixtures/task.py @@ -35,3 +35,5 @@ "updatedBy": "qMbT6M2GB65W7UHgJS4g", }, ] + +tasks_models = [TaskModel(**data) for data in tasks_db_data] diff --git a/todo/tests/unit/repositories/test_task_repository.py b/todo/tests/unit/repositories/test_task_repository.py new file mode 100644 index 0000000..e8c5f28 --- /dev/null +++ b/todo/tests/unit/repositories/test_task_repository.py @@ -0,0 +1,51 @@ +from unittest import TestCase +from unittest.mock import patch, MagicMock +from pymongo.collection import Collection +from todo.models.task import TaskModel +from todo.repositories.task_repository import TaskRepository +from todo.tests.fixtures.task import tasks_db_data + + +class TaskRepositoryTests(TestCase): + def setUp(self): + self.task_data = tasks_db_data + + self.patcher_get_collection = patch("todo.repositories.task_repository.TaskRepository.get_collection") + self.mock_get_collection = self.patcher_get_collection.start() + self.mock_collection = MagicMock(spec=Collection) + self.mock_get_collection.return_value = self.mock_collection + + def tearDown(self): + self.patcher_get_collection.stop() + + def test_list_applies_pagination_correctly(self): + self.mock_collection.find.return_value.skip.return_value.limit.return_value = self.task_data + + page = 1 + limit = 10 + result = TaskRepository.list(page, limit) + + self.assertEqual(len(result), len(self.task_data)) + self.assertTrue(all(isinstance(task, TaskModel) for task in result)) + + self.mock_collection.find.assert_called_once() + self.mock_collection.find.return_value.skip.assert_called_once_with(0) + self.mock_collection.find.return_value.skip.return_value.limit.assert_called_once_with(limit) + + def test_list_returns_empty_list_for_no_tasks(self): + self.mock_collection.find.return_value.skip.return_value.limit.return_value = [] + + result = TaskRepository.list(2, 10) + + self.assertEqual(result, []) + self.mock_collection.find.assert_called_once() + self.mock_collection.find.return_value.skip.assert_called_once_with(10) + self.mock_collection.find.return_value.skip.return_value.limit.assert_called_once_with(10) + + def test_count_returns_total_task_count(self): + self.mock_collection.count_documents.return_value = 42 + + result = TaskRepository.count() + + self.assertEqual(result, 42) + self.mock_collection.count_documents.assert_called_once_with({}) From 98798030aaca292527c21819c387aad59d8b1c8e Mon Sep 17 00:00:00 2001 From: Samarpan Harit Date: Wed, 18 Dec 2024 01:25:40 +0530 Subject: [PATCH 6/6] Add tests for label repository --- todo/tests/fixtures/label.py | 2 + .../repositories/test_label_repository.py | 42 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 todo/tests/unit/repositories/test_label_repository.py diff --git a/todo/tests/fixtures/label.py b/todo/tests/fixtures/label.py index 3017688..05de140 100644 --- a/todo/tests/fixtures/label.py +++ b/todo/tests/fixtures/label.py @@ -18,3 +18,5 @@ "createdBy": "qMbT6M2GB65W7UHgJS4g", }, ] + +label_models = [LabelModel(**data) for data in label_db_data] diff --git a/todo/tests/unit/repositories/test_label_repository.py b/todo/tests/unit/repositories/test_label_repository.py new file mode 100644 index 0000000..a6ea123 --- /dev/null +++ b/todo/tests/unit/repositories/test_label_repository.py @@ -0,0 +1,42 @@ +from unittest import TestCase +from unittest.mock import patch, MagicMock +from pymongo.collection import Collection +from todo.models.label import LabelModel +from todo.repositories.label_repository import LabelRepository +from todo.tests.fixtures.label import label_db_data + + +class LabelRepositoryTests(TestCase): + def setUp(self): + self.label_ids = [label_data["_id"] for label_data in label_db_data] + self.label_data = label_db_data + + self.patcher_get_collection = patch("todo.repositories.label_repository.LabelRepository.get_collection") + self.mock_get_collection = self.patcher_get_collection.start() + self.mock_collection = MagicMock(spec=Collection) + self.mock_get_collection.return_value = self.mock_collection + + def tearDown(self): + self.patcher_get_collection.stop() + + def test_list_by_ids_returns_label_models(self): + self.mock_collection.find.return_value = self.label_data + + result = LabelRepository.list_by_ids(self.label_ids) + + self.assertEqual(len(result), len(self.label_data)) + self.assertTrue(all(isinstance(label, LabelModel) for label in result)) + + def test_list_by_ids_returns_empty_list_if_not_found(self): + self.mock_collection.find.return_value = [] + + result = LabelRepository.list_by_ids([self.label_ids[0]]) + + self.assertEqual(result, []) + + def test_list_by_ids_skips_db_call_for_empty_input(self): + result = LabelRepository.list_by_ids([]) + + self.assertEqual(result, []) + self.mock_get_collection.assert_not_called() + self.mock_collection.assert_not_called()