Skip to content

Commit

Permalink
Respect return_models in upsert method when building upsert sql
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredlewis committed Apr 21, 2017
1 parent c88c135 commit aac80b2
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
4 changes: 4 additions & 0 deletions docs/release_notes.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Release Notes
=============

v0.14.3
-------
* Respect return_models in upsert method when building upsert sql

v0.14.2
-------
* Fix upsert to use the proper prepare method on django fields
Expand Down
27 changes: 21 additions & 6 deletions querybuilder/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,15 @@ def get_update_sql(self, rows):

return self.sql, sql_args

def get_upsert_sql(self, rows, unique_fields, update_fields, auto_field_name=None, only_insert=False):
def get_upsert_sql(
self,
rows,
unique_fields,
update_fields,
auto_field_name=None,
only_insert=False,
return_rows=True
):
"""
Generates the postgres specific sql necessary to perform an upsert (ON CONFLICT)
Expand Down Expand Up @@ -1235,22 +1243,22 @@ def get_upsert_sql(self, rows, unique_fields, update_fields, auto_field_name=Non
row_values_sql = ', '.join(row_values)

if update_fields:
self.sql = 'INSERT INTO {0} ({1}) VALUES {2} ON CONFLICT ({3}) DO UPDATE SET {4} RETURNING {5}'.format(
self.sql = 'INSERT INTO {0} ({1}) VALUES {2} ON CONFLICT ({3}) DO UPDATE SET {4} {5}'.format(
self.tables[0].get_identifier(),
all_field_names_sql,
row_values_sql,
unique_field_names_sql,
update_fields_sql,
'*'
'RETURNING *' if return_rows else ''
)
else:
self.sql = 'INSERT INTO {0} ({1}) VALUES {2} ON CONFLICT ({3}) {4} RETURNING {5}'.format(
self.sql = 'INSERT INTO {0} ({1}) VALUES {2} ON CONFLICT ({3}) {4} {5}'.format(
self.tables[0].get_identifier(),
all_field_names_sql,
row_values_sql,
unique_field_names_sql,
'DO UPDATE SET {0}=EXCLUDED.{0}'.format(unique_fields[0].column),
'*'
'RETURNING *' if return_rows else ''
)

return self.sql, sql_args
Expand Down Expand Up @@ -1737,7 +1745,13 @@ def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_m
return_value = []

if rows:
sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields, auto_field_name=auto_field_name)
sql, sql_args = self.get_upsert_sql(
rows,
unique_fields,
update_fields,
auto_field_name=auto_field_name,
return_rows=return_models
)

# get the cursor to execute the query
cursor = self.get_cursor()
Expand All @@ -1755,6 +1769,7 @@ def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_m
update_fields,
auto_field_name=auto_field_name,
only_insert=True,
return_rows=return_models
)

# get the cursor to execute the query
Expand Down
2 changes: 1 addition & 1 deletion querybuilder/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.14.2'
__version__ = '0.14.3'

0 comments on commit aac80b2

Please sign in to comment.