Skip to content

Commit

Permalink
Merge pull request #80 from wesokes/hotfix/upsert-pk
Browse files Browse the repository at this point in the history
fix upsert pk issue
  • Loading branch information
somewes authored Feb 20, 2017
2 parents 28e7f9e + 613e66c commit e20addf
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 24 deletions.
4 changes: 4 additions & 0 deletions docs/release_notes.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Release Notes
=============

v0.14.1
-------
* Fix upsert to handle case when the uniqueness constraint is the pk field

v0.14.0
-------
* Drop support for django 1.7, add official support for python 3.5
Expand Down
85 changes: 68 additions & 17 deletions querybuilder/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from copy import deepcopy

from django.db import connection as default_django_connection
from django.db.models import Q
from django.db.models import Q, AutoField
from django.db.models.query import QuerySet
from django.db.models.constants import LOOKUP_SEP
try:
Expand Down Expand Up @@ -1185,17 +1185,22 @@ def get_update_sql(self, rows):

return self.sql, sql_args

def get_upsert_sql(self, rows, unique_fields, update_fields):
def get_upsert_sql(self, rows, unique_fields, update_fields, auto_field_name=None, only_insert=False):
"""
Performs postgres upsert with multiple rows
Generates the postgres specific sql necessary to perform an upsert (ON CONFLICT)
INSERT INTO table_name (field1, field2)
VALUES (1, 'two')
ON CONFLICT (unique_field) DO UPDATE SET field2 = EXCLUDED.field2;
"""
ModelClass = self.tables[0].model
pk_name = ModelClass._meta.pk.column
all_fields = [field for field in ModelClass._meta.fields if field.column != pk_name]

# Use all fields except pk unless the uniqueness constraint is the pk field. Null pk field rows will be
# excluded in the upsert method before calling this method
all_fields = [field for field in ModelClass._meta.fields if field.column != auto_field_name]
if auto_field_name in unique_fields and not only_insert:
all_fields = [field for field in ModelClass._meta.fields]

all_field_names = [field.column for field in all_fields]
all_field_names_sql = ', '.join(all_field_names)

Expand Down Expand Up @@ -1696,40 +1701,86 @@ def update(self, rows):
# execute the query
cursor.execute(sql, sql_args)

def get_auto_field_name(self, model_class):
"""
If one of the unique_fields is the model's AutoField, return the field name, otherwise return None
"""
# Get auto field name (a model can only have one AutoField)
for field in model_class._meta.fields:
if isinstance(field, AutoField):
return field.column

return None

def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_models=False):
"""
Performs an upsert on the set of models defined in rows.
Performs an upsert with the set of models defined in rows. If the unique field which is meant
to cause a conflict is an auto increment field, then the field should be excluded when its value is null.
In this case, an upsert will be performed followed by a bulk_create
"""
if len(rows) == 0:
return

sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields)
ModelClass = self.tables[0].model

# get the cursor to execute the query
cursor = self.get_cursor()
rows_with_null_auto_field_value = []

# execute the query
cursor.execute(sql, sql_args)
# Get auto field name (a model can only have one AutoField)
auto_field_name = self.get_auto_field_name(ModelClass)

# Check if unique fields list contains an auto field
if auto_field_name in unique_fields:
# Separate the rows that need to be inserted vs the rows that need to be upserted
rows_with_null_auto_field_value = [row for row in rows if getattr(row, auto_field_name) is None]
rows = [row for row in rows if getattr(row, auto_field_name) is not None]

return_value = []

if rows:
sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields, auto_field_name=auto_field_name)

# get the cursor to execute the query
cursor = self.get_cursor()

# execute the upsert query
cursor.execute(sql, sql_args)

if return_rows or return_models:
return_value.extend(self._fetch_all_as_dict(cursor))

if rows_with_null_auto_field_value:
sql, sql_args = self.get_upsert_sql(
rows_with_null_auto_field_value,
unique_fields,
update_fields,
auto_field_name=auto_field_name,
only_insert=True,
)

# get the cursor to execute the query
cursor = self.get_cursor()

# execute the upsert query
cursor.execute(sql, sql_args)

if return_rows:
return self._fetch_all_as_dict(cursor)
if return_rows or return_models:
return_value.extend(self._fetch_all_as_dict(cursor))

if return_models:
row_dicts = self._fetch_all_as_dict(cursor)
ModelClass = self.tables[0].model
model_objects = [
ModelClass(**row_dict)
for row_dict in row_dicts
for row_dict in return_value
]

