diff --git a/rest_framework_filters/filters.py b/rest_framework_filters/filters.py index 9d313dc..48fd6ad 100644 --- a/rest_framework_filters/filters.py +++ b/rest_framework_filters/filters.py @@ -44,9 +44,10 @@ class BlogFilter(filters.FilterSet): """ creation_counter = 0 - def __init__(self, field_name=None, *, lookups=None): + def __init__(self, field_name=None, *, lookups=None, method=None): self.field_name = field_name self.lookups = lookups or [] + self.method = method self.creation_counter = AutoFilter.creation_counter AutoFilter.creation_counter += 1 diff --git a/rest_framework_filters/filterset.py b/rest_framework_filters/filterset.py index a08ac42..9fcfdfd 100644 --- a/rest_framework_filters/filterset.py +++ b/rest_framework_filters/filterset.py @@ -103,6 +103,16 @@ def expand_auto_filter(cls, new_class, filter_name, f): # replace the field name with the param name from the filerset gen_name = gen_name.replace(f.field_name, filter_name, 1) + if f.method: + # Override method for auto-generated filters. + gen_f.method = f.method + + # Skip if lookup expression is `exact` since it is equivalent to no lookup + if gen_f.lookup_expr != "exact": + # Update field name to also include lookup expr. + gen_f.field_name = "{field_name}__{lookup_expr}".format(field_name=f.field_name, + lookup_expr=gen_f.lookup_expr) + # do not overwrite declared filters if gen_name not in orig_declared: expanded[gen_name] = gen_f diff --git a/tests/test_filtering.py b/tests/test_filtering.py index e9d3dbd..6279e19 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -1,3 +1,4 @@ +from django.db.models import CharField, Value from django.test import TestCase from django_filters import FilterSet as DFFilterSet @@ -78,6 +79,62 @@ class Meta: f = Subclass(GET, queryset=Note.objects.all()) self.assertEqual(len(list(f.qs)), 2) + def test_autofilter_with_method(self): + # Test that method param applies to all auto-generated filters. + def filter_iexact(qs, field_name, value): + # Test that the field name contains the lookup expression. + self.assertEqual(field_name, 'content__icontains') + + return qs.filter(**{field_name: value}).annotate(checksum=Value("3", output_field=CharField())) + + class Actual(FilterSet): + title = filters.AutoFilter(lookups='__all__', method='filter_title') + content = filters.AutoFilter(lookups=['icontains'], method=filter_iexact) + author = filters.AutoFilter(lookups='__all__', field_name='author__username', method='filter_author') + + class Meta: + model = Note + fields = [] + + def filter_title(self, qs, field_name, value): + return qs.filter(**{field_name: value}).annotate(checksum=Value("1", output_field=CharField())) + + def filter_author(self, qs, field_name, value): + return qs.filter(**{field_name: value}).annotate(checksum=Value("2", output_field=CharField())) + + # Test method as a function + GET = {'content__icontains': 'test content'} + f = Actual(GET, queryset=Note.objects.all()) + self.assertEqual(len(list(f.qs)), 4) + self.assertEqual(f.qs[0].checksum, "3") + + # Test method as a string reference to filterset method + GET = {'title__contains': 'Hello'} + f = Actual(GET, queryset=Note.objects.all()) + self.assertEqual(len(list(f.qs)), 2) + self.assertEqual(f.qs[0].checksum, "1") + + GET = {'title__iendswith': '4'} + f = Actual(GET, queryset=Note.objects.all()) + self.assertEqual(len(list(f.qs)), 1) + self.assertEqual(f.qs[0].checksum, "1") + + GET = {'title': 'Hello Test 3'} + f = Actual(GET, queryset=Note.objects.all()) + self.assertEqual(len(list(f.qs)), 1) + self.assertEqual(f.qs[0].checksum, "1") + + # Test method in Autofilter on related field + GET = {'author__contains': 'user2'} + f = Actual(GET, queryset=Note.objects.all()) + self.assertEqual(len(list(f.qs)), 1) + self.assertEqual(f.qs[0].checksum, "2") + + GET = {'author': 'user2'} + f = Actual(GET, queryset=Note.objects.all()) + self.assertEqual(len(list(f.qs)), 1) + self.assertEqual(f.qs[0].checksum, "2") + class RelatedFilterTests(TestCase): diff --git a/tests/test_filterset.py b/tests/test_filterset.py index 9278f88..8bbd0b8 100644 --- a/tests/test_filterset.py +++ b/tests/test_filterset.py @@ -261,6 +261,33 @@ class F(FilterSet): self.assertEqual(str(w[0].message), message) self.assertIs(w[0].category, DeprecationWarning) + def test_autofilter_can_be_generated_with_method(self): + # ensure AutoFilters are generated with the provided method. + def external_method(instance, qs, field, value): + pass + + class F(FilterSet): + id = filters.AutoFilter(lookups='__all__', method='filterset_method') + title = filters.AutoFilter(lookups=['exact'], method=external_method) + author = filters.AutoFilter(field_name='author__last_name', lookups='__all__', method='related_method') + + class Meta: + model = Note + fields = [] + + def filterset_method(self, qs, field, value): + pass + + def related_method(self, qs, field, value): + pass + + for field_name, lookup_filter in F.base_filters.items(): + # Ensure field name on filter is overridden to include lookup expression. + if lookup_filter.lookup_expr != 'exact': + self.assertTrue(lookup_filter.field_name.endswith(lookup_filter.lookup_expr)) + + self.assertIsNotNone(lookup_filter._method) + class GetRelatedFiltersetsTests(TestCase):