Skip to content

Commit

Permalink
[SNC-14] add: pagination and filtering to ListDatasets (#279)
Browse files Browse the repository at this point in the history
* add: filterable and pageable datasets

* add: unit test

* fix: happy path

* fix: import fixture

---------

Co-authored-by: micah cahill <[email protected]>
  • Loading branch information
thearchitector and goatrocks authored Feb 28, 2024
1 parent 87d4deb commit dbfa802
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 19 deletions.
18 changes: 16 additions & 2 deletions indico/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ def __init__(self, **kwargs):
self.update(and_(kwargs) if len(kwargs) > 1 else kwargs)


class DatasetFilter(Filter):
"""
Create a Filter when querying for Datasets via datasetsPage.
Args:
name (str): dataset name by which to filter
Returns:
dict containing query filter parameters
"""

__options__ = ("name",)

def __init__(self, name: str):
super().__init__(name=name)


class SubmissionReviewFilter(Filter):
__options__ = ("rejected", "created_by", "review_type")

Expand Down Expand Up @@ -164,7 +180,6 @@ def __init__(
updated_at_start_date: datetime = None,
updated_at_end_date: datetime = None,
):

kwargs = {"workflowId": workflow_id, "id": submission_id, "status": status}
if created_at_end_date and not created_at_start_date:
raise IndicoInputError("Must specify created_at_start_date")
Expand All @@ -176,7 +191,6 @@ def __init__(
else datetime.datetime.now().strftime("%Y-%m-%d"),
}


if updated_at_end_date and not updated_at_start_date:
raise IndicoInputError("Must specify updated_at_start_date")
if updated_at_start_date is not None:
Expand Down
70 changes: 56 additions & 14 deletions indico/queries/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jsons
import tempfile
from pathlib import Path
from typing import List
from typing import List, Union, Dict, Optional

import pandas as pd
import deprecation
Expand All @@ -14,6 +14,7 @@
GraphQLRequest,
HTTPMethod,
HTTPRequest,
PagedRequest,
RequestChain,
)
from indico.errors import IndicoNotFound, IndicoInputError
Expand All @@ -25,9 +26,10 @@
ReadApiOcrOptionsInput,
OcrInputLanguage,
)
from indico.filters import DatasetFilter


class ListDatasets(GraphQLRequest):
class ListDatasets(PagedRequest):
"""
List all of your datasets
Expand All @@ -39,21 +41,52 @@ class ListDatasets(GraphQLRequest):
"""

query = """
query ListDatasets($limit: Int){
datasetsPage(limit: $limit) {
query ListDatasets(
$filters: DatasetFilter,
$limit: Int,
$orderBy: DATASET_COLUMN_ENUM,
$desc: Boolean,
$after: Int
){
datasetsPage(
filters: $filters,
limit: $limit
orderBy: $orderBy,
desc: $desc,
after: $after
) {
datasets {
id
name
rowCount
}
pageInfo {
hasNextPage
endCursor
}
}
}
"""

def __init__(self, *, limit: int = 100):
super().__init__(self.query, variables={"limit": limit})
def __init__(
self,
*,
filters: Optional[Union[Dict, DatasetFilter]] = None,
limit: int = 100,
order_by: str = "ID",
desc: bool = False,
):
super().__init__(
self.query,
variables={
"filters": filters,
"limit": limit,
"orderBy": order_by,
"desc": desc,
},
)

def process_response(self, response) -> Dataset:
def process_response(self, response) -> List[Dataset]:
response = super().process_response(response)
return [Dataset(**dataset) for dataset in response["datasetsPage"]["datasets"]]