# Set the state to indicate the object has been loaded from db
for model_object in model_objects:
model_object._state.adding = False
model_object._state.db = 'default'

return model_objects
return_value = model_objects

return []
return return_value

def sql_delete(self):
"""
Expand Down
97 changes: 95 additions & 2 deletions querybuilder/tests/upsert_tests.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from django.test.utils import override_settings
from django import VERSION
from django_dynamic_fixture import G

from querybuilder.logger import Logger
from querybuilder.query import Query
from querybuilder.tests.models import Uniques
from querybuilder.tests.models import Uniques, User
from querybuilder.tests.query_tests import QueryTestCase


@override_settings(DEBUG=True)
class TestUpdate(QueryTestCase):
class TestUpsert(QueryTestCase):

def setUp(self):
self.logger = Logger()
Expand Down Expand Up @@ -123,3 +124,95 @@ def test_upsert(self):
self.assertEqual(models[1].field5, 'not null')
self.assertEqual(models[1].field6, '2.6')
self.assertEqual(models[1].field7, '2.7')

def test_upsert_pk(self):
"""
Makes sure upserting is possible when the only uniqueness constraint is the pk.
"""
user1 = G(User, email='user1')
user1.email = 'user1change'
user2 = User(email='user2')
user3 = User(email='user3')

self.assertEqual(User.objects.count(), 1)
Query().from_table(User).upsert(
[user1, user2, user3],
unique_fields=['id'],
update_fields=['email'],
)
self.assertEqual(User.objects.count(), 3)

users = list(User.objects.order_by('id'))

self.assertEqual(users[0].email, 'user1change')
self.assertEqual(users[1].email, 'user2')
self.assertEqual(users[2].email, 'user3')

def test_upsert_pk_return_dicts(self):
"""
Makes sure upserting is possible when the only uniqueness constraint is the pk. Should return dicts.
"""
user1 = G(User, email='user1')
user1.email = 'user1change'
user2 = User(email='user2')
user3 = User(email='user3')

self.assertEqual(User.objects.count(), 1)
rows = Query().from_table(User).upsert(
[user1, user2, user3],
unique_fields=['id'],
update_fields=['email'],
return_rows=True,
)
self.assertEqual(User.objects.count(), 3)
self.assertEqual(len(rows), 3)

# Check ids
for row in rows:
self.assertIsNotNone(row['id'])

# Check emails
email_set = {
row['email'] for row in rows
}
self.assertEqual(email_set, {'user1change', 'user2', 'user3'})

# Check fields from db
users = list(User.objects.order_by('id'))
self.assertEqual(users[0].email, 'user1change')
self.assertEqual(users[1].email, 'user2')
self.assertEqual(users[2].email, 'user3')

def test_upsert_pk_return_models(self):
"""
Makes sure upserting is possible when the only uniqueness constraint is the pk. Should return models.
"""
user1 = G(User, email='user1')
user1.email = 'user1change'
user2 = User(email='user2')
user3 = User(email='user3')

self.assertEqual(User.objects.count(), 1)
records = Query().from_table(User).upsert(
[user1, user2, user3],
unique_fields=['id'],
update_fields=['email'],
return_models=True,
)
self.assertEqual(len(records), 3)

# Check ids
for record in records:
self.assertIsNotNone(record.id)

# Check emails
email_set = {
record.email for record in records
}
self.assertEqual(email_set, {'user1change', 'user2', 'user3'})

# Check fields from db
users = list(User.objects.order_by('id'))
self.assertEqual(users[0].email, 'user1change')
self.assertEqual(users[1].email, 'user2')
self.assertEqual(users[2].email, 'user3')
2 changes: 1 addition & 1 deletion querybuilder/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.14.0'
__version__ = '0.14.1'
8 changes: 4 additions & 4 deletions settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def configure_settings():
if test_db is None:
db_config = {
'ENGINE': 'django.db.backends.postgresql_psycopg2',
'NAME': 'ambition_dev',
'USER': 'ambition_dev',
'PASSWORD': 'ambition_dev',
'HOST': 'localhost'
'NAME': 'ambition',
'USER': 'ambition',
'PASSWORD': 'ambition',
'HOST': 'db'
}
elif test_db == 'postgres':
db_config = {
Expand Down

0 comments on commit e20addf

Please sign in to comment.