diff --git a/conda-store-server/conda_store_server/_internal/server/app.py b/conda-store-server/conda_store_server/_internal/server/app.py index e380f654a..5633356f7 100644 --- a/conda-store-server/conda_store_server/_internal/server/app.py +++ b/conda-store-server/conda_store_server/_internal/server/app.py @@ -16,6 +16,7 @@ from fastapi.responses import FileResponse, JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates +from fastapi_pagination import add_pagination from sqlalchemy.pool import QueuePool from starlette.middleware.sessions import SessionMiddleware from traitlets import ( @@ -231,6 +232,8 @@ def trim_slash(url): }, ) + add_pagination(app) + app.add_middleware( CORSMiddleware, allow_origins=self.cors_allow_origins, diff --git a/conda-store-server/conda_store_server/_internal/server/views/api.py b/conda-store-server/conda_store_server/_internal/server/views/api.py index 999f053a2..90450d25d 100644 --- a/conda-store-server/conda_store_server/_internal/server/views/api.py +++ b/conda-store-server/conda_store_server/_internal/server/views/api.py @@ -3,13 +3,18 @@ # license that can be found in the LICENSE file. import datetime -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, Dict, List, Optional, Tuple, TypedDict import pydantic import yaml from celery.result import AsyncResult from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request from fastapi.responses import PlainTextResponse, RedirectResponse +from fastapi_pagination import set_params +from fastapi_pagination.cursor import CursorPage, CursorParams +from fastapi_pagination.ext.sqlalchemy import paginate +from pydantic import BaseModel +from sqlalchemy.orm import Query as SqlQuery from conda_store_server import __version__, api, app from conda_store_server._internal import orm, schema, utils @@ -18,6 +23,8 @@ from conda_store_server._internal.server import dependencies from conda_store_server.server.auth import Authentication +set_params(CursorParams(size=10, cursor=None)) + class PaginatedArgs(TypedDict): """Dictionary type holding information about paginated requests.""" @@ -54,19 +61,40 @@ def get_paginated_args( def filter_distinct_on( - query, - distinct_on: List[str] = [], - allowed_distinct_ons: Dict = {}, - default_distinct_on: List[str] = [], -): - distinct_on = distinct_on or default_distinct_on - distinct_on = [ - allowed_distinct_ons[d] for d in distinct_on if d in allowed_distinct_ons - ] + query: SqlQuery, + distinct_on: List[str] | None = None, + allowed_distinct_ons: Dict | None = None, +) -> Tuple[List[str], SqlQuery]: + """Filter the query using the distinct fields. - if distinct_on: - return distinct_on, query.distinct(*distinct_on) - return distinct_on, query + Parameters + ---------- + query : SqlQuery + Query to filter + distinct_on : List[str] | None + Parameter to pass to the FILTER DISTINCT statement + allowed_distinct_ons : Dict | None + Allowed values of the parameter + + Returns + ------- + SqlQuery + Query containing filtered results + """ + if distinct_on is None: + distinct_on = [] + + disallowed = set(distinct_on) - set(allowed_distinct_ons) + if disallowed: + raise HTTPException( + status_code=400, + detail=( + f"Requested distinct_on terms ({disallowed}) are not allowed. " + f"Valid terms are {set(allowed_distinct_ons)}" + ), + ) + + return query.distinct(*[allowed_distinct_ons[item] for item in distinct_on]) def get_sorts( @@ -93,6 +121,15 @@ def get_sorts( return [order_mapping[order](k) for k in sort_by] +def paginate_response( + query: SqlQuery, + obj_schema: BaseModel, + order: str = "asc", + sort_by: List[str] = None, +) -> CursorPage: + return + + def paginated_api_response( query, paginated_args, @@ -103,7 +140,7 @@ def paginated_api_response( required_sort_bys: List = [], default_sort_by: List = [], default_order: str = "asc", -): +) -> CursorPage: sorts = get_sorts( order=paginated_args["order"], sort_by=paginated_args["sort_by"], @@ -113,15 +150,17 @@ def paginated_api_response( default_order=default_order, ) - count = query.count() - query = ( - query.order_by(*sorts) - .limit(paginated_args["limit"]) - .offset(paginated_args["offset"]) + print( + query, + paginated_args, + object_schema, + sorts, ) + + count = query.count() return { "status": "ok", - "data": [object_schema.from_orm(_).dict(exclude=exclude) for _ in query.all()], + "data": paginate(query.order_by(*sorts)), "page": (paginated_args["offset"] // paginated_args["limit"]) + 1, "size": paginated_args["limit"], "count": count, @@ -712,9 +751,9 @@ async def api_list_environments( paginated_args, schema.Environment, exclude={"current_build"}, - allowed_sort_bys={ - "scheduled_on": orm.Environment.current_build.scheduled_on, - }, + # allowed_sort_bys={ + # "scheduled_on": orm.Environment.current_build.scheduled_on, + # }, default_sort_by=["scheduled_on"], default_order="asc", ) diff --git a/conda-store-server/pyproject.toml b/conda-store-server/pyproject.toml index 318466b03..e2f1e25fb 100644 --- a/conda-store-server/pyproject.toml +++ b/conda-store-server/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "alembic", "celery", "fastapi", + "fastapi_pagination", "filelock", "flower", "itsdangerous", @@ -67,6 +68,7 @@ dependencies = [ "pydantic >=1.10.16,<2.0a0", "python-multipart", "sqlalchemy<2.0a0", + "sqlakeyset", "traitlets", "uvicorn", "yarl",