diff --git a/README.md b/README.md index eb8a8ba..c38ffb8 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ An overview of each util is below with links to more in-depth documentation and - [get_or_none](#get_or_none): Performs a get on a queryset and returns None if the object does not exist. - [upsert](#upsert): Performs an upsert (update or insert) to a model. - [bulk_update](#bulk_update): Bulk updates a list of models and the fields that have been updated. - +- [post_bulk_operation](#post_bulk_operation): A signal that is fired when a bulk operation happens. ## single() Assumes that the model only has one element in the table or queryset and returns that value. If the table has more than one or no value, an exception is raised. @@ -130,5 +130,20 @@ Performs an bulk update on an list of objects. Any fields listed in the fields_t print model_obj2.int_field, model_obj2.float_field 10, 20.0 +## post_bulk_operation(providing_args=['model']) +A signal that is emitted at the end of a bulk operation. The current bulk operations are Django's update and bulk_create methods and this package's bulk_update method. The signal provides the model that was updated. + +**Examples** + + from manager_utils import post_bulk_operation + + def signal_handler(self, *args, **kwargs): + print kwargs['model'] + + post_bulk_operation.connect(signal_handler) + + TestModel.objects.all().update(int_field=1) + + ## License MIT License (See the LICENSE file included in this repository) diff --git a/manager_utils/__init__.py b/manager_utils/__init__.py index c4f3e4f..0bb13be 100644 --- a/manager_utils/__init__.py +++ b/manager_utils/__init__.py @@ -1 +1 @@ -from .manager_utils import ManagerUtilsMixin, ManagerUtilsManager +from .manager_utils import ManagerUtilsMixin, ManagerUtilsManager, post_bulk_operation diff --git a/manager_utils/manager_utils.py b/manager_utils/manager_utils.py index 7104094..b0cc543 100644 --- a/manager_utils/manager_utils.py +++ b/manager_utils/manager_utils.py @@ -2,9 +2,14 @@ from django.db.models import Manager from django.db.models.query import QuerySet +from django.dispatch import Signal from querybuilder.query import Query +# A signal that is emitted when any bulk operation occurs +post_bulk_operation = Signal(providing_args=['model']) + + class ManagerUtilsQuerySet(QuerySet): """ Defines the methods in the manager utils that can also be applied to querysets. @@ -27,7 +32,15 @@ def single(self): Assumes that this model only has one element in the table and returns it. If the table has more than one or no value, an exception is raised. """ - return self.get(id__gte=0) + return self.get() + + def update(self, **kwargs): + """ + Overrides Django's update method to emit a post_bulk_operation signal when it completes. + """ + ret_val = super(ManagerUtilsQuerySet, self).update(**kwargs) + post_bulk_operation.send(sender=self, model=self.model) + return ret_val class ManagerUtilsMixin(object): @@ -38,6 +51,15 @@ class ManagerUtilsMixin(object): def get_queryset(self): return ManagerUtilsQuerySet(self.model) + def bulk_create(self, objs, batch_size=None): + """ + Overrides Django's bulk_create function to emit a post_bulk_operation signal when bulk_create + is finished. + """ + ret_val = super(ManagerUtilsMixin, self).bulk_create(objs, batch_size=batch_size) + post_bulk_operation.send(sender=self, model=self.model) + return ret_val + def bulk_update(self, model_objs, fields_to_update): """ Bulk updates a list of model objects that are already saved. @@ -46,6 +68,8 @@ def bulk_update(self, model_objs, fields_to_update): model_objs: A list of model objects that have been updated. fields_to_update: A list of fields to be updated. Only these fields will be updated + Sianals: Emits a post_bulk_operation signal when completed. + Examples: # Create a couple test models model_obj1 = TestModel.objects.create(int_field=1, float_field=2.0, char_field='Hi') @@ -80,6 +104,8 @@ def bulk_update(self, model_objs, fields_to_update): fields=chain(['id'] + fields_to_update), ).update(updated_rows) + post_bulk_operation.send(sender=self, model=self.model) + def upsert(self, defaults=None, updates=None, **kwargs): """ Performs an update on an object or an insert if the object does not exist. diff --git a/test_project/tests/manager_utils_tests.py b/test_project/tests/manager_utils_tests.py index ab67e77..08a4461 100644 --- a/test_project/tests/manager_utils_tests.py +++ b/test_project/tests/manager_utils_tests.py @@ -1,9 +1,82 @@ from django.test import TestCase from django_dynamic_fixture import G +from manager_utils import post_bulk_operation from test_project.models import TestModel +class PostBulkOperationSignalTest(TestCase): + """ + Tests that the post_bulk_operation signal is emitted on all functions that emit the signal. + """ + def setUp(self): + """ + Defines a siangl handler that collects information about fired signals + """ + class SignalHandler(object): + num_times_called = 0 + model = None + + def __call__(self, *args, **kwargs): + self.num_times_called += 1 + self.model = kwargs['model'] + + self.signal_handler = SignalHandler() + post_bulk_operation.connect(self.signal_handler) + + def tearDown(self): + """ + Disconnect the siangl to make sure it doesn't get connected multiple times. + """ + post_bulk_operation.disconnect(self.signal_handler) + + def test_post_bulk_operation_queryset_update(self): + """ + Tests that the update operation on a queryset emits the post_bulk_operation signal. + """ + TestModel.objects.all().update(int_field=1) + + self.assertEquals(self.signal_handler.model, TestModel) + self.assertEquals(self.signal_handler.num_times_called, 1) + + def test_post_bulk_operation_manager_update(self): + """ + Tests that the update operation on a manager emits the post_bulk_operation signal. + """ + TestModel.objects.update(int_field=1) + + self.assertEquals(self.signal_handler.model, TestModel) + self.assertEquals(self.signal_handler.num_times_called, 1) + + def test_post_bulk_operation_bulk_update(self): + """ + Tests that the bulk_update operation emits the post_bulk_operation signal. + """ + model_obj = TestModel.objects.create(int_field=2) + TestModel.objects.bulk_update([model_obj], ['int_field']) + + self.assertEquals(self.signal_handler.model, TestModel) + self.assertEquals(self.signal_handler.num_times_called, 1) + + def test_post_bulk_operation_bulk_create(self): + """ + Tests that the bulk_create operation emits the post_bulk_operation signal. + """ + TestModel.objects.bulk_create([TestModel(int_field=2)]) + + self.assertEquals(self.signal_handler.model, TestModel) + self.assertEquals(self.signal_handler.num_times_called, 1) + + def test_save_doesnt_emit_signal(self): + """ + Tests that a non-bulk operation doesn't emit the signal. + """ + model_obj = TestModel.objects.create(int_field=2) + model_obj.save() + + self.assertEquals(self.signal_handler.num_times_called, 0) + + class GetOrNoneTests(TestCase): """ Tests the get_or_none function in the manager utils