Expand Down Expand Up @@ -240,7 +273,9 @@ def requests(self):
omnipage_ocr_options=self.omnipage_ocr_options,
ocr_engine=self.ocr_engine,
)
yield _AddFiles(dataset_id=self.previous.id, metadata=file_metadata, autoprocess=True)
yield _AddFiles(
dataset_id=self.previous.id, metadata=file_metadata, autoprocess=True
)
dataset_id = self.previous.id
yield GetDatasetFileStatus(id=dataset_id)
debouncer = Debouncer()
Expand Down Expand Up @@ -441,7 +476,9 @@ def requests(self):
yield GetDatasetFileStatus(id=self.dataset_id)
if self.wait:
debouncer = Debouncer()
while not all(f.status in self.expected_statuses for f in self.previous.files):
while not all(
f.status in self.expected_statuses for f in self.previous.files
):
yield GetDatasetFileStatus(id=self.previous.id)
debouncer.backoff()

Expand Down Expand Up @@ -492,8 +529,10 @@ def __init__(self, dataset_id: int, datafile_ids: List[int]):
def process_response(self, response):
return Dataset(**super().process_response(response)["addDataCsv"])

@deprecation.deprecated(deprecated_in="5.3",
details="Use AddFiles wtih autoprocess=True instead")

@deprecation.deprecated(
deprecated_in="5.3", details="Use AddFiles wtih autoprocess=True instead"
)
class ProcessFiles(RequestChain):
"""
Process files associated with a dataset and add corresponding data to the dataset
Expand Down Expand Up @@ -529,8 +568,10 @@ def requests(self):
yield GetDatasetFileStatus(id=self.dataset_id)
debouncer.backoff()

@deprecation.deprecated(deprecated_in="5.3",
details="Use AddFiles wtih autoprocess=True instead")

@deprecation.deprecated(
deprecated_in="5.3", details="Use AddFiles wtih autoprocess=True instead"
)
class ProcessCSV(RequestChain):
"""
Process CSV associated with a dataset and add corresponding data to the dataset
Expand Down Expand Up @@ -566,6 +607,7 @@ class GetAvailableOcrEngines(GraphQLRequest):
"""
Fetches and lists the available OCR engines
"""

query = """query{
ocrOptions {
engines{
Expand All @@ -581,6 +623,7 @@ def process_response(self, response):
engines = super().process_response(response)["ocrOptions"]["engines"]
return [OcrEngine[e["name"]] for e in engines]


class GetOcrEngineLanguageCodes(GraphQLRequest):
"""
Fetches and lists the available languages by name and code for the given OCR Engine
Expand All @@ -601,7 +644,6 @@ class GetOcrEngineLanguageCodes(GraphQLRequest):
}
}"""


def __init__(self, engine: OcrEngine):
self.engine = engine
super().__init__(self.query)
Expand Down
25 changes: 22 additions & 3 deletions tests/integration/queries/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import json
from indico.filters import DatasetFilter
import pytest
import unittest
import pandas as pd
from pathlib import Path
import os
Expand All @@ -15,7 +15,6 @@
CreateEmptyDataset,
AddFiles,
ProcessFiles,
ProcessCSV,
)
from indico.queries.export import CreateExport, DownloadExport
from indico.types.dataset import (
Expand All @@ -26,7 +25,7 @@
ReadApiOcrOptionsInput,
)
from indico.errors import IndicoRequestError
from tests.integration.data.datasets import airlines_dataset
from tests.integration.data.datasets import airlines_dataset # noqa: F401


def test_create_dataset(indico):
Expand Down Expand Up @@ -78,6 +77,26 @@ def test_list_datasets(indico, airlines_dataset):
assert type(datasets[0]) == Dataset


def test_list_datasets_filtered(indico, airlines_dataset):
client = IndicoClient()
datasets = client.call(
ListDatasets(filters=DatasetFilter(name=f"{time.time()}_bananas"))
)

assert isinstance(datasets, list)
assert len(datasets) == 0

# happy path
datasets = client.call(
ListDatasets(filters=DatasetFilter(name=airlines_dataset.name))
)

assert isinstance(datasets, list)
assert len(datasets) == 1
assert type(datasets[0]) == Dataset
assert datasets[0].name == airlines_dataset.name


def test_images(indico):
client = IndicoClient()

Expand Down

0 comments on commit dbfa802

Please sign in to comment.