From 38581e3730b5022b023763a1671457714f0aac1c Mon Sep 17 00:00:00 2001 From: Wes Okes Date: Fri, 17 Feb 2017 13:17:02 -0500 Subject: [PATCH 1/5] fix upsert pk issue --- docs/release_notes.rst | 4 ++ querybuilder/query.py | 34 +++++++++-- querybuilder/tests/upsert_tests.py | 97 +++++++++++++++++++++++++++++- querybuilder/version.py | 2 +- settings.py | 8 +-- 5 files changed, 133 insertions(+), 12 deletions(-) diff --git a/docs/release_notes.rst b/docs/release_notes.rst index ed04574..ba8c279 100644 --- a/docs/release_notes.rst +++ b/docs/release_notes.rst @@ -1,6 +1,10 @@ Release Notes ============= +v0.14.1 +------- +* Fix upsert to handle case when the uniqueness constraint is the pk field + v0.14.0 ------- * Drop support for django 1.7, add official support for python 3.5 diff --git a/querybuilder/query.py b/querybuilder/query.py index e17ba4d..ca02c4b 100644 --- a/querybuilder/query.py +++ b/querybuilder/query.py @@ -1187,7 +1187,7 @@ def get_update_sql(self, rows): def get_upsert_sql(self, rows, unique_fields, update_fields): """ - Performs postgres upsert with multiple rows + Generates the postgres specific sql necessary to perform an upsert (ON CONFLICT) INSERT INTO table_name (field1, field2) VALUES (1, 'two') @@ -1195,7 +1195,13 @@ def get_upsert_sql(self, rows, unique_fields, update_fields): """ ModelClass = self.tables[0].model pk_name = ModelClass._meta.pk.column + + # Use all fields except pk unless the uniqueness constraint is the pk field. Null pk field rows will be + # excluded in the upsert method before calling this method all_fields = [field for field in ModelClass._meta.fields if field.column != pk_name] + if len(unique_fields) == 1 and unique_fields[0] == pk_name: + all_fields = [field for field in ModelClass._meta.fields] + all_field_names = [field.column for field in all_fields] all_field_names_sql = ', '.join(all_field_names) @@ -1698,21 +1704,39 @@ def update(self, rows): def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_models=False): """ - Performs an upsert on the set of models defined in rows. + Performs an upsert with the set of models defined in rows. If the unique field which is meant + to cause a conflict is an auto increment field, then the field should be excluded when its value is null. + In this case, an upsert will be performed followed by a bulk_create """ if len(rows) == 0: return + ModelClass = self.tables[0].model + pk_name = ModelClass._meta.pk.column + + rows_without_pk = [] + if len(unique_fields) == 1 and unique_fields[0] == pk_name: + rows_without_pk = [row for row in rows if getattr(row, pk_name) is None] + rows = [row for row in rows if getattr(row, pk_name) is not None] + sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields) # get the cursor to execute the query cursor = self.get_cursor() - # execute the query + # execute the upsert query cursor.execute(sql, sql_args) + # execute the bulk create query if needed + bulk_created_records = [] + if rows_without_pk: + bulk_created_records = ModelClass.objects.bulk_create(rows_without_pk) + if return_rows: - return self._fetch_all_as_dict(cursor) + return self._fetch_all_as_dict(cursor) + [ + record.__dict__ + for record in bulk_created_records + ] if return_models: row_dicts = self._fetch_all_as_dict(cursor) @@ -1727,7 +1751,7 @@ def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_m model_object._state.adding = False model_object._state.db = 'default' - return model_objects + return model_objects + bulk_created_records return [] diff --git a/querybuilder/tests/upsert_tests.py b/querybuilder/tests/upsert_tests.py index d4f233f..96a22b2 100644 --- a/querybuilder/tests/upsert_tests.py +++ b/querybuilder/tests/upsert_tests.py @@ -1,14 +1,15 @@ from django.test.utils import override_settings from django import VERSION +from django_dynamic_fixture import G from querybuilder.logger import Logger from querybuilder.query import Query -from querybuilder.tests.models import Uniques +from querybuilder.tests.models import Uniques, User from querybuilder.tests.query_tests import QueryTestCase @override_settings(DEBUG=True) -class TestUpdate(QueryTestCase): +class TestUpsert(QueryTestCase): def setUp(self): self.logger = Logger() @@ -123,3 +124,95 @@ def test_upsert(self): self.assertEqual(models[1].field5, 'not null') self.assertEqual(models[1].field6, '2.6') self.assertEqual(models[1].field7, '2.7') + + def test_upsert_pk(self): + """ + Makes sure upserting is possible when the only uniqueness constraint is the pk. + """ + user1 = G(User, email='user1') + user1.email = 'user1change' + user2 = User(email='user2') + user3 = User(email='user3') + + self.assertEqual(User.objects.count(), 1) + Query().from_table(User).upsert( + [user1, user2, user3], + unique_fields=['id'], + update_fields=['email'], + ) + self.assertEqual(User.objects.count(), 3) + + users = list(User.objects.order_by('id')) + + self.assertEqual(users[0].email, 'user1change') + self.assertEqual(users[1].email, 'user2') + self.assertEqual(users[2].email, 'user3') + + def test_upsert_pk_return_dicts(self): + """ + Makes sure upserting is possible when the only uniqueness constraint is the pk. Should return dicts. + """ + user1 = G(User, email='user1') + user1.email = 'user1change' + user2 = User(email='user2') + user3 = User(email='user3') + + self.assertEqual(User.objects.count(), 1) + rows = Query().from_table(User).upsert( + [user1, user2, user3], + unique_fields=['id'], + update_fields=['email'], + return_rows=True, + ) + self.assertEqual(User.objects.count(), 3) + self.assertEqual(len(rows), 3) + + # Check ids + for row in rows: + self.assertIsNotNone(row['id']) + + # Check emails + email_set = { + row['email'] for row in rows + } + self.assertEqual(email_set, {'user1change', 'user2', 'user3'}) + + # Check fields from db + users = list(User.objects.order_by('id')) + self.assertEqual(users[0].email, 'user1change') + self.assertEqual(users[1].email, 'user2') + self.assertEqual(users[2].email, 'user3') + + def test_upsert_pk_return_models(self): + """ + Makes sure upserting is possible when the only uniqueness constraint is the pk. Should return models. + """ + user1 = G(User, email='user1') + user1.email = 'user1change' + user2 = User(email='user2') + user3 = User(email='user3') + + self.assertEqual(User.objects.count(), 1) + records = Query().from_table(User).upsert( + [user1, user2, user3], + unique_fields=['id'], + update_fields=['email'], + return_models=True, + ) + self.assertEqual(len(records), 3) + + # Check ids + for record in records: + self.assertIsNotNone(record.id) + + # Check emails + email_set = { + record.email for record in records + } + self.assertEqual(email_set, {'user1change', 'user2', 'user3'}) + + # Check fields from db + users = list(User.objects.order_by('id')) + self.assertEqual(users[0].email, 'user1change') + self.assertEqual(users[1].email, 'user2') + self.assertEqual(users[2].email, 'user3') diff --git a/querybuilder/version.py b/querybuilder/version.py index ef91994..092052c 100644 --- a/querybuilder/version.py +++ b/querybuilder/version.py @@ -1 +1 @@ -__version__ = '0.14.0' +__version__ = '0.14.1' diff --git a/settings.py b/settings.py index ea4591b..f92d61d 100644 --- a/settings.py +++ b/settings.py @@ -13,10 +13,10 @@ def configure_settings(): if test_db is None: db_config = { 'ENGINE': 'django.db.backends.postgresql_psycopg2', - 'NAME': 'ambition_dev', - 'USER': 'ambition_dev', - 'PASSWORD': 'ambition_dev', - 'HOST': 'localhost' + 'NAME': 'ambition', + 'USER': 'ambition', + 'PASSWORD': 'ambition', + 'HOST': 'db' } elif test_db == 'postgres': db_config = { From 22b716d855b13af3e23e2624585a67f31bb36f74 Mon Sep 17 00:00:00 2001 From: Wes Okes Date: Fri, 17 Feb 2017 13:58:10 -0500 Subject: [PATCH 2/5] check auto field instead of assume pk --- querybuilder/query.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/querybuilder/query.py b/querybuilder/query.py index ca02c4b..9585e8f 100644 --- a/querybuilder/query.py +++ b/querybuilder/query.py @@ -1,7 +1,7 @@ from copy import deepcopy from django.db import connection as default_django_connection -from django.db.models import Q +from django.db.models import Q, AutoField from django.db.models.query import QuerySet from django.db.models.constants import LOOKUP_SEP try: @@ -1712,12 +1712,21 @@ def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_m return ModelClass = self.tables[0].model - pk_name = ModelClass._meta.pk.column - rows_without_pk = [] - if len(unique_fields) == 1 and unique_fields[0] == pk_name: - rows_without_pk = [row for row in rows if getattr(row, pk_name) is None] - rows = [row for row in rows if getattr(row, pk_name) is not None] + rows_with_null_auto_field_value = [] + + # Get auto field name (a model can only have one AutoField) + auto_field_name = None + for field in ModelClass._meta.fields: + if isinstance(field, AutoField): + auto_field_name = field.column + break + + # Check if unique fields list contains an auto field + if auto_field_name in set(unique_fields): + # Separate the rows that need to be inserted vs the rows that need to be upserted + rows_with_null_auto_field_value = [row for row in rows if getattr(row, auto_field_name) is None] + rows = [row for row in rows if getattr(row, auto_field_name) is not None] sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields) @@ -1729,8 +1738,8 @@ def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_m # execute the bulk create query if needed bulk_created_records = [] - if rows_without_pk: - bulk_created_records = ModelClass.objects.bulk_create(rows_without_pk) + if rows_with_null_auto_field_value: + bulk_created_records = ModelClass.objects.bulk_create(rows_with_null_auto_field_value) if return_rows: return self._fetch_all_as_dict(cursor) + [ From 5c0b3582f766281f09ecbcca7ba14a6ebea0cf72 Mon Sep 17 00:00:00 2001 From: Wes Okes Date: Fri, 17 Feb 2017 16:28:52 -0500 Subject: [PATCH 3/5] properly use auto field --- querybuilder/query.py | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/querybuilder/query.py b/querybuilder/query.py index 9585e8f..749d759 100644 --- a/querybuilder/query.py +++ b/querybuilder/query.py @@ -1185,7 +1185,7 @@ def get_update_sql(self, rows): return self.sql, sql_args - def get_upsert_sql(self, rows, unique_fields, update_fields): + def get_upsert_sql(self, rows, unique_fields, update_fields, auto_field_name=None): """ Generates the postgres specific sql necessary to perform an upsert (ON CONFLICT) @@ -1194,12 +1194,11 @@ def get_upsert_sql(self, rows, unique_fields, update_fields): ON CONFLICT (unique_field) DO UPDATE SET field2 = EXCLUDED.field2; """ ModelClass = self.tables[0].model - pk_name = ModelClass._meta.pk.column # Use all fields except pk unless the uniqueness constraint is the pk field. Null pk field rows will be # excluded in the upsert method before calling this method - all_fields = [field for field in ModelClass._meta.fields if field.column != pk_name] - if len(unique_fields) == 1 and unique_fields[0] == pk_name: + all_fields = [field for field in ModelClass._meta.fields if field.column != auto_field_name] + if auto_field_name in unique_fields: all_fields = [field for field in ModelClass._meta.fields] all_field_names = [field.column for field in all_fields] @@ -1702,6 +1701,17 @@ def update(self, rows): # execute the query cursor.execute(sql, sql_args) + def get_auto_field_name(self, model_class): + """ + If one of the unique_fields is the model's AutoField, return the field name, otherwise return None + """ + # Get auto field name (a model can only have one AutoField) + for field in model_class._meta.fields: + if isinstance(field, AutoField): + return field.column + + return None + def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_models=False): """ Performs an upsert with the set of models defined in rows. If the unique field which is meant @@ -1716,25 +1726,22 @@ def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_m rows_with_null_auto_field_value = [] # Get auto field name (a model can only have one AutoField) - auto_field_name = None - for field in ModelClass._meta.fields: - if isinstance(field, AutoField): - auto_field_name = field.column - break + auto_field_name = self.get_auto_field_name(ModelClass) # Check if unique fields list contains an auto field - if auto_field_name in set(unique_fields): + if auto_field_name in unique_fields: # Separate the rows that need to be inserted vs the rows that need to be upserted rows_with_null_auto_field_value = [row for row in rows if getattr(row, auto_field_name) is None] rows = [row for row in rows if getattr(row, auto_field_name) is not None] - sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields) + if rows: + sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields, auto_field_name) - # get the cursor to execute the query - cursor = self.get_cursor() + # get the cursor to execute the query + cursor = self.get_cursor() - # execute the upsert query - cursor.execute(sql, sql_args) + # execute the upsert query + cursor.execute(sql, sql_args) # execute the bulk create query if needed bulk_created_records = [] @@ -1742,13 +1749,14 @@ def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_m bulk_created_records = ModelClass.objects.bulk_create(rows_with_null_auto_field_value) if return_rows: - return self._fetch_all_as_dict(cursor) + [ + upserted_rows = self._fetch_all_as_dict(cursor) if rows else [] + return upserted_rows + [ record.__dict__ for record in bulk_created_records ] if return_models: - row_dicts = self._fetch_all_as_dict(cursor) + row_dicts = self._fetch_all_as_dict(cursor) if rows else [] ModelClass = self.tables[0].model model_objects = [ ModelClass(**row_dict) From 1d0c85700cd5fc82059b25580e1ac677d1f990ba Mon Sep 17 00:00:00 2001 From: Wes Okes Date: Mon, 20 Feb 2017 11:43:11 -0500 Subject: [PATCH 4/5] handle bulk create for older djangos --- querybuilder/query.py | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/querybuilder/query.py b/querybuilder/query.py index 749d759..96eecde 100644 --- a/querybuilder/query.py +++ b/querybuilder/query.py @@ -1185,7 +1185,7 @@ def get_update_sql(self, rows): return self.sql, sql_args - def get_upsert_sql(self, rows, unique_fields, update_fields, auto_field_name=None): + def get_upsert_sql(self, rows, unique_fields, update_fields, auto_field_name=None, only_insert=False): """ Generates the postgres specific sql necessary to perform an upsert (ON CONFLICT) @@ -1198,7 +1198,7 @@ def get_upsert_sql(self, rows, unique_fields, update_fields, auto_field_name=Non # Use all fields except pk unless the uniqueness constraint is the pk field. Null pk field rows will be # excluded in the upsert method before calling this method all_fields = [field for field in ModelClass._meta.fields if field.column != auto_field_name] - if auto_field_name in unique_fields: + if auto_field_name in unique_fields and not only_insert: all_fields = [field for field in ModelClass._meta.fields] all_field_names = [field.column for field in all_fields] @@ -1734,8 +1734,10 @@ def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_m rows_with_null_auto_field_value = [row for row in rows if getattr(row, auto_field_name) is None] rows = [row for row in rows if getattr(row, auto_field_name) is not None] + return_value = [] + if rows: - sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields, auto_field_name) + sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields, auto_field_name=auto_field_name) # get the cursor to execute the query cursor = self.get_cursor() @@ -1743,24 +1745,32 @@ def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_m # execute the upsert query cursor.execute(sql, sql_args) - # execute the bulk create query if needed - bulk_created_records = [] + if return_rows or return_models: + return_value.extend(self._fetch_all_as_dict(cursor)) + if rows_with_null_auto_field_value: - bulk_created_records = ModelClass.objects.bulk_create(rows_with_null_auto_field_value) + sql, sql_args = self.get_upsert_sql( + rows_with_null_auto_field_value, + unique_fields, + update_fields, + auto_field_name=auto_field_name, + only_insert=True, + ) - if return_rows: - upserted_rows = self._fetch_all_as_dict(cursor) if rows else [] - return upserted_rows + [ - record.__dict__ - for record in bulk_created_records - ] + # get the cursor to execute the query + cursor = self.get_cursor() + + # execute the upsert query + cursor.execute(sql, sql_args) + + if return_rows or return_models: + return_value.extend(self._fetch_all_as_dict(cursor)) if return_models: - row_dicts = self._fetch_all_as_dict(cursor) if rows else [] ModelClass = self.tables[0].model model_objects = [ ModelClass(**row_dict) - for row_dict in row_dicts + for row_dict in return_value ] # Set the state to indicate the object has been loaded from db @@ -1768,9 +1778,7 @@ def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_m model_object._state.adding = False model_object._state.db = 'default' - return model_objects + bulk_created_records - - return [] + return return_value def sql_delete(self): """ From 613e66cf1ac20cf22006f5e843e21ed81d8aab65 Mon Sep 17 00:00:00 2001 From: Wes Okes Date: Mon, 20 Feb 2017 12:52:55 -0500 Subject: [PATCH 5/5] fix return models --- querybuilder/query.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/querybuilder/query.py b/querybuilder/query.py index 96eecde..8303d2c 100644 --- a/querybuilder/query.py +++ b/querybuilder/query.py @@ -1778,6 +1778,8 @@ def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_m model_object._state.adding = False model_object._state.db = 'default' + return_value = model_objects + return return_value def sql_delete(self):