Skip to content

Commit

Permalink
Merge pull request #107 from ambitioninc/develop
Browse files Browse the repository at this point in the history
v1.3.1
  • Loading branch information
jaredlewis authored Mar 4, 2019
2 parents bee328b + c0a8586 commit d56a4e9
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 50 deletions.
5 changes: 5 additions & 0 deletions docs/release_notes.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Release Notes
=============

v1.3.1
------
* Fix bulk_update not properly casting fields
* Add support for returning upserts with multiple unique fields for non native

v1.3.0
------
* Updated version 2 interface
Expand Down
116 changes: 91 additions & 25 deletions manager_utils/manager_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from itertools import chain
import itertools

from django.db import connection
from django.db.models import Manager
from django.db.models.query import QuerySet
from django.dispatch import Signal
Expand Down Expand Up @@ -38,22 +39,31 @@ def _get_upserts_distinct(queryset, model_objs_updated, model_objs_created, uniq
Given a list of model objects that were updated and model objects that were created,
fetch the pks of the newly created models and return the two lists in a tuple
"""

# Keep track of the created models
created_models = []
# Fetch the objects that were created based on the uniqueness constraint. Note - only support the case
# where there is one update field so that we can perform an in query. TODO perform an OR query to gather
# the created values when there is more than one update field
if len(unique_fields) == 1:
unique_field = unique_fields[0]

# If we created new models query for them
if model_objs_created:
created_models.extend(
queryset.filter(**{'{0}__in'.format(unique_field): (
getattr(model_obj, unique_field) for model_obj in model_objs_created
)})
queryset.extra(
where=['({unique_fields_sql}) in %s'.format(
unique_fields_sql=', '.join(unique_fields)
)],
params=[
tuple([
tuple([
getattr(model_obj, field)
for field in unique_fields
])
for model_obj in model_objs_created
])
]
)
)
else:
raise NotImplementedError(
'bulk_upsert currently doesnt support returning upserts with more than one update field')

return (model_objs_updated, created_models)
# Return the models
return model_objs_updated, created_models


def _get_upserts(queryset, model_objs_updated, model_objs_created, unique_fields):
Expand Down Expand Up @@ -91,10 +101,15 @@ def _get_prepped_model_field(model_obj, field):
"""
Gets the value of a field of a model obj that is prepared for the db.
"""
try:
return model_obj._meta.get_field(field).get_prep_value(getattr(model_obj, field))
except: # noqa
return getattr(model_obj, field)

# Get the field
field = model_obj._meta.get_field(field)

# Get the value
value = field.get_db_prep_save(getattr(model_obj, field.attname), connection)

# Return the value
return value


def bulk_upsert(
Expand Down Expand Up @@ -509,19 +524,70 @@ def bulk_update(manager, model_objs, fields_to_update):
10, 20.0
"""
updated_rows = [
[model_obj.pk] + [_get_prepped_model_field(model_obj, field_name) for field_name in fields_to_update]

# Add the pk to the value fields so we can join
value_fields = [manager.model._meta.pk.attname] + fields_to_update

# Build the row values
row_values = [
[_get_prepped_model_field(model_obj, field_name) for field_name in value_fields]
for model_obj in model_objs
]
if len(updated_rows) == 0 or len(fields_to_update) == 0:

# If we do not have any values or fields to update just return
if len(row_values) == 0 or len(fields_to_update) == 0:
return

# Execute the bulk update
Query().from_table(
table=manager.model,
fields=chain([manager.model._meta.pk.attname] + fields_to_update),
).update(updated_rows)
# Create a map of db types
db_types = [
manager.model._meta.get_field(field).db_type(connection)
for field in value_fields
]

# Build the value fields sql
value_fields_sql = ', '.join(value_fields)

# Build the set sql
update_fields_sql = ', '.join([
'{field} = new_values.{field}'.format(field=field)
for field in fields_to_update
])

# Build the values sql
values_sql = ', '.join([
'({0})'.format(
', '.join([
'%s::{0}'.format(
db_types[i]
) if not row_number and i else '%s'
for i, _ in enumerate(row)
])
)
for row_number, row in enumerate(row_values)
])

# Start building the query
update_sql = (
'UPDATE {table} '
'SET {update_fields_sql} '
'FROM (VALUES {values_sql}) AS new_values ({value_fields_sql}) '
'WHERE {table}.{pk_field} = new_values.{pk_field}'
).format(
table=manager.model._meta.db_table,
pk_field=manager.model._meta.pk.attname,
update_fields_sql=update_fields_sql,
values_sql=values_sql,
value_fields_sql=value_fields_sql
)

# Combine all the row values
update_sql_params = list(itertools.chain(*row_values))

# Run the update query
with connection.cursor() as cursor:
cursor.execute(update_sql, update_sql_params)

# call the bulk operation signal
post_bulk_operation.send(sender=manager.model, model=manager.model)


Expand Down
110 changes: 90 additions & 20 deletions manager_utils/tests/manager_utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class TestGetPreppedModelField(TestCase):
def test_invalid_field(self):
t = models.TestModel()
with self.assertRaises(AttributeError):
with self.assertRaises(Exception):
_get_prepped_model_field(t, 'non_extant_field')


Expand Down Expand Up @@ -380,43 +380,69 @@ def test_return_upserts_distinct_none_native(self):
models.TestModel.objects.bulk_upsert(
[], ['float_field'], ['float_field'], return_upserts_distinct=True, native=True)

def test_return_multi_unique_fields_not_supported(self):
"""
Current manager utils doesn't support returning bulk upserts when there are multiple unique fields.
"""
with self.assertRaises(NotImplementedError):
models.TestModel.objects.bulk_upsert([], ['float_field', 'int_field'], ['float_field'], return_upserts=True)

def test_return_multi_unique_distinct_fields_not_supported(self):
"""
Current manager utils doesn't support returning bulk upserts when there are multiple unique fields.
"""
with self.assertRaises(NotImplementedError):
models.TestModel.objects.bulk_upsert(
[], ['float_field', 'int_field'], ['float_field'], return_upserts_distinct=True)

def test_return_created_values(self):
"""
Tests that values that are created are returned properly when return_upserts is True.
"""

return_values = models.TestModel.objects.bulk_upsert(
[models.TestModel(int_field=1), models.TestModel(int_field=3), models.TestModel(int_field=4)],
['int_field'], ['float_field'], return_upserts=True
[
models.TestModel(int_field=1, char_field='1'),
models.TestModel(int_field=3, char_field='3'),
models.TestModel(int_field=4, char_field='4')
],
['int_field', 'char_field'],
['float_field'],
return_upserts=True
)

# Assert that we properly returned the models
self.assertEquals(len(return_values), 3)
for test_model, expected_int in zip(sorted(return_values, key=lambda k: k.int_field), [1, 3, 4]):
self.assertEquals(test_model.int_field, expected_int)
self.assertIsNotNone(test_model.id)
self.assertEquals(models.TestModel.objects.count(), 3)

# Run additional upserts
return_values = models.TestModel.objects.bulk_upsert(
[
models.TestModel(int_field=1, char_field='1', float_field=10),
models.TestModel(int_field=3, char_field='3'),
models.TestModel(int_field=4, char_field='4'),
models.TestModel(int_field=5, char_field='5', float_field=50),
],
['int_field', 'char_field'],
['float_field'],
return_upserts=True
)
self.assertEquals(len(return_values), 4)
self.assertEqual(
[
[1, '1', 10],
[3, '3', None],
[4, '4', None],
[5, '5', 50],
],
[
[test_model.int_field, test_model.char_field, test_model.float_field]
for test_model in return_values
]
)

def test_return_created_values_native(self):
"""
Tests that values that are created are returned properly when return_upserts is True.
"""
return_values = models.TestModel.objects.bulk_upsert(
[models.TestModel(int_field=1), models.TestModel(int_field=3), models.TestModel(int_field=4)],
['int_field'], ['float_field'], return_upserts=True, native=True
[
models.TestModel(int_field=1, char_field='1'),
models.TestModel(int_field=3, char_field='3'),
models.TestModel(int_field=4, char_field='4')
],
['int_field', 'char_field'],
['float_field'],
return_upserts=True,
native=True
)

self.assertEquals(len(return_values), 3)
Expand Down Expand Up @@ -1793,6 +1819,50 @@ def test_objs_two_fields_to_update(self):
self.assertEquals(test_obj_1.float_field, 3.0)
self.assertEquals(test_obj_2.float_field, 4.0)

def test_updating_objects_with_custom_db_field_types(self):
"""
Tests when objects are updated that have custom field types
"""
test_obj_1 = G(
models.TestModel,
int_field=1,
float_field=1.0,
json_field={'test': 'test'},
array_field=['one', 'two']
)
test_obj_2 = G(
models.TestModel,
int_field=2,
float_field=2.0,
json_field={'test2': 'test2'},
array_field=['three', 'four']
)

# Change the fields on the models
test_obj_1.json_field = {'test': 'updated'}
test_obj_1.array_field = ['one', 'two', 'updated']

test_obj_2.json_field = {'test2': 'updated'}
test_obj_2.array_field = ['three', 'four', 'updated']

# Do a bulk update with the int fields
models.TestModel.objects.bulk_update(
[test_obj_1, test_obj_2],
['json_field', 'array_field']
)

# Refetch the objects
test_obj_1 = models.TestModel.objects.get(id=test_obj_1.id)
test_obj_2 = models.TestModel.objects.get(id=test_obj_2.id)

# Assert that the json field was updated
self.assertEquals(test_obj_1.json_field, {'test': 'updated'})
self.assertEquals(test_obj_2.json_field, {'test2': 'updated'})

# Assert that the array field was updated
self.assertEquals(test_obj_1.array_field, ['one', 'two', 'updated'])
self.assertEquals(test_obj_2.array_field, ['three', 'four', 'updated'])


class UpsertTest(TestCase):
"""
Expand Down
3 changes: 3 additions & 0 deletions manager_utils/tests/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.contrib.postgres.fields import JSONField, ArrayField
from django.db import models
from manager_utils import ManagerUtilsManager
from timezone_field import TimeZoneField
Expand All @@ -10,6 +11,8 @@ class TestModel(models.Model):
int_field = models.IntegerField(null=True, unique=True)
char_field = models.CharField(max_length=128, null=True)
float_field = models.FloatField(null=True)
json_field = JSONField(default=dict)
array_field = ArrayField(models.CharField(max_length=128), default=list)
time_zone = TimeZoneField(default='UTC')

objects = ManagerUtilsManager()
Expand Down
2 changes: 1 addition & 1 deletion manager_utils/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.3.0'
__version__ = '1.3.1'
4 changes: 2 additions & 2 deletions run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from django_nose import NoseTestSuiteRunner


def run_tests(*test_args, **kwargs):
def run(*test_args, **kwargs):
if not test_args:
test_args = ['manager_utils']

Expand All @@ -30,4 +30,4 @@ def run_tests(*test_args, **kwargs):
parser.add_option('--verbosity', dest='verbosity', action='store', default=1, type=int)
(options, args) = parser.parse_args()

run_tests(*args, **options.__dict__)
run(*args, **options.__dict__)
2 changes: 1 addition & 1 deletion settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def configure_settings():
'NAME': 'ambition',
'USER': 'ambition',
'PASSWORD': 'ambition',
'HOST': 'localhost'
'HOST': 'db'
}
elif test_db == 'postgres':
db_config = {
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@ def get_version():
'django-timezone-field',
'parameterized',
],
test_suite='run_tests.run_tests',
test_suite='run_tests.run',
include_package_data=True,
)

0 comments on commit d56a4e9

Please sign in to comment.