diff --git a/docs/release_notes.rst b/docs/release_notes.rst index 44dd69c..8a00db5 100644 --- a/docs/release_notes.rst +++ b/docs/release_notes.rst @@ -1,6 +1,11 @@ Release Notes ============= +v1.3.1 +------ +* Fix bulk_update not properly casting fields +* Add support for returning upserts with multiple unique fields for non native + v1.3.0 ------ * Updated version 2 interface diff --git a/manager_utils/manager_utils.py b/manager_utils/manager_utils.py index 1f622ca..d092c22 100644 --- a/manager_utils/manager_utils.py +++ b/manager_utils/manager_utils.py @@ -1,5 +1,6 @@ -from itertools import chain +import itertools +from django.db import connection from django.db.models import Manager from django.db.models.query import QuerySet from django.dispatch import Signal @@ -38,22 +39,31 @@ def _get_upserts_distinct(queryset, model_objs_updated, model_objs_created, uniq Given a list of model objects that were updated and model objects that were created, fetch the pks of the newly created models and return the two lists in a tuple """ + + # Keep track of the created models created_models = [] - # Fetch the objects that were created based on the uniqueness constraint. Note - only support the case - # where there is one update field so that we can perform an in query. TODO perform an OR query to gather - # the created values when there is more than one update field - if len(unique_fields) == 1: - unique_field = unique_fields[0] + + # If we created new models query for them + if model_objs_created: created_models.extend( - queryset.filter(**{'{0}__in'.format(unique_field): ( - getattr(model_obj, unique_field) for model_obj in model_objs_created - )}) + queryset.extra( + where=['({unique_fields_sql}) in %s'.format( + unique_fields_sql=', '.join(unique_fields) + )], + params=[ + tuple([ + tuple([ + getattr(model_obj, field) + for field in unique_fields + ]) + for model_obj in model_objs_created + ]) + ] + ) ) - else: - raise NotImplementedError( - 'bulk_upsert currently doesnt support returning upserts with more than one update field') - return (model_objs_updated, created_models) + # Return the models + return model_objs_updated, created_models def _get_upserts(queryset, model_objs_updated, model_objs_created, unique_fields): @@ -91,10 +101,15 @@ def _get_prepped_model_field(model_obj, field): """ Gets the value of a field of a model obj that is prepared for the db. """ - try: - return model_obj._meta.get_field(field).get_prep_value(getattr(model_obj, field)) - except: # noqa - return getattr(model_obj, field) + + # Get the field + field = model_obj._meta.get_field(field) + + # Get the value + value = field.get_db_prep_save(getattr(model_obj, field.attname), connection) + + # Return the value + return value def bulk_upsert( @@ -509,19 +524,70 @@ def bulk_update(manager, model_objs, fields_to_update): 10, 20.0 """ - updated_rows = [ - [model_obj.pk] + [_get_prepped_model_field(model_obj, field_name) for field_name in fields_to_update] + + # Add the pk to the value fields so we can join + value_fields = [manager.model._meta.pk.attname] + fields_to_update + + # Build the row values + row_values = [ + [_get_prepped_model_field(model_obj, field_name) for field_name in value_fields] for model_obj in model_objs ] - if len(updated_rows) == 0 or len(fields_to_update) == 0: + + # If we do not have any values or fields to update just return + if len(row_values) == 0 or len(fields_to_update) == 0: return - # Execute the bulk update - Query().from_table( - table=manager.model, - fields=chain([manager.model._meta.pk.attname] + fields_to_update), - ).update(updated_rows) + # Create a map of db types + db_types = [ + manager.model._meta.get_field(field).db_type(connection) + for field in value_fields + ] + # Build the value fields sql + value_fields_sql = ', '.join(value_fields) + + # Build the set sql + update_fields_sql = ', '.join([ + '{field} = new_values.{field}'.format(field=field) + for field in fields_to_update + ]) + + # Build the values sql + values_sql = ', '.join([ + '({0})'.format( + ', '.join([ + '%s::{0}'.format( + db_types[i] + ) if not row_number and i else '%s' + for i, _ in enumerate(row) + ]) + ) + for row_number, row in enumerate(row_values) + ]) + + # Start building the query + update_sql = ( + 'UPDATE {table} ' + 'SET {update_fields_sql} ' + 'FROM (VALUES {values_sql}) AS new_values ({value_fields_sql}) ' + 'WHERE {table}.{pk_field} = new_values.{pk_field}' + ).format( + table=manager.model._meta.db_table, + pk_field=manager.model._meta.pk.attname, + update_fields_sql=update_fields_sql, + values_sql=values_sql, + value_fields_sql=value_fields_sql + ) + + # Combine all the row values + update_sql_params = list(itertools.chain(*row_values)) + + # Run the update query + with connection.cursor() as cursor: + cursor.execute(update_sql, update_sql_params) + + # call the bulk operation signal post_bulk_operation.send(sender=manager.model, model=manager.model) diff --git a/manager_utils/tests/manager_utils_tests.py b/manager_utils/tests/manager_utils_tests.py index 9b7ca17..afe67cf 100644 --- a/manager_utils/tests/manager_utils_tests.py +++ b/manager_utils/tests/manager_utils_tests.py @@ -15,7 +15,7 @@ class TestGetPreppedModelField(TestCase): def test_invalid_field(self): t = models.TestModel() - with self.assertRaises(AttributeError): + with self.assertRaises(Exception): _get_prepped_model_field(t, 'non_extant_field') @@ -380,43 +380,69 @@ def test_return_upserts_distinct_none_native(self): models.TestModel.objects.bulk_upsert( [], ['float_field'], ['float_field'], return_upserts_distinct=True, native=True) - def test_return_multi_unique_fields_not_supported(self): - """ - Current manager utils doesn't support returning bulk upserts when there are multiple unique fields. - """ - with self.assertRaises(NotImplementedError): - models.TestModel.objects.bulk_upsert([], ['float_field', 'int_field'], ['float_field'], return_upserts=True) - - def test_return_multi_unique_distinct_fields_not_supported(self): - """ - Current manager utils doesn't support returning bulk upserts when there are multiple unique fields. - """ - with self.assertRaises(NotImplementedError): - models.TestModel.objects.bulk_upsert( - [], ['float_field', 'int_field'], ['float_field'], return_upserts_distinct=True) - def test_return_created_values(self): """ Tests that values that are created are returned properly when return_upserts is True. """ + return_values = models.TestModel.objects.bulk_upsert( - [models.TestModel(int_field=1), models.TestModel(int_field=3), models.TestModel(int_field=4)], - ['int_field'], ['float_field'], return_upserts=True + [ + models.TestModel(int_field=1, char_field='1'), + models.TestModel(int_field=3, char_field='3'), + models.TestModel(int_field=4, char_field='4') + ], + ['int_field', 'char_field'], + ['float_field'], + return_upserts=True ) + # Assert that we properly returned the models self.assertEquals(len(return_values), 3) for test_model, expected_int in zip(sorted(return_values, key=lambda k: k.int_field), [1, 3, 4]): self.assertEquals(test_model.int_field, expected_int) self.assertIsNotNone(test_model.id) self.assertEquals(models.TestModel.objects.count(), 3) + # Run additional upserts + return_values = models.TestModel.objects.bulk_upsert( + [ + models.TestModel(int_field=1, char_field='1', float_field=10), + models.TestModel(int_field=3, char_field='3'), + models.TestModel(int_field=4, char_field='4'), + models.TestModel(int_field=5, char_field='5', float_field=50), + ], + ['int_field', 'char_field'], + ['float_field'], + return_upserts=True + ) + self.assertEquals(len(return_values), 4) + self.assertEqual( + [ + [1, '1', 10], + [3, '3', None], + [4, '4', None], + [5, '5', 50], + ], + [ + [test_model.int_field, test_model.char_field, test_model.float_field] + for test_model in return_values + ] + ) + def test_return_created_values_native(self): """ Tests that values that are created are returned properly when return_upserts is True. """ return_values = models.TestModel.objects.bulk_upsert( - [models.TestModel(int_field=1), models.TestModel(int_field=3), models.TestModel(int_field=4)], - ['int_field'], ['float_field'], return_upserts=True, native=True + [ + models.TestModel(int_field=1, char_field='1'), + models.TestModel(int_field=3, char_field='3'), + models.TestModel(int_field=4, char_field='4') + ], + ['int_field', 'char_field'], + ['float_field'], + return_upserts=True, + native=True ) self.assertEquals(len(return_values), 3) @@ -1793,6 +1819,50 @@ def test_objs_two_fields_to_update(self): self.assertEquals(test_obj_1.float_field, 3.0) self.assertEquals(test_obj_2.float_field, 4.0) + def test_updating_objects_with_custom_db_field_types(self): + """ + Tests when objects are updated that have custom field types + """ + test_obj_1 = G( + models.TestModel, + int_field=1, + float_field=1.0, + json_field={'test': 'test'}, + array_field=['one', 'two'] + ) + test_obj_2 = G( + models.TestModel, + int_field=2, + float_field=2.0, + json_field={'test2': 'test2'}, + array_field=['three', 'four'] + ) + + # Change the fields on the models + test_obj_1.json_field = {'test': 'updated'} + test_obj_1.array_field = ['one', 'two', 'updated'] + + test_obj_2.json_field = {'test2': 'updated'} + test_obj_2.array_field = ['three', 'four', 'updated'] + + # Do a bulk update with the int fields + models.TestModel.objects.bulk_update( + [test_obj_1, test_obj_2], + ['json_field', 'array_field'] + ) + + # Refetch the objects + test_obj_1 = models.TestModel.objects.get(id=test_obj_1.id) + test_obj_2 = models.TestModel.objects.get(id=test_obj_2.id) + + # Assert that the json field was updated + self.assertEquals(test_obj_1.json_field, {'test': 'updated'}) + self.assertEquals(test_obj_2.json_field, {'test2': 'updated'}) + + # Assert that the array field was updated + self.assertEquals(test_obj_1.array_field, ['one', 'two', 'updated']) + self.assertEquals(test_obj_2.array_field, ['three', 'four', 'updated']) + class UpsertTest(TestCase): """ diff --git a/manager_utils/tests/models.py b/manager_utils/tests/models.py index 5c2dbeb..a5d345a 100644 --- a/manager_utils/tests/models.py +++ b/manager_utils/tests/models.py @@ -1,3 +1,4 @@ +from django.contrib.postgres.fields import JSONField, ArrayField from django.db import models from manager_utils import ManagerUtilsManager from timezone_field import TimeZoneField @@ -10,6 +11,8 @@ class TestModel(models.Model): int_field = models.IntegerField(null=True, unique=True) char_field = models.CharField(max_length=128, null=True) float_field = models.FloatField(null=True) + json_field = JSONField(default=dict) + array_field = ArrayField(models.CharField(max_length=128), default=list) time_zone = TimeZoneField(default='UTC') objects = ManagerUtilsManager() diff --git a/manager_utils/version.py b/manager_utils/version.py index 19b4f1d..72837bd 100644 --- a/manager_utils/version.py +++ b/manager_utils/version.py @@ -1 +1 @@ -__version__ = '1.3.0' +__version__ = '1.3.1' diff --git a/run_tests.py b/run_tests.py index 0a2db9e..a2ac039 100644 --- a/run_tests.py +++ b/run_tests.py @@ -13,7 +13,7 @@ from django_nose import NoseTestSuiteRunner -def run_tests(*test_args, **kwargs): +def run(*test_args, **kwargs): if not test_args: test_args = ['manager_utils'] @@ -30,4 +30,4 @@ def run_tests(*test_args, **kwargs): parser.add_option('--verbosity', dest='verbosity', action='store', default=1, type=int) (options, args) = parser.parse_args() - run_tests(*args, **options.__dict__) + run(*args, **options.__dict__) diff --git a/settings.py b/settings.py index c90ee15..c028a2f 100644 --- a/settings.py +++ b/settings.py @@ -16,7 +16,7 @@ def configure_settings(): 'NAME': 'ambition', 'USER': 'ambition', 'PASSWORD': 'ambition', - 'HOST': 'localhost' + 'HOST': 'db' } elif test_db == 'postgres': db_config = { diff --git a/setup.py b/setup.py index 7acd168..4b41226 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,6 @@ def get_version(): 'django-timezone-field', 'parameterized', ], - test_suite='run_tests.run_tests', + test_suite='run_tests.run', include_package_data=True, )