From d6efced70a21b4af42fdc805d334ee037e5558c8 Mon Sep 17 00:00:00 2001 From: Maximilian Haye Date: Tue, 20 Aug 2024 19:33:17 +0200 Subject: [PATCH] refactor: ensure_package_and_question_state_exist --- questionpy_server/__init__.py | 6 +- questionpy_server/api/routes/_attempts.py | 16 +- questionpy_server/api/routes/_files.py | 12 +- questionpy_server/api/routes/_packages.py | 30 +- questionpy_server/api/routes/_status.py | 9 +- questionpy_server/app.py | 12 +- questionpy_server/cache.py | 18 +- questionpy_server/collector/__init__.py | 2 +- ...e_collection.py => _package_collection.py} | 27 +- questionpy_server/decorators.py | 378 +++++++++++++----- questionpy_server/package.py | 11 +- questionpy_server/repository/__init__.py | 4 +- questionpy_server/web.py | 97 +---- tests/questionpy_common/test_elements.py | 12 +- .../api/routes/test_packages.py | 3 +- .../collector/test_package_collection.py | 7 +- .../repository/test_repository.py | 4 +- tests/questionpy_server/test_cache.py | 4 +- 18 files changed, 372 insertions(+), 280 deletions(-) rename questionpy_server/collector/{package_collection.py => _package_collection.py} (86%) diff --git a/questionpy_server/__init__.py b/questionpy_server/__init__.py index 5e11d8c5..e3c13219 100644 --- a/questionpy_server/__init__.py +++ b/questionpy_server/__init__.py @@ -1,9 +1,9 @@ -__version__ = "0.1.0" - # This file is part of the QuestionPy Server. (https://questionpy.org) # The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus from questionpy_server.worker.pool import WorkerPool -__all__ = ["WorkerPool"] +__version__ = "0.1.0" + +__all__ = ["WorkerPool", "__version__"] diff --git a/questionpy_server/api/routes/_attempts.py b/questionpy_server/api/routes/_attempts.py index f34f3074..f4d6445f 100644 --- a/questionpy_server/api/routes/_attempts.py +++ b/questionpy_server/api/routes/_attempts.py @@ -8,13 +8,13 @@ from questionpy_common.environment import RequestUser from questionpy_server.api.models import AttemptScoreArguments, AttemptStartArguments, AttemptViewArguments -from questionpy_server.decorators import ensure_package_and_question_state_exist +from questionpy_server.app import QPyServer +from questionpy_server.decorators import ensure_required_parts from questionpy_server.package import Package from questionpy_server.web import json_response from questionpy_server.worker.runtime.package_location import ZipPackageLocation if TYPE_CHECKING: - from questionpy_server.app import QPyServer from questionpy_server.worker.worker import Worker @@ -22,11 +22,11 @@ @attempt_routes.post(r"/packages/{package_hash:\w+}/attempt/start") # type: ignore[arg-type] -@ensure_package_and_question_state_exist +@ensure_required_parts async def post_attempt_start( request: web.Request, package: Package, question_state: bytes, data: AttemptStartArguments ) -> web.Response: - qpyserver: QPyServer = request.app["qpy_server_app"] + qpyserver = request.app[QPyServer.APP_KEY] package_path = await package.get_path() worker: Worker @@ -37,11 +37,11 @@ async def post_attempt_start( @attempt_routes.post(r"/packages/{package_hash:\w+}/attempt/view") # type: ignore[arg-type] -@ensure_package_and_question_state_exist +@ensure_required_parts async def post_attempt_view( request: web.Request, package: Package, question_state: bytes, data: AttemptViewArguments ) -> web.Response: - qpyserver: QPyServer = request.app["qpy_server_app"] + qpyserver = request.app[QPyServer.APP_KEY] package_path = await package.get_path() worker: Worker @@ -58,11 +58,11 @@ async def post_attempt_view( @attempt_routes.post(r"/packages/{package_hash:\w+}/attempt/score") # type: ignore[arg-type] -@ensure_package_and_question_state_exist +@ensure_required_parts async def post_attempt_score( request: web.Request, package: Package, question_state: bytes, data: AttemptScoreArguments ) -> web.Response: - qpyserver: QPyServer = request.app["qpy_server_app"] + qpyserver = request.app[QPyServer.APP_KEY] package_path = await package.get_path() worker: Worker diff --git a/questionpy_server/api/routes/_files.py b/questionpy_server/api/routes/_files.py index a3c2bf82..1c447857 100644 --- a/questionpy_server/api/routes/_files.py +++ b/questionpy_server/api/routes/_files.py @@ -6,35 +6,35 @@ from aiohttp import web from aiohttp.web_exceptions import HTTPNotImplemented -from questionpy_server.decorators import ensure_package_and_question_state_exist +from questionpy_server.app import QPyServer +from questionpy_server.decorators import ensure_package from questionpy_server.package import Package from questionpy_server.worker.runtime.package_location import ZipPackageLocation if TYPE_CHECKING: - from questionpy_server.app import QPyServer from questionpy_server.worker.worker import Worker file_routes = web.RouteTableDef() @file_routes.post(r"/packages/{package_hash}/file/{namespace}/{short_name}/{path:static/.*}") # type: ignore[arg-type] -@ensure_package_and_question_state_exist +@ensure_package async def post_attempt_start(request: web.Request, package: Package) -> web.Response: - qpy_server: QPyServer = request.app["qpy_server_app"] + qpy_server = request.app[QPyServer.APP_KEY] namespace = request.match_info["namespace"] short_name = request.match_info["short_name"] path = request.match_info["path"] if package.manifest.namespace != namespace or package.manifest.short_name != short_name: # TODO: Support static files in non-main packages by using namespace and short_name. - raise HTTPNotImplemented(reason="Static file retrieval from non-main packages is not supported yet.") + raise HTTPNotImplemented(text="Static file retrieval from non-main packages is not supported yet.") worker: Worker async with qpy_server.worker_pool.get_worker(ZipPackageLocation(await package.get_path()), 0, None) as worker: try: file = await worker.get_static_file(path) except FileNotFoundError as e: - raise web.HTTPNotFound(reason="File not found.") from e + raise web.HTTPNotFound(text="File not found.") from e return web.Response( body=file.data, diff --git a/questionpy_server/api/routes/_packages.py b/questionpy_server/api/routes/_packages.py index 4232cc89..19dd8e76 100644 --- a/questionpy_server/api/routes/_packages.py +++ b/questionpy_server/api/routes/_packages.py @@ -9,13 +9,13 @@ from questionpy_common.environment import RequestUser from questionpy_server.api.models import QuestionCreateArguments, QuestionEditFormResponse, RequestBaseData -from questionpy_server.decorators import ensure_package_and_question_state_exist +from questionpy_server.app import QPyServer +from questionpy_server.decorators import ensure_package, ensure_required_parts from questionpy_server.package import Package from questionpy_server.web import json_response from questionpy_server.worker.runtime.package_location import ZipPackageLocation if TYPE_CHECKING: - from questionpy_server.app import QPyServer from questionpy_server.worker.worker import Worker package_routes = web.RouteTableDef() @@ -23,7 +23,7 @@ @package_routes.get("/packages") async def get_packages(request: web.Request) -> web.Response: - qpyserver: QPyServer = request.app["qpy_server_app"] + qpyserver = request.app[QPyServer.APP_KEY] packages = qpyserver.package_collection.get_packages() data = [package.get_info() for package in packages] @@ -33,22 +33,22 @@ async def get_packages(request: web.Request) -> web.Response: @package_routes.get(r"/packages/{package_hash:\w+}") async def get_package(request: web.Request) -> web.Response: - qpyserver: QPyServer = request.app["qpy_server_app"] + qpyserver = request.app[QPyServer.APP_KEY] - try: - package = qpyserver.package_collection.get(request.match_info["package_hash"]) - return json_response(data=package.get_info()) - except FileNotFoundError as error: - raise HTTPNotFound from error + package = qpyserver.package_collection.get(request.match_info["package_hash"]) + if not package: + raise HTTPNotFound + + return json_response(data=package.get_info()) @package_routes.post(r"/packages/{package_hash:\w+}/options") # type: ignore[arg-type] -@ensure_package_and_question_state_exist +@ensure_required_parts async def post_options( - request: web.Request, package: Package, question_state: bytes | None, data: RequestBaseData + request: web.Request, package: Package, data: RequestBaseData, question_state: bytes | None = None ) -> web.Response: """Get the options form definition that allow a question creator to customize a question.""" - qpyserver: QPyServer = request.app["qpy_server_app"] + qpyserver = request.app[QPyServer.APP_KEY] package_path = await package.get_path() worker: Worker @@ -61,11 +61,11 @@ async def post_options( @package_routes.post(r"/packages/{package_hash:\w+}/question") # type: ignore[arg-type] -@ensure_package_and_question_state_exist +@ensure_required_parts async def post_question( request: web.Request, data: QuestionCreateArguments, package: Package, question_state: bytes | None = None ) -> web.Response: - qpyserver: QPyServer = request.app["qpy_server_app"] + qpyserver = request.app[QPyServer.APP_KEY] package_path = await package.get_path() worker: Worker @@ -84,7 +84,7 @@ async def post_question_migrate(_request: web.Request) -> web.Response: @package_routes.post(r"/package-extract-info") # type: ignore[arg-type] -@ensure_package_and_question_state_exist +@ensure_package async def package_extract_info(_request: web.Request, package: Package) -> web.Response: """Get package information.""" return json_response(data=package.get_info(), status=201) diff --git a/questionpy_server/api/routes/_status.py b/questionpy_server/api/routes/_status.py index b7e6977b..27c65b59 100644 --- a/questionpy_server/api/routes/_status.py +++ b/questionpy_server/api/routes/_status.py @@ -2,25 +2,20 @@ # The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus -from typing import TYPE_CHECKING - from aiohttp import web from questionpy_server import __version__ from questionpy_server.api.models import ServerStatus, Usage +from questionpy_server.app import QPyServer from questionpy_server.web import json_response -if TYPE_CHECKING: - from questionpy_server.app import QPyServer - - status_routes = web.RouteTableDef() @status_routes.get(r"/status") async def get_server_status(request: web.Request) -> web.Response: """Get server status.""" - qpyserver: QPyServer = request.app["qpy_server_app"] + qpyserver = request.app[QPyServer.APP_KEY] status = ServerStatus( version=__version__, allow_lms_packages=qpyserver.settings.webservice.allow_lms_packages, diff --git a/questionpy_server/app.py b/questionpy_server/app.py index 3ea68698..7381d9a4 100644 --- a/questionpy_server/app.py +++ b/questionpy_server/app.py @@ -3,24 +3,28 @@ # (c) Technische Universität Berlin, innoCampus from asyncio import create_task -from typing import Any +from typing import Any, ClassVar from aiohttp import web from . import __version__ -from .api.routes import routes from .cache import FileLimitLRU from .collector import PackageCollection from .settings import Settings from .worker.pool import WorkerPool -class QPyServer: +class QPyServer(web.AppKey["QPyServer"]): + APP_KEY: ClassVar[web.AppKey["QPyServer"]] = web.AppKey("qpy_server_app") + def __init__(self, settings: Settings): + # We import here, so we don't have to work around circular imports. + from .api.routes import routes # noqa: PLC0415 + self.settings: Settings = settings self.web_app = web.Application(client_max_size=settings.webservice.max_main_size) self.web_app.add_routes(routes) - self.web_app["qpy_server_app"] = self + self.web_app[self.APP_KEY] = self self.worker_pool = WorkerPool( settings.worker.max_workers, settings.worker.max_memory, worker_type=settings.worker.type diff --git a/questionpy_server/cache.py b/questionpy_server/cache.py index ffd4d407..2f26e472 100644 --- a/questionpy_server/cache.py +++ b/questionpy_server/cache.py @@ -19,9 +19,14 @@ class File(NamedTuple): size: int -class SizeError(Exception): - def __init__(self, message: str = "", max_size: int = 0, actual_size: int = 0): - super().__init__(message) +class CacheItemTooLargeError(Exception): + def __init__(self, key: str, actual_size: int, max_size: int): + readable_actual = ByteSize(actual_size).human_readable() + readable_max = ByteSize(max_size).human_readable() + super().__init__( + f"Unable to cache item '{key}' with size '{readable_actual}' because it exceeds the maximum " + f"allowed size of '{readable_max}'" + ) self.max_size = max_size self.actual_size = actual_size @@ -146,12 +151,7 @@ async def put(self, key: str, value: bytes) -> Path: if size > self.max_size: # If we allowed this, the loop at the end would remove all items from the dictionary, # so we raise an error to allow exceptions for this case. - msg = f"Item itself exceeds maximum allowed size of {ByteSize(self.max_size).human_readable()}" - raise SizeError( - msg, - max_size=self.max_size, - actual_size=size, - ) + raise CacheItemTooLargeError(key, size, self.max_size) async with self._lock: # Save the bytes on filesystem. diff --git a/questionpy_server/collector/__init__.py b/questionpy_server/collector/__init__.py index 00d0c1ce..89607fb3 100644 --- a/questionpy_server/collector/__init__.py +++ b/questionpy_server/collector/__init__.py @@ -2,7 +2,7 @@ # The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus -from questionpy_server.collector.package_collection import PackageCollection +from questionpy_server.collector._package_collection import PackageCollection __all__ = [ "PackageCollection", diff --git a/questionpy_server/collector/package_collection.py b/questionpy_server/collector/_package_collection.py similarity index 86% rename from questionpy_server/collector/package_collection.py rename to questionpy_server/collector/_package_collection.py index 3e2fa39e..3b79da26 100644 --- a/questionpy_server/collector/package_collection.py +++ b/questionpy_server/collector/_package_collection.py @@ -80,20 +80,13 @@ async def put(self, package_container: "HashContainer") -> "Package": """ return await self._lms_collector.put(package_container) - def get(self, package_hash: str) -> "Package": + def get(self, package_hash: str) -> "Package | None": """Returns a package if it exists. Args: - package_hash (str): hash value of the package - - Returns: - path to the package + package_hash: hash value of the package """ - # Check if package was indexed - if package := self._indexer.get_by_hash(package_hash): - return package - - raise FileNotFoundError + return self._indexer.get_by_hash(package_hash) def get_by_identifier(self, identifier: str) -> dict[SemVer, "Package"]: """Returns a dict of packages with the given identifier and available versions. @@ -106,20 +99,14 @@ def get_by_identifier(self, identifier: str) -> dict[SemVer, "Package"]: """ return self._indexer.get_by_identifier(identifier) - def get_by_identifier_and_version(self, identifier: str, version: SemVer) -> "Package": + def get_by_identifier_and_version(self, identifier: str, version: SemVer) -> "Package | None": """Returns a package with the given identifier and version. Args: - identifier (str): identifier of the package - version (str): version of the package - - Returns: - package + identifier: identifier of the package + version: version of the package """ - if package := self._indexer.get_by_identifier_and_version(identifier, version): - return package - - raise FileNotFoundError + return self._indexer.get_by_identifier_and_version(identifier, version) def get_packages(self) -> set["Package"]: """Returns a set of all available packages. diff --git a/questionpy_server/decorators.py b/questionpy_server/decorators.py index ea4e73d9..f195bc58 100644 --- a/questionpy_server/decorators.py +++ b/questionpy_server/decorators.py @@ -1,107 +1,297 @@ -import functools -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, cast, get_type_hints +import inspect +from collections.abc import Awaitable, Callable +from functools import wraps +from inspect import Parameter +from typing import Concatenate, NamedTuple, ParamSpec, TypeAlias -from aiohttp.abc import Request +from aiohttp import BodyPartReader, web from aiohttp.log import web_logger -from aiohttp.web_exceptions import HTTPBadRequest, HTTPNotFound, HTTPUnsupportedMediaType +from aiohttp.web_exceptions import HTTPBadRequest +from pydantic import BaseModel, ValidationError +from questionpy_common import constants from questionpy_server.api.models import MainBaseModel, NotFoundStatus, NotFoundStatusWhat -from questionpy_server.types import RouteHandler -from questionpy_server.web import create_model_from_json, get_or_save_package, parse_form_data +from questionpy_server.app import QPyServer +from questionpy_server.cache import CacheItemTooLargeError +from questionpy_server.package import Package +from questionpy_server.types import M +from questionpy_server.web import ( + HashContainer, + read_part, +) -if TYPE_CHECKING: - from questionpy_server.app import QPyServer +_P = ParamSpec("_P") +_HandlerFunc: TypeAlias = Callable[Concatenate[web.Request, _P], Awaitable[web.StreamResponse]] -# TODO: refactor to reduce complexity -def ensure_package_and_question_state_exist( # noqa: C901 - _func: RouteHandler | None = None, -) -> RouteHandler | Callable[[RouteHandler], RouteHandler]: - """Decorator that ensures package and question state exist. +class _RequestBodyParts(NamedTuple): + main: bytes | None + package: HashContainer | None + question_state: bytes | None - Ensures that the package and question state exist (if needed by func) and that the json corresponds to the model - given as a type annotation in func. - This decorator assumes that: - * func may want an argument named 'data' (with a subclass of MainBaseModel) - * func may want an argument named 'question_state' (bytes or bytes | None) - * every func wants a package with an argument named 'package' +def _get_main_body_param(handler: _HandlerFunc, signature: inspect.Signature) -> inspect.Parameter | None: + candidates = [ + param + for param in signature.parameters.values() + if isinstance(param.annotation, type) and issubclass(param.annotation, MainBaseModel) + ] + + if not candidates: + # Handler doesn't use the main body. + return None + + if len(candidates) > 1: + msg = f"Handler function '{handler.__name__}' ambiguously takes multiple MainBaseModel parameters" + raise TypeError(msg) + + return candidates[0] + + +def _get_package_param(handler: _HandlerFunc, signature: inspect.Signature) -> inspect.Parameter | None: + candidates = [param for param in signature.parameters.values() if param.annotation is Package] + + if not candidates: + # Handler doesn't use the package. + return None + + if len(candidates) > 1: + msg = f"Handler function '{handler.__name__}' ambiguously takes multiple Package parameters" + raise TypeError(msg) + + return candidates[0] + + +_PARTS_REQUEST_KEY = "qpy-request-parts" + + +async def _read_body_parts(request: web.Request) -> _RequestBodyParts: + # We can only read the body once, and we have to read all of it at once (since we make no assumption about the order + # of the parts). Since we want to otherwise decouple main body, package, and question state handling logic, we cache + # the read body as a request variable. + parts: _RequestBodyParts = request.get(_PARTS_REQUEST_KEY, None) + if parts: + return parts + + if not request.body_exists: + # No body sent at all. + parts = _RequestBodyParts(None, None, None) + elif request.content_type == "multipart/form-data": + # Multiple parts. + parts = await _parse_form_data(request) + elif request.content_type == "application/json": + # Just the main body part. + parts = _RequestBodyParts(await request.read(), None, None) + else: + msg = ( + f"Wrong content type, expected multipart/form-data, application/json or no body, got " + f"'{request.content_type}'" + ) + web_logger.info(msg) + raise web.HTTPUnsupportedMediaType(text=msg) + + request[_PARTS_REQUEST_KEY] = parts + return parts + + +class _ExceptionMixin(web.HTTPException): + def __init__(self, msg: str, body: BaseModel | None = None) -> None: + if body: + # Send structured error body as JSON. + super().__init__(reason=type(self).__name__, text=body.model_dump_json(), content_type="application/json") + else: + # Send the detailed message. + super().__init__(reason=type(self).__name__, text=msg) + + # web.HTTPException uses the HTTP reason (which should be very short) as the exception message (which should be + # detailed). This sets the message to our detailed one. + Exception.__init__(self, msg) + + web_logger.info(msg) + + +class MainBodyMissingError(web.HTTPBadRequest, _ExceptionMixin): + def __init__(self) -> None: + super().__init__("The main body is required but was not provided.") + + +class PackageMissingWithoutHashError(web.HTTPBadRequest, _ExceptionMixin): + def __init__(self) -> None: + super().__init__("The package is required but was not provided.") + + +class PackageMissingByHashError(web.HTTPNotFound, _ExceptionMixin): + def __init__(self, package_hash: str) -> None: + super().__init__( + f"The package was not provided, is not cached and could not be found by its hash. ('{package_hash}')", + NotFoundStatus(what=NotFoundStatusWhat.PACKAGE), + ) + + +class PackageHashMismatchError(web.HTTPBadRequest, _ExceptionMixin): + def __init__(self, from_uri: str, from_body: str) -> None: + super().__init__( + f"The request URI specifies a package with hash '{from_uri}', but the sent package has a hash of " + f"'{from_body}'." + ) + + +class QuestionStateMissingError(web.HTTPBadRequest, _ExceptionMixin): + def __init__(self) -> None: + super().__init__( + "A question state part is required but was not provided.", + NotFoundStatus(what=NotFoundStatusWhat.QUESTION_STATE), + ) + + +def ensure_package(handler: _HandlerFunc, *, param: inspect.Parameter | None = None) -> _HandlerFunc: + """Decorator ensuring that the package needed by the handler is present.""" + if not param: + signature = inspect.signature(handler) + param = _get_package_param(handler, signature) + + if not param: + msg = f"Handler '{handler.__name__}' does not have a package param but is decorated with ensure_package." + raise TypeError(msg) + + @wraps(handler) + async def wrapper(request: web.Request, *args: _P.args, **kwargs: _P.kwargs) -> web.StreamResponse: + server = request.app[QPyServer.APP_KEY] + + uri_package_hash: str | None = request.match_info.get("package_hash", None) + parts = await _read_body_parts(request) + + if parts.package and uri_package_hash and uri_package_hash != parts.package.hash: + raise PackageHashMismatchError(uri_package_hash, parts.package.hash) + + package = None + if uri_package_hash: + package = server.package_collection.get(uri_package_hash) + + if not package and parts.package: + try: + package = await server.package_collection.put(parts.package) + except CacheItemTooLargeError as e: + raise web.HTTPRequestEntityTooLarge(max_size=e.max_size, actual_size=e.actual_size, text=str(e)) from e + + if not package: + if uri_package_hash: + raise PackageMissingByHashError(uri_package_hash) + raise PackageMissingWithoutHashError + + kwargs[param.name] = package # type: ignore[union-attr] # we narrowed package_param earlier + + return await handler(request, *args, **kwargs) + + return wrapper + + +def ensure_question_state(handler: _HandlerFunc, *, param: inspect.Parameter | None = None) -> _HandlerFunc: + """Decorator ensuring that the question state is present if needed by the handler.""" + if not param: + signature = inspect.signature(handler) + param = signature.parameters.get("question_state", None) + + if not param: + msg = ( + f"Handler '{handler.__name__}' does not have a question state param but is decorated with " + f"ensure_question_state." + ) + raise TypeError(msg) + + @wraps(handler) + async def wrapper(request: web.Request, *args: _P.args, **kwargs: _P.kwargs) -> web.StreamResponse: + parts = await _read_body_parts(request) + + if parts.question_state is not None: + kwargs[param.name] = parts.question_state + elif param.default is Parameter.empty: + raise QuestionStateMissingError + + return await handler(request, *args, **kwargs) + + return wrapper + + +def ensure_main_body(handler: _HandlerFunc, *, param: inspect.Parameter | None = None) -> _HandlerFunc: + if not param: + signature = inspect.signature(handler) + param = _get_main_body_param(handler, signature) + + if not param: + msg = ( + f"Handler '{handler.__name__}' does not have a MainBaseModel param but is decorated with " + f"ensure_main_body." + ) + raise TypeError(msg) + + @wraps(handler) + async def wrapper(request: web.Request, *args: _P.args, **kwargs: _P.kwargs) -> web.StreamResponse: + parts = await _read_body_parts(request) + + if parts.main is None: + raise MainBodyMissingError + + kwargs[param.name] = _validate_from_http(parts.main, param.annotation) + return await handler(request, *args, **kwargs) + + return wrapper + + +def ensure_required_parts(handler: _HandlerFunc) -> _HandlerFunc: + signature = inspect.signature(handler) + + main_body_param = _get_main_body_param(handler, signature) + question_state_param = signature.parameters.get("question_state", None) + package_param = _get_package_param(handler, signature) + + if main_body_param: + handler = ensure_main_body(handler, param=main_body_param) + + if question_state_param: + handler = ensure_question_state(handler, param=question_state_param) + + if package_param: + handler = ensure_package(handler, param=package_param) + + return handler + + +async def _parse_form_data(request: web.Request) -> _RequestBodyParts: + """Parses a multipart/form-data request. Args: - _func (Optional[RouteHandler]): Control parameter; allows using the decorator with or without arguments. - If this decorator is used with any arguments, this will always be the decorated function itself. (Default - value = None) + request (Request): The request to be parsed. + + Returns: tuple of main field, package, and question state """ + server = request.app[QPyServer.APP_KEY] + main = package = question_state = None + + reader = await request.multipart() + while part := await reader.next(): + if not isinstance(part, BodyPartReader): + continue + + if part.name == "main": + main = await read_part(part, server.settings.webservice.max_main_size, calculate_hash=False) + elif part.name == "package": + package = await read_part(part, server.settings.webservice.max_package_size, calculate_hash=True) + elif part.name == "question_state": + question_state = await read_part(part, constants.MAX_QUESTION_STATE_SIZE, calculate_hash=False) + + return _RequestBodyParts(main, package, question_state) - def decorator(function: RouteHandler) -> RouteHandler: # noqa: C901 - """Internal decorator function.""" - type_hints = get_type_hints(function) - question_state_type = type_hints.get("question_state") - takes_question_state = question_state_type is not None - require_question_state = question_state_type is bytes - main_part_json_model: type[MainBaseModel] | None = type_hints.get("data") - - if main_part_json_model and not issubclass(main_part_json_model, MainBaseModel): - msg = f"Parameter 'data' of function {function.__name__} has unexpected type." - raise TypeError(msg) - - @functools.wraps(function) - async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Any: # noqa: C901 - """Wrapper around the actual function call.""" - server: QPyServer = request.app["qpy_server_app"] - package_hash: str = request.match_info.get("package_hash", "") - - if not request.body_exists: - main, sent_package, sent_question_state = None, None, None - elif request.content_type == "multipart/form-data": - main, sent_package, sent_question_state = await parse_form_data(request) - elif request.content_type == "application/json": - main, sent_package, sent_question_state = await request.read(), None, None - else: - web_logger.info("Wrong content type, multipart/form-data expected, got %s", request.content_type) - raise HTTPUnsupportedMediaType - - if main_part_json_model: - if main is None: - msg = "Multipart/form-data field 'main' is not set" - web_logger.warning(msg) - raise HTTPBadRequest(text=msg) - - model = create_model_from_json(main.decode(), main_part_json_model) - kwargs["data"] = model - - # Check if func wants a question state and if it is provided. - if takes_question_state: - if require_question_state and sent_question_state is None: - msg = "Multipart/form-data field 'question_state' is not set" - web_logger.warning(msg) - raise HTTPBadRequest(text=msg) - kwargs["question_state"] = sent_question_state - - # Check if a package is provided and if it matches the optional hash given in the URL. - if sent_package and package_hash and package_hash != sent_package.hash: - msg = f"Package hash does not match: {package_hash} != {sent_package.hash}" - web_logger.warning(msg) - raise HTTPBadRequest(text=msg) - - package = await get_or_save_package(server.package_collection, package_hash, sent_package) - if package is None: - if package_hash: - raise HTTPNotFound( - text=NotFoundStatus(what=NotFoundStatusWhat.PACKAGE).model_dump_json(), - content_type="application/json", - ) - - msg = "No package found in multipart/form-data" - web_logger.warning(msg) - raise HTTPBadRequest(text=msg) - - kwargs["package"] = package - return await function(request, *args, **kwargs) - - return cast(RouteHandler, wrapper) - - if _func is None: - return decorator - return decorator(_func) + +def _validate_from_http(raw_body: str | bytes, param_class: type[M]) -> M: + """Validates the given json which was presumably an HTTP body to the given Pydantic model. + + Args: + raw_body: raw json body + param_class: the [pydantic.BaseModel][] subclass to valdiate to + """ + try: + return param_class.model_validate_json(raw_body) + except ValidationError as error: + web_logger.info("JSON does not match model: %s", error) + raise HTTPBadRequest(reason="Invalid JSON Body") from error diff --git a/questionpy_server/package.py b/questionpy_server/package.py index 2cbc1e3b..5fc7bf08 100644 --- a/questionpy_server/package.py +++ b/questionpy_server/package.py @@ -4,14 +4,17 @@ import contextlib from pathlib import Path +from typing import TYPE_CHECKING from questionpy_server.api.models import PackageInfo -from questionpy_server.collector.abc import BaseCollector from questionpy_server.collector.lms_collector import LMSCollector from questionpy_server.collector.local_collector import LocalCollector from questionpy_server.collector.repo_collector import RepoCollector from questionpy_server.utils.manifest import ComparableManifest +if TYPE_CHECKING: + from questionpy_server.collector.abc import BaseCollector + class PackageSources: """A container for all package sources.""" @@ -28,7 +31,7 @@ def __len__(self) -> int: lms_collector = 1 if self._lms_collector else 0 return local_collector + len(self._repo_collectors) + lms_collector - def add(self, collector: BaseCollector) -> None: + def add(self, collector: "BaseCollector") -> None: """Adds a collector to the package sources. Args: @@ -44,7 +47,7 @@ def add(self, collector: BaseCollector) -> None: msg = f"Invalid collector type: {type(collector)}" raise TypeError(msg) - def remove(self, collector: BaseCollector) -> None: + def remove(self, collector: "BaseCollector") -> None: """Removes a collector from the package sources. Args: @@ -106,7 +109,7 @@ def __init__( self, package_hash: str, manifest: ComparableManifest, - source: BaseCollector | None = None, + source: "BaseCollector | None" = None, path: Path | None = None, ): self.hash = package_hash diff --git a/questionpy_server/repository/__init__.py b/questionpy_server/repository/__init__.py index bb294130..d64ffb6a 100644 --- a/questionpy_server/repository/__init__.py +++ b/questionpy_server/repository/__init__.py @@ -7,7 +7,7 @@ from gzip import decompress from urllib.parse import urljoin -from questionpy_server.cache import FileLimitLRU, SizeError +from questionpy_server.cache import CacheItemTooLargeError, FileLimitLRU from questionpy_server.repository.helper import download from questionpy_server.repository.models import RepoMeta, RepoPackage, RepoPackageIndex from questionpy_server.utils.logger import URLAdapter @@ -52,7 +52,7 @@ async def get_packages(self, meta: RepoMeta) -> dict[str, RepoPackage]: raw_index_zip = await download(self._url_index, size=meta.size, expected_hash=meta.sha256) try: await self._cache.put(meta.sha256, raw_index_zip) - except SizeError: + except CacheItemTooLargeError: self._log.warning("Package index is too big to be cached.") raw_index = decompress(raw_index_zip) diff --git a/questionpy_server/web.py b/questionpy_server/web.py index 49faa8be..ec54dd9a 100644 --- a/questionpy_server/web.py +++ b/questionpy_server/web.py @@ -5,33 +5,23 @@ from collections.abc import Sequence from hashlib import sha256 from io import BytesIO -from json import JSONDecodeError, loads -from typing import TYPE_CHECKING, Literal, NamedTuple, overload +from typing import Literal, NamedTuple, overload from aiohttp import BodyPartReader -from aiohttp.abc import Request from aiohttp.log import web_logger -from aiohttp.web_exceptions import HTTPBadRequest, HTTPRequestEntityTooLarge +from aiohttp.web_exceptions import HTTPRequestEntityTooLarge from aiohttp.web_response import Response -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel -from questionpy_common import constants from questionpy_common.constants import KiB -from questionpy_server.cache import SizeError -from questionpy_server.collector import PackageCollection -from questionpy_server.package import Package -from questionpy_server.types import M - -if TYPE_CHECKING: - from questionpy_server.app import QPyServer def json_response(data: Sequence[BaseModel] | BaseModel, status: int = 200) -> Response: """Creates a json response from a single BaseModel or a list of BaseModels. Args: - data (Union[Sequence[BaseModel]): A BaseModel or a list of BaseModels. - status (int): The HTTP status code. + data: A BaseModel or a list of BaseModels. + status: The HTTP status code. Returns: Response: A response object. @@ -42,29 +32,6 @@ def json_response(data: Sequence[BaseModel] | BaseModel, status: int = 200) -> R return Response(text=data.model_dump_json(), status=status, content_type="application/json") -def create_model_from_json(json: object | str, param_class: type[M]) -> M: - """Creates a BaseModel from an object. - - Args: - json (Union[object, str]): object containing the parsed json - param_class (type[M]): BaseModel class - - Returns: - M: BaseModel - """ - try: - if isinstance(json, str): - json = loads(json) - model = param_class.model_validate(json) - except ValidationError as error: - web_logger.warning("JSON does not match model: %s", error) - raise HTTPBadRequest from error - except JSONDecodeError as error: - web_logger.warning("Invalid JSON in request") - raise HTTPBadRequest from error - return model - - class HashContainer(NamedTuple): data: bytes hash: str @@ -112,57 +79,3 @@ async def read_part(part: BodyPartReader, max_size: int, *, calculate_hash: bool if calculate_hash: return HashContainer(data=buffer.getvalue(), hash=hash_object.hexdigest()) return buffer.getvalue() - - -async def parse_form_data(request: Request) -> tuple[bytes | None, HashContainer | None, bytes | None]: - """Parses a multipart/form-data request. - - Args: - request (Request): The request to be parsed. - - Returns: - tuple of main field, package, and question state - """ - server: QPyServer = request.app["qpy_server_app"] - main = package = question_state = None - - reader = await request.multipart() - while part := await reader.next(): - if not isinstance(part, BodyPartReader): - continue - - if part.name == "main": - main = await read_part(part, server.settings.webservice.max_main_size, calculate_hash=False) - elif part.name == "package": - package = await read_part(part, server.settings.webservice.max_package_size, calculate_hash=True) - elif part.name == "question_state": - question_state = await read_part(part, constants.MAX_QUESTION_STATE_SIZE, calculate_hash=False) - - return main, package, question_state - - -async def get_or_save_package( - collection: PackageCollection, hash_value: str, container: HashContainer | None -) -> Package | None: - """Gets a package from or saves it in the package collection. - - Args: - collection (PackageCollection): package collection - hash_value (str): The hash of the package. - container (Optional[HashContainer]): container with the package data and its hash - - Returns: - package - """ - try: - if not container: - package = collection.get(hash_value) - else: - package = await collection.put(container) - except SizeError as error: - raise HTTPRequestEntityTooLarge( - max_size=error.max_size, actual_size=error.actual_size, text=str(error) - ) from error - except FileNotFoundError: - return None - return package diff --git a/tests/questionpy_common/test_elements.py b/tests/questionpy_common/test_elements.py index f6ad8fd0..3331afc2 100644 --- a/tests/questionpy_common/test_elements.py +++ b/tests/questionpy_common/test_elements.py @@ -3,6 +3,7 @@ # (c) Technische Universität Berlin, innoCampus from io import BytesIO +from unittest.mock import patch import pytest from aiohttp import FormData @@ -38,6 +39,7 @@ TextInputElement, is_form_element, ) +from questionpy_server.collector import PackageCollection from tests.conftest import get_file_hash, package_dir, test_data_path PACKAGE = package_dir / "package_1.qpy" @@ -51,10 +53,12 @@ QUESTION_STATE_REQUEST = (path / "main.json").read_text() -async def test_optional_question_state(client: TestClient) -> None: - # Even though the question state is optional, the body is still required to be valid JSON. - res = await client.request(METHOD, URL, data=b"{not_valid!}", headers={"Content-Type": "application/json"}) - assert res.status == 400 +async def test_should_validate_main_body_when_question_state_is_not_given(client: TestClient) -> None: + with patch.object(PackageCollection, "get"): + # Even though the question state is optional, the body is still required to be valid JSON. + res = await client.request(METHOD, URL, data=b"{not_valid!}", headers={"Content-Type": "application/json"}) + assert res.status == 400 + assert res.reason == "Invalid JSON Body" async def test_no_package(client: TestClient) -> None: diff --git a/tests/questionpy_server/api/routes/test_packages.py b/tests/questionpy_server/api/routes/test_packages.py index 282ff51d..501cfbb7 100644 --- a/tests/questionpy_server/api/routes/test_packages.py +++ b/tests/questionpy_server/api/routes/test_packages.py @@ -40,5 +40,4 @@ async def test_extract_info_faulty(client: TestClient) -> None: res = await client.request("POST", "/package-extract-info", data=payload) assert res.status == 400 - text = await res.text() - assert text == "No package found in multipart/form-data" + assert res.reason == "PackageMissingWithoutHashError" diff --git a/tests/questionpy_server/collector/test_package_collection.py b/tests/questionpy_server/collector/test_package_collection.py index 44b2367d..e439c152 100644 --- a/tests/questionpy_server/collector/test_package_collection.py +++ b/tests/questionpy_server/collector/test_package_collection.py @@ -5,7 +5,6 @@ from pathlib import Path from unittest.mock import Mock, patch -import pytest from _pytest.tmpdir import TempPathFactory from semver import VersionInfo @@ -53,8 +52,7 @@ def test_get_package() -> None: # Package does not exist. with patch.object(Indexer, "get_by_hash", return_value=None) as get_by_hash: - with pytest.raises(FileNotFoundError): - package_collection.get("hash") + assert package_collection.get("hash") is None get_by_hash.assert_called_once_with("hash") @@ -78,8 +76,7 @@ def test_get_package_by_identifier_and_version() -> None: # Package does not exist. with patch.object(Indexer, "get_by_identifier_and_version", return_value=None) as get_by_identifier_and_version: version = VersionInfo.parse("0.1.0") - with pytest.raises(FileNotFoundError): - package_collection.get_by_identifier_and_version("@default/name", version) + assert package_collection.get_by_identifier_and_version("@default/name", version) is None get_by_identifier_and_version.assert_called_once_with("@default/name", version) diff --git a/tests/questionpy_server/repository/test_repository.py b/tests/questionpy_server/repository/test_repository.py index b5813271..07e21685 100644 --- a/tests/questionpy_server/repository/test_repository.py +++ b/tests/questionpy_server/repository/test_repository.py @@ -11,7 +11,7 @@ from _pytest.tmpdir import TempPathFactory from questionpy_common.constants import KiB -from questionpy_server.cache import FileLimitLRU, SizeError +from questionpy_server.cache import CacheItemTooLargeError, FileLimitLRU from questionpy_server.repository import RepoMeta, RepoPackage, RepoPackageIndex, Repository from questionpy_server.utils.manifest import ComparableManifest from tests.test_data.factories import ManifestFactory, RepoMetaFactory, RepoPackageVersionsFactory @@ -119,7 +119,7 @@ async def test_log_warning_when_package_index_is_too_big_for_cache( with ( patch("questionpy_server.repository.download") as mock_download, - patch.object(cache, "put", side_effect=SizeError), + patch.object(cache, "put", side_effect=CacheItemTooLargeError("key", 2, 1)), ): parsed = package_index.model_dump_json() mock_download.return_value = compress(parsed.encode()) diff --git a/tests/questionpy_server/test_cache.py b/tests/questionpy_server/test_cache.py index 29952d30..9e4f9742 100644 --- a/tests/questionpy_server/test_cache.py +++ b/tests/questionpy_server/test_cache.py @@ -11,7 +11,7 @@ import pytest from _pytest.tmpdir import TempPathFactory -from questionpy_server.cache import FileLimitLRU, SizeError +from questionpy_server.cache import CacheItemTooLargeError, FileLimitLRU @dataclass @@ -186,7 +186,7 @@ async def test_put(cache: FileLimitLRU, settings: Settings) -> None: assert get_file_count(settings.cache.directory) == settings.items.num_of_items # Content size is bigger than cache size. - with pytest.raises(SizeError): + with pytest.raises(CacheItemTooLargeError): await cache.put("new", b"." * (settings.cache.size + 1)) # Replace existing file.