From d758db850b8ba685059ec1d7cfba2ab0cb07d2df Mon Sep 17 00:00:00 2001 From: Kari Noriy Date: Fri, 8 Dec 2023 13:11:04 +0000 Subject: [PATCH] reduced duplicate code in api --- .../audio_recorder/views.py | 134 +++++++++--------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/django_dataset_collection_tool/audio_recorder/views.py b/django_dataset_collection_tool/audio_recorder/views.py index 4b0ddbe..1dd55cf 100644 --- a/django_dataset_collection_tool/audio_recorder/views.py +++ b/django_dataset_collection_tool/audio_recorder/views.py @@ -6,7 +6,7 @@ import pandas as pd from collections import defaultdict -from django.http import HttpResponse, HttpResponseForbidden +from django.http import HttpResponse, HttpResponseForbidden, StreamingHttpResponse from django.http.response import JsonResponse from django.shortcuts import render, get_object_or_404, redirect @@ -581,86 +581,86 @@ def get(self, request, *args, **kwargs): return render(request, 'audio_recorder/stats.html', context) -class DownloadView(APIView): - authentication_classes = [TokenAuthGet] - permission_classes = [IsAuthenticated, CanUsePaidParameter] - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] - - def get(self, request, *args, **kwargs): - # Add a download parameter check - download_param = request.GET.get('download', None) - if download_param not in [None, 'csv', 'parquet']: - raise PermissionDenied(detail="Invalid 'download' parameter value.") - - # Query the database for all utterances - utterances = Utterances.objects.filter(status='Awaiting Review').values() - if download_param == 'csv': - return self.create_csv_response(utterances) - - return self.create_parquet_response(utterances) - - def create_csv_response(self, data): - response = HttpResponse(content_type='text/csv') - response['Content-Disposition'] = 'attachment; filename="utterances.csv"' - - if data: - writer = csv.writer(response) - writer.writerow(data[0].keys()) # column headers - for item in data: - writer.writerow(item.values()) - - return response - - def create_parquet_response(self, data): - df = pd.DataFrame(data) - response = HttpResponse(content_type='application/octet-stream') - response['Content-Disposition'] = 'attachment; filename="utterances.parquet"' - if not df.empty: - df.to_parquet(response, index=False) - - return response +class Echo: + """An object that implements just the write method of the file-like interface.""" + def write(self, value): + """Write the value by returning it, instead of storing in a buffer.""" + return value + +class DownloadMixin: + def create_csv_response(self, data, filename='utterances.csv'): + pseudo_buffer = Echo() + writer = csv.writer(pseudo_buffer) + response = StreamingHttpResponse((writer.writerow(item.values()) for item in data), content_type="text/csv") + response['Content-Disposition'] = f'attachment; filename="{filename}"' + return response + + def create_parquet_response(self, data, filename='utterances.parquet'): + df = pd.DataFrame(list(data)) + response = HttpResponse(content_type='application/octet-stream') + response['Content-Disposition'] = f'attachment; filename="{filename}"' + df.to_parquet(response, index=False) + return response + +class DownloadView(APIView, DownloadMixin): + authentication_classes = [TokenAuthGet] + permission_classes = [IsAuthenticated, CanUsePaidParameter] + throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + + def get(self, request, *args, **kwargs): + download_param = request.GET.get('download', None) + if download_param not in [None, 'csv', 'parquet']: + raise PermissionDenied(detail="Invalid 'download' parameter value.") + + utterances = Utterances.objects.filter(status='Awaiting Review').values() + if download_param == 'csv': + return self.create_csv_response(utterances) + elif download_param == 'parquet': + return self.create_parquet_response(utterances) + else: + return RestResponse(utterances) -class GetUtterancesURLsView(APIView): +class GetUtterancesURLsView(APIView, DownloadMixin): authentication_classes = [TokenAuthGet] permission_classes = [IsAuthenticated] throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] def get(self, request, *args, **kwargs): - # Retrieve the status parameter from the query string + status_param, limit_param, download_param = self.get_parameters(request) + utterances = self.query_utterances(status_param, limit_param) + + if download_param == 'csv': + return self.create_csv_response(utterances) + elif download_param == 'parquet': + return self.create_parquet_response(utterances) + else: + response_data = [{'id': utt['pk'], 'url': utt['audio_recording']} for utt in utterances] + return RestResponse(response_data) + + def get_parameters(self, request): + # Retrieve parameters status_param = request.GET.get('status', 'Awaiting Review') + limit_param = request.GET.get('limit', None) + download_param = request.GET.get('download', None) - # Validate the status parameter + # Validate status valid_statuses = ['Pending', 'Awaiting Review', 'Complete', 'Needs Updating'] if status_param not in valid_statuses: - return RestResponse({'error': 'Invalid status parameter'}, status=400) + raise ValidationError('Invalid status parameter') - # Retrieve and validate the limit parameter - limit_param = request.GET.get('limit', None) + # Validate and convert limit if limit_param: try: - limit = int(limit_param) - MaxValueValidator(1000)(limit) # Assuming a maximum limit of 1000 + limit_param = int(limit_param) + MaxValueValidator(1000)(limit_param) # Max limit except (ValueError, ValidationError): - return RestResponse({'error': 'Invalid limit parameter'}, status=400) - else: - limit = None - - # Query the database for utterances with the specified status - query = Utterances.objects.filter(status=status_param) - if limit: - query = query[:limit] - utterances = query.values('pk', 'audio_recording') - - # Prepare the data for response - response_data = [ - { - 'id': utterance['pk'], - 'url': utterance['audio_recording'] - } - for utterance in utterances - ] - - return RestResponse(response_data) + raise ValidationError('Invalid limit parameter') + + return status_param, limit_param, download_param + + def query_utterances(self, status, limit): + query = Utterances.objects.filter(status=status).values('pk', 'audio_recording') + return query[:limit] if limit else query def report_utterance(request, utterance_id): utterance = get_object_or_404(Utterances, pk=utterance_id)