From dbfa8029f92ee499f3877e9982bb4cd50f8fdbb8 Mon Sep 17 00:00:00 2001 From: Elias Gabriel Date: Wed, 28 Feb 2024 15:21:39 -0500 Subject: [PATCH] [SNC-14] add: pagination and filtering to ListDatasets (#279) * add: filterable and pageable datasets * add: unit test * fix: happy path * fix: import fixture --------- Co-authored-by: micah cahill --- indico/filters/__init__.py | 18 +++++- indico/queries/datasets.py | 70 ++++++++++++++++++----- tests/integration/queries/test_dataset.py | 25 +++++++- 3 files changed, 94 insertions(+), 19 deletions(-) diff --git a/indico/filters/__init__.py b/indico/filters/__init__.py index 1e446d6d..b24af746 100644 --- a/indico/filters/__init__.py +++ b/indico/filters/__init__.py @@ -28,6 +28,22 @@ def __init__(self, **kwargs): self.update(and_(kwargs) if len(kwargs) > 1 else kwargs) +class DatasetFilter(Filter): + """ + Create a Filter when querying for Datasets via datasetsPage. + + Args: + name (str): dataset name by which to filter + Returns: + dict containing query filter parameters + """ + + __options__ = ("name",) + + def __init__(self, name: str): + super().__init__(name=name) + + class SubmissionReviewFilter(Filter): __options__ = ("rejected", "created_by", "review_type") @@ -164,7 +180,6 @@ def __init__( updated_at_start_date: datetime = None, updated_at_end_date: datetime = None, ): - kwargs = {"workflowId": workflow_id, "id": submission_id, "status": status} if created_at_end_date and not created_at_start_date: raise IndicoInputError("Must specify created_at_start_date") @@ -176,7 +191,6 @@ def __init__( else datetime.datetime.now().strftime("%Y-%m-%d"), } - if updated_at_end_date and not updated_at_start_date: raise IndicoInputError("Must specify updated_at_start_date") if updated_at_start_date is not None: diff --git a/indico/queries/datasets.py b/indico/queries/datasets.py index b12265f3..d7c8a562 100644 --- a/indico/queries/datasets.py +++ b/indico/queries/datasets.py @@ -4,7 +4,7 @@ import jsons import tempfile from pathlib import Path -from typing import List +from typing import List, Union, Dict, Optional import pandas as pd import deprecation @@ -14,6 +14,7 @@ GraphQLRequest, HTTPMethod, HTTPRequest, + PagedRequest, RequestChain, ) from indico.errors import IndicoNotFound, IndicoInputError @@ -25,9 +26,10 @@ ReadApiOcrOptionsInput, OcrInputLanguage, ) +from indico.filters import DatasetFilter -class ListDatasets(GraphQLRequest): +class ListDatasets(PagedRequest): """ List all of your datasets @@ -39,21 +41,52 @@ class ListDatasets(GraphQLRequest): """ query = """ - query ListDatasets($limit: Int){ - datasetsPage(limit: $limit) { + query ListDatasets( + $filters: DatasetFilter, + $limit: Int, + $orderBy: DATASET_COLUMN_ENUM, + $desc: Boolean, + $after: Int + ){ + datasetsPage( + filters: $filters, + limit: $limit + orderBy: $orderBy, + desc: $desc, + after: $after + ) { datasets { id name rowCount } + pageInfo { + hasNextPage + endCursor + } } } """ - def __init__(self, *, limit: int = 100): - super().__init__(self.query, variables={"limit": limit}) + def __init__( + self, + *, + filters: Optional[Union[Dict, DatasetFilter]] = None, + limit: int = 100, + order_by: str = "ID", + desc: bool = False, + ): + super().__init__( + self.query, + variables={ + "filters": filters, + "limit": limit, + "orderBy": order_by, + "desc": desc, + }, + ) - def process_response(self, response) -> Dataset: + def process_response(self, response) -> List[Dataset]: response = super().process_response(response) return [Dataset(**dataset) for dataset in response["datasetsPage"]["datasets"]] @@ -240,7 +273,9 @@ def requests(self): omnipage_ocr_options=self.omnipage_ocr_options, ocr_engine=self.ocr_engine, ) - yield _AddFiles(dataset_id=self.previous.id, metadata=file_metadata, autoprocess=True) + yield _AddFiles( + dataset_id=self.previous.id, metadata=file_metadata, autoprocess=True + ) dataset_id = self.previous.id yield GetDatasetFileStatus(id=dataset_id) debouncer = Debouncer() @@ -441,7 +476,9 @@ def requests(self): yield GetDatasetFileStatus(id=self.dataset_id) if self.wait: debouncer = Debouncer() - while not all(f.status in self.expected_statuses for f in self.previous.files): + while not all( + f.status in self.expected_statuses for f in self.previous.files + ): yield GetDatasetFileStatus(id=self.previous.id) debouncer.backoff() @@ -492,8 +529,10 @@ def __init__(self, dataset_id: int, datafile_ids: List[int]): def process_response(self, response): return Dataset(**super().process_response(response)["addDataCsv"]) -@deprecation.deprecated(deprecated_in="5.3", - details="Use AddFiles wtih autoprocess=True instead") + +@deprecation.deprecated( + deprecated_in="5.3", details="Use AddFiles wtih autoprocess=True instead" +) class ProcessFiles(RequestChain): """ Process files associated with a dataset and add corresponding data to the dataset @@ -529,8 +568,10 @@ def requests(self): yield GetDatasetFileStatus(id=self.dataset_id) debouncer.backoff() -@deprecation.deprecated(deprecated_in="5.3", - details="Use AddFiles wtih autoprocess=True instead") + +@deprecation.deprecated( + deprecated_in="5.3", details="Use AddFiles wtih autoprocess=True instead" +) class ProcessCSV(RequestChain): """ Process CSV associated with a dataset and add corresponding data to the dataset @@ -566,6 +607,7 @@ class GetAvailableOcrEngines(GraphQLRequest): """ Fetches and lists the available OCR engines """ + query = """query{ ocrOptions { engines{ @@ -581,6 +623,7 @@ def process_response(self, response): engines = super().process_response(response)["ocrOptions"]["engines"] return [OcrEngine[e["name"]] for e in engines] + class GetOcrEngineLanguageCodes(GraphQLRequest): """ Fetches and lists the available languages by name and code for the given OCR Engine @@ -601,7 +644,6 @@ class GetOcrEngineLanguageCodes(GraphQLRequest): } }""" - def __init__(self, engine: OcrEngine): self.engine = engine super().__init__(self.query) diff --git a/tests/integration/queries/test_dataset.py b/tests/integration/queries/test_dataset.py index 7d807cde..a6bdff18 100644 --- a/tests/integration/queries/test_dataset.py +++ b/tests/integration/queries/test_dataset.py @@ -1,7 +1,7 @@ import time import json +from indico.filters import DatasetFilter import pytest -import unittest import pandas as pd from pathlib import Path import os @@ -15,7 +15,6 @@ CreateEmptyDataset, AddFiles, ProcessFiles, - ProcessCSV, ) from indico.queries.export import CreateExport, DownloadExport from indico.types.dataset import ( @@ -26,7 +25,7 @@ ReadApiOcrOptionsInput, ) from indico.errors import IndicoRequestError -from tests.integration.data.datasets import airlines_dataset +from tests.integration.data.datasets import airlines_dataset # noqa: F401 def test_create_dataset(indico): @@ -78,6 +77,26 @@ def test_list_datasets(indico, airlines_dataset): assert type(datasets[0]) == Dataset +def test_list_datasets_filtered(indico, airlines_dataset): + client = IndicoClient() + datasets = client.call( + ListDatasets(filters=DatasetFilter(name=f"{time.time()}_bananas")) + ) + + assert isinstance(datasets, list) + assert len(datasets) == 0 + + # happy path + datasets = client.call( + ListDatasets(filters=DatasetFilter(name=airlines_dataset.name)) + ) + + assert isinstance(datasets, list) + assert len(datasets) == 1 + assert type(datasets[0]) == Dataset + assert datasets[0].name == airlines_dataset.name + + def test_images(indico): client = IndicoClient()