diff --git a/docs/release_notes.rst b/docs/release_notes.rst index 30198fc..c108eaa 100644 --- a/docs/release_notes.rst +++ b/docs/release_notes.rst @@ -1,6 +1,10 @@ Release Notes ============= +v3.1.3 +------ +* Added object ordering to upserts and updates to help reduce the number of deadlocks + v3.1.2 ------ * Bump django-query-builder for psycopg3 support diff --git a/manager_utils/manager_utils.py b/manager_utils/manager_utils.py index e4cb65d..77e8f88 100644 --- a/manager_utils/manager_utils.py +++ b/manager_utils/manager_utils.py @@ -1,7 +1,7 @@ import itertools from typing import List -from django.db import connection +from django.db import connection, connections, models from django.db.models import Manager, Model from django.db.models.query import QuerySet from django.dispatch import Signal @@ -86,6 +86,40 @@ def _fetch_models_by_pk(queryset: QuerySet, models: List[Model]) -> List[Model]: ) +def _get_field_db_val(queryset, field, value, connection): + if hasattr(value, "resolve_expression"): # pragma: no cover + # Handle cases when the field is of type "Func" and other expressions. + # This is useful for libraries like django-rdkit that can't easily be tested + return value.resolve_expression(queryset.query, allow_joins=False, for_save=True) + else: + return field.get_db_prep_save(value, connection) + + +def _model_fields(model: models.Model) -> List[models.Field]: + """Return the fields of a model, excluding generated and non-concrete ones.""" + return [f for f in model._meta.fields if not getattr(f, "generated", False) and f.concrete] + + +def _sort_by_unique_fields(queryset, model_objs, unique_fields): + """ + Sort a list of models by their unique fields. + + Sorting models in an upsert greatly reduces the chances of deadlock + when doing concurrent upserts + """ + model = queryset.model + connection = connections[queryset.db] + unique_fields = [field for field in _model_fields(model) if field.attname in unique_fields] + + def sort_key(model_obj): + return tuple( + _get_field_db_val(queryset, field, getattr(model_obj, field.attname), connection) + for field in unique_fields + ) + + return sorted(model_objs, key=sort_key) + + def bulk_upsert( queryset, model_objs, unique_fields, update_fields=None, return_upserts=False, return_upserts_distinct=False, sync=False, native=False @@ -190,6 +224,9 @@ def bulk_upsert( raise ValueError('Must provide unique_fields argument') update_fields = update_fields or [] + # Sore the models to prevent deadlocks + model_objs = _sort_by_unique_fields(queryset, model_objs, unique_fields) + if native: if return_upserts_distinct: raise NotImplementedError('return upserts distinct not supported with native postgres upsert') @@ -500,6 +537,9 @@ def bulk_update(manager, model_objs, fields_to_update): """ + # Sort the model objects to reduce the likelihood of deadlocks + model_objs = sorted(model_objs, key=lambda obj: obj.pk) + # Add the pk to the value fields so we can join value_fields = [manager.model._meta.pk.attname] + fields_to_update diff --git a/manager_utils/version.py b/manager_utils/version.py index f71b21a..dc0a25b 100644 --- a/manager_utils/version.py +++ b/manager_utils/version.py @@ -1 +1 @@ -__version__ = '3.1.2' +__version__ = '3.1.3'