diff --git a/questionpy_server/api/models.py b/questionpy_server/api/models.py index 9212ad11..5c4f564f 100644 --- a/questionpy_server/api/models.py +++ b/questionpy_server/api/models.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Annotated, Any -from pydantic import BaseModel, ByteSize, ConfigDict, Field, FilePath, HttpUrl +from pydantic import BaseModel, ByteSize, ConfigDict, Field, HttpUrl from questionpy_common.api.attempt import AttemptModel from questionpy_common.api.question import QuestionModel @@ -16,10 +16,21 @@ class PackageInfo(BaseModel): model_config = ConfigDict(use_enum_values=True) - package_hash: str short_name: str namespace: str name: dict[str, str] + type: PackageType + author: str | None + url: str | None + languages: set[str] | None + description: dict[str, str] | None + icon: str | None + license: str | None + tags: set[str] | None + + +class PackageVersionSpecificInfo(BaseModel): + package_hash: str version: Annotated[ str, Field( @@ -29,14 +40,15 @@ class PackageInfo(BaseModel): r"(\+([0-9a-zA-Z-]+(\.[0-9a-zA-Z-]+)*))?$" ), ] - type: PackageType - author: str | None - url: HttpUrl | None - languages: list[str] | None - description: dict[str, str] | None - icon: FilePath | HttpUrl | None - license: str | None - tags: list[str] | None + + +class PackageVersionInfo(PackageInfo, PackageVersionSpecificInfo): + pass + + +class PackageVersionsInfo(BaseModel): + manifest: PackageInfo + versions: list[PackageVersionSpecificInfo] class MainBaseModel(BaseModel): diff --git a/questionpy_server/api/routes.py b/questionpy_server/api/routes.py index 8da8161f..ef8d54c0 100644 --- a/questionpy_server/api/routes.py +++ b/questionpy_server/api/routes.py @@ -36,10 +36,8 @@ async def get_packages(request: web.Request) -> web.Response: qpyserver: "QPyServer" = request.app["qpy_server_app"] - packages = qpyserver.package_collection.get_packages() - data = [package.get_info() for package in packages] - - return json_response(data=data) + package_versions_infos = qpyserver.package_collection.get_package_versions_infos() + return json_response(data=package_versions_infos) @routes.get(r"/packages/{package_hash:\w+}") diff --git a/questionpy_server/collector/indexer.py b/questionpy_server/collector/indexer.py index d5eff701..4dfe6a69 100644 --- a/questionpy_server/collector/indexer.py +++ b/questionpy_server/collector/indexer.py @@ -8,6 +8,7 @@ from typing import overload from questionpy_server import WorkerPool +from questionpy_server.api.models import PackageInfo, PackageVersionsInfo, PackageVersionSpecificInfo from questionpy_server.collector.abc import BaseCollector from questionpy_server.collector.local_collector import LocalCollector from questionpy_server.collector.repo_collector import RepoCollector @@ -30,6 +31,8 @@ def __init__(self, worker_pool: WorkerPool): self._index_by_identifier: dict[str, dict[SemVer, Package]] = {} """dict[identifier, dict[version, Package]]""" + self._package_versions_infos: list[PackageVersionsInfo] | None = None + self._lock: Lock | None = None def get_by_hash(self, package_hash: str) -> Package | None: @@ -66,13 +69,38 @@ def get_by_identifier_and_version(self, identifier: str, version: SemVer) -> Pac """ return self._index_by_identifier.get(identifier, {}).get(version, None) - def get_packages(self) -> set[Package]: - """Returns all packages in the index (excluding packages from LMSs). + def get_package_versions_infos(self) -> list[PackageVersionsInfo]: + """Returns an overview of every package and its versions (excluding packages from LMSs). + + TODO: optimize further? Returns: - set of packages + list of PackageVersionsInfo """ - return {package for packages in self._index_by_identifier.values() for package in packages.values()} + if self._package_versions_infos is not None: + return self._package_versions_infos + + package_versions_infos = [] + + for package_versions in self._index_by_identifier.values(): + versions = [] + sorted_package_versions = sorted(package_versions, reverse=True) + for version in sorted_package_versions: + package_version = package_versions[version] + versions.append(PackageVersionSpecificInfo(package_hash=package_version.hash, version=str(version))) + + # A package should always have at least one package version, we try-except just in case. + try: + latest_package_version = package_versions[sorted_package_versions[0]] + package_info = PackageInfo(**latest_package_version.manifest.model_dump()) + except KeyError: + continue + + package_versions_info = PackageVersionsInfo(manifest=package_info, versions=versions) + package_versions_infos.append(package_versions_info) + + self._package_versions_infos = package_versions_infos + return self._package_versions_infos @overload async def register_package( @@ -128,6 +156,9 @@ async def register_package( else: package_versions[package.manifest.version] = package + # Force recalculation of list[PackageVersionsInfo]. + self._package_versions_infos = None + return package async def unregister_package(self, package_hash: str, source: BaseCollector) -> None: @@ -158,6 +189,9 @@ async def unregister_package(self, package_hash: str, source: BaseCollector) -> if not package_versions: self._index_by_identifier.pop(package.manifest.identifier, None) + # Force recalculation of list[PackageVersionsInfo]. + self._package_versions_infos = None + if len(package.sources) == 0: # Package has no more sources; remove it from the index. self._index_by_hash.pop(package_hash, None) diff --git a/questionpy_server/collector/package_collection.py b/questionpy_server/collector/package_collection.py index 3e2fa39e..a117c712 100644 --- a/questionpy_server/collector/package_collection.py +++ b/questionpy_server/collector/package_collection.py @@ -10,6 +10,7 @@ from pydantic import HttpUrl from questionpy_server import WorkerPool +from questionpy_server.api.models import PackageVersionsInfo from questionpy_server.cache import FileLimitLRU from questionpy_server.collector.indexer import Indexer from questionpy_server.collector.lms_collector import LMSCollector @@ -121,10 +122,10 @@ def get_by_identifier_and_version(self, identifier: str, version: SemVer) -> "Pa raise FileNotFoundError - def get_packages(self) -> set["Package"]: - """Returns a set of all available packages. + def get_package_versions_infos(self) -> list[PackageVersionsInfo]: + """Returns an overview of every package and its versions. Returns: - set of packages + list of PackageVersionsInfo """ - return self._indexer.get_packages() + return self._indexer.get_package_versions_infos() diff --git a/questionpy_server/factories/__init__.py b/questionpy_server/factories/__init__.py index 4b602a30..bbcc2bc2 100644 --- a/questionpy_server/factories/__init__.py +++ b/questionpy_server/factories/__init__.py @@ -3,7 +3,7 @@ # (c) Technische Universität Berlin, innoCampus from .attempt import AttemptScoredFactory -from .package import PackageInfoFactory +from .package import PackageVersionInfoFactory from .question_state import RequestBaseDataFactory -__all__ = ["AttemptScoredFactory", "PackageInfoFactory", "RequestBaseDataFactory"] +__all__ = ["AttemptScoredFactory", "PackageVersionInfoFactory", "RequestBaseDataFactory"] diff --git a/questionpy_server/factories/package.py b/questionpy_server/factories/package.py index 32a461cf..f6d37116 100644 --- a/questionpy_server/factories/package.py +++ b/questionpy_server/factories/package.py @@ -6,14 +6,14 @@ from faker import Faker from polyfactory.factories.pydantic_factory import ModelFactory -from questionpy_server.api.models import PackageInfo +from questionpy_server.api.models import PackageVersionInfo languages = ["en", "de"] fake = Faker() -class PackageInfoFactory(ModelFactory): - __model__ = PackageInfo +class PackageVersionInfoFactory(ModelFactory): + __model__ = PackageVersionInfo @staticmethod def author() -> str: diff --git a/questionpy_server/package.py b/questionpy_server/package.py index 2cbc1e3b..cdbd6e57 100644 --- a/questionpy_server/package.py +++ b/questionpy_server/package.py @@ -5,7 +5,7 @@ import contextlib from pathlib import Path -from questionpy_server.api.models import PackageInfo +from questionpy_server.api.models import PackageVersionInfo from questionpy_server.collector.abc import BaseCollector from questionpy_server.collector.lms_collector import LMSCollector from questionpy_server.collector.local_collector import LocalCollector @@ -99,7 +99,7 @@ class Package: sources: PackageSources - _info: PackageInfo | None + _info: PackageVersionInfo | None _path: Path | None def __init__( @@ -127,7 +127,7 @@ def __eq__(self, other: object) -> bool: return NotImplemented return self.hash == other.hash - def get_info(self) -> PackageInfo: + def get_info(self) -> PackageVersionInfo: """Returns the package info. Returns: @@ -136,7 +136,7 @@ def get_info(self) -> PackageInfo: if not self._info: tmp = self.manifest.model_dump() tmp["version"] = str(tmp["version"]) - self._info = PackageInfo(**tmp, package_hash=self.hash) + self._info = PackageVersionInfo(**tmp, package_hash=self.hash) return self._info async def get_path(self) -> Path: diff --git a/questionpy_server/web.py b/questionpy_server/web.py index a6e15e24..8e0332b7 100644 --- a/questionpy_server/web.py +++ b/questionpy_server/web.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from hashlib import sha256 from io import BytesIO -from json import JSONDecodeError, loads +from json import JSONDecodeError, dumps, loads from typing import TYPE_CHECKING, Literal, NamedTuple, overload from aiohttp import BodyPartReader @@ -13,7 +13,9 @@ from aiohttp.log import web_logger from aiohttp.web_exceptions import HTTPBadRequest, HTTPRequestEntityTooLarge from aiohttp.web_response import Response +from aiohttp.web_response import json_response as aiohttp_json_response from pydantic import BaseModel, ValidationError +from pydantic_core import to_jsonable_python from questionpy_common import constants from questionpy_common.constants import KiB @@ -36,10 +38,7 @@ def json_response(data: Sequence[BaseModel] | BaseModel, status: int = 200) -> R Returns: Response: A response object. """ - if isinstance(data, Sequence): - json_list = f'[{",".join(x.json() for x in data)}]' - return Response(text=json_list, status=status, content_type="application/json") - return Response(text=data.model_dump_json(), status=status, content_type="application/json") + return aiohttp_json_response(data, status=status, dumps=lambda model: dumps(model, default=to_jsonable_python)) def create_model_from_json(json: object | str, param_class: type[M]) -> M: diff --git a/tests/questionpy_server/api/test_models.py b/tests/questionpy_server/api/test_models.py index 282ff51d..2352707d 100644 --- a/tests/questionpy_server/api/test_models.py +++ b/tests/questionpy_server/api/test_models.py @@ -1,23 +1,87 @@ # This file is part of the QuestionPy Server. (https://questionpy.org) # The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus - +from hashlib import sha256 from io import BytesIO +from itertools import pairwise, starmap +from operator import ge +from unittest.mock import Mock +import pytest from aiohttp import FormData +from aiohttp.pytest_plugin import AiohttpClient from aiohttp.test_utils import TestClient from pydantic import TypeAdapter -from questionpy_server.api.models import PackageInfo +from questionpy_server.api.models import PackageVersionInfo, PackageVersionsInfo +from questionpy_server.app import QPyServer +from questionpy_server.collector.local_collector import LocalCollector +from questionpy_server.utils.manifest import ComparableManifest from tests.conftest import PACKAGE +from tests.test_data.factories import ManifestFactory + + +@pytest.mark.parametrize( + "packages", + [ + # No packages. + {}, + # One package. + {"ns1": {"0.1.0"}}, + # Two packages. + {"ns1": {"0.1.0"}, "ns2": {"0.1.0"}}, + # Multiple versions. + {"ns1": {"1.0.0", "0.0.1"}, "ns2": {"1.0.0", "0.1.0", "0.0.1"}}, + # Multiple versions, unsorted. + {"ns1": {"0.0.1", "1.0.0"}, "ns2": {"0.1.0", "0.0.1", "1.0.0"}}, + ], +) +async def test_packages(qpy_server: QPyServer, aiohttp_client: AiohttpClient, packages: dict[str, set[str]]) -> None: + async def add_package_version(server: QPyServer, manifest: ComparableManifest) -> None: + package_hash = sha256((manifest.short_name + manifest.namespace + str(manifest.version)).encode()).hexdigest() + await server.package_collection._indexer.register_package(package_hash, manifest, Mock(spec=LocalCollector)) + manifests: dict[str, dict[str, ComparableManifest]] = {} + for namespace, versions in packages.items(): + for version in versions: + expected_manifest = ManifestFactory.build(namespace=namespace, short_name=namespace, version=version) + manifests.setdefault(namespace, {})[version] = expected_manifest + await add_package_version(qpy_server, expected_manifest) -async def test_packages(client: TestClient) -> None: + client = await aiohttp_client(qpy_server.web_app) res = await client.request("GET", "/packages") + # Assert that a valid list of PackageVersionsInfo is returned. assert res.status == 200 data = await res.json() - TypeAdapter(list[PackageInfo]).validate_python(data) + package_versions_infos: list[PackageVersionsInfo] = TypeAdapter(list[PackageVersionsInfo]).validate_python(data) + + expected_package_count = len(packages) + assert len(package_versions_infos) == expected_package_count + + if expected_package_count <= 0: + return + + actual_namespaces = [] + + # Iterate over all actual packages. + for package_versions_info in package_versions_infos: + actual_package_info = package_versions_info.manifest + actual_versions = [version.version for version in package_versions_info.versions] + # Assert that each package version is available and in the correct order. + assert set(actual_versions) == packages[actual_package_info.namespace] + assert all(starmap(ge, pairwise(actual_versions))), "The package versions are not sorted in descending order." + # Assert that the actual package info is a subset of the manifest of the latest package version. + actual_package_info_items = actual_package_info.model_dump().items() + latest_manifest_items = manifests[actual_package_info.namespace][actual_versions[0]].model_dump().items() + assert actual_package_info_items <= latest_manifest_items, ( + "Actual package info was not derived from the " "latest package version." + ) + + actual_namespaces.append(actual_package_info.namespace) + + # Assert that every expected package is returned. + assert set(actual_namespaces) == packages.keys() async def test_extract_info(client: TestClient) -> None: @@ -29,7 +93,7 @@ async def test_extract_info(client: TestClient) -> None: assert res.status == 201 data = await res.json() - PackageInfo.model_validate(data) + PackageVersionInfo.model_validate(data) async def test_extract_info_faulty(client: TestClient) -> None: diff --git a/tests/questionpy_server/collector/test_indexer.py b/tests/questionpy_server/collector/test_indexer.py index 83712526..36a48473 100644 --- a/tests/questionpy_server/collector/test_indexer.py +++ b/tests/questionpy_server/collector/test_indexer.py @@ -49,7 +49,7 @@ async def test_register_package_from_lms(collector: LMSCollector) -> None: assert len(packages_by_identifier) == 0 # Package is not accessible by retrieving all packages. - packages = indexer.get_packages() + packages = indexer.get_package_versions_infos() assert len(packages) == 0 @@ -78,9 +78,9 @@ async def test_register_package_from_local_and_repo_collector(collector: BaseCol assert packages_by_identifier[package.manifest.version] is package # Package is accessible by retrieving all packages. - packages = indexer.get_packages() + packages = indexer.get_package_versions_infos() assert len(packages) == 1 - assert next(iter(packages)) is package + assert next(iter(packages)).manifest.model_dump().items() <= package.manifest.model_dump().items() async def test_register_package_with_same_hash_as_existing_package() -> None: @@ -109,9 +109,9 @@ async def test_register_package_with_same_hash_as_existing_package() -> None: assert len(packages_by_identifier) == 1 assert packages_by_identifier[package.manifest.version] is package - packages = indexer.get_packages() + packages = indexer.get_package_versions_infos() assert len(packages) == 1 - assert next(iter(packages)) is package + assert next(iter(packages)).manifest.model_dump().items() <= package.manifest.model_dump().items() async def test_register_two_packages_with_same_manifest_but_different_hashes(caplog: pytest.LogCaptureFixture) -> None: diff --git a/tests/questionpy_server/collector/test_package_collection.py b/tests/questionpy_server/collector/test_package_collection.py index 44b2367d..5e7e7187 100644 --- a/tests/questionpy_server/collector/test_package_collection.py +++ b/tests/questionpy_server/collector/test_package_collection.py @@ -87,9 +87,9 @@ def test_get_packages() -> None: package_collection = PackageCollection(None, {}, Mock(), Mock(), Mock()) # Package does exist. - with patch.object(Indexer, "get_packages") as get_packages: - package_collection.get_packages() - get_packages.assert_called_once() + with patch.object(Indexer, "get_package_versions_infos") as get_package_versions_infos: + package_collection.get_package_versions_infos() + get_package_versions_infos.assert_called_once() async def test_notify_indexer_on_cache_deletion(tmp_path_factory: TempPathFactory) -> None: diff --git a/tests/test_data/factories.py b/tests/test_data/factories.py index 3d1801fb..ddeb3d0d 100644 --- a/tests/test_data/factories.py +++ b/tests/test_data/factories.py @@ -5,6 +5,7 @@ from collections.abc import Callable from typing import Any +from polyfactory import Use from polyfactory.factories.pydantic_factory import ModelFactory from semver import Version @@ -32,3 +33,8 @@ class RepoPackageVersionsFactory(CustomFactory): class ManifestFactory(CustomFactory): __model__ = ComparableManifest + + short_name = Use(lambda: ModelFactory.__faker__.word().lower() + "_sn") + namespace = Use(lambda: ModelFactory.__faker__.word().lower() + "_ns") + url = Use(ModelFactory.__faker__.url) + icon = None