From 7acf10819d93b4810daa2e8e46996faccc17ed77 Mon Sep 17 00:00:00 2001 From: Maximilian Jugl Date: Thu, 16 May 2024 10:47:42 +0200 Subject: [PATCH] feat: implement upload of temporary files to hub --- project/routers/scratch.py | 105 ++++++++++++++++++++++++++++--------- tests/test_scratch.py | 27 ++++++---- 2 files changed, 99 insertions(+), 33 deletions(-) diff --git a/project/routers/scratch.py b/project/routers/scratch.py index 4fb7eef..5e32058 100644 --- a/project/routers/scratch.py +++ b/project/routers/scratch.py @@ -1,25 +1,69 @@ +import io import logging import uuid -from typing import Annotated +from typing import Annotated, Optional -from fastapi import APIRouter, UploadFile, Depends, HTTPException -from minio import Minio, S3Error +from fastapi import APIRouter, UploadFile, Depends, HTTPException, BackgroundTasks +from minio import Minio from pydantic import BaseModel, HttpUrl from starlette import status from starlette.requests import Request from starlette.responses import StreamingResponse from project.config import Settings -from project.dependencies import get_settings, get_local_minio, get_client_id +from project.dependencies import ( + get_settings, + get_local_minio, + get_client_id, + get_access_token, +) +from project.hub import AccessToken, ApiWrapper router = APIRouter() logger = logging.getLogger(__name__) +# TODO fix this jank +object_id_to_hub_bucket_dict: dict[str, Optional[str]] = {} + class ScratchUploadResponse(BaseModel): url: HttpUrl +def __bg_upload_to_remote( + minio: Minio, + bucket_name: str, + object_name: str, + api: ApiWrapper, + client_id: str, + object_id: str, +): + logger.info( + "__bg_upload_to_remote: bucket `%s`, object `%s`", bucket_name, object_name + ) + + minio_resp = None + + try: + minio_resp = minio.get_object(bucket_name, object_name) + bucket_file_lst = api.upload_to_bucket( + f"analysis-temp-files.{client_id}", + object_name, + io.BytesIO(minio_resp.data), + minio_resp.headers.get("Content-Type", "application/octet-stream"), + ) + + assert len(bucket_file_lst) == 1 + bucket_file = bucket_file_lst[0] + api.link_file_to_analysis(client_id, bucket_file.id, bucket_file.name, "TEMP") + object_id_to_hub_bucket_dict[object_id] = bucket_file.id + minio.remove_object(bucket_name, object_name) + finally: + if minio is not None: + minio_resp.close() + minio_resp.release_conn() + + @router.put( "/", response_model=ScratchUploadResponse, @@ -32,6 +76,8 @@ async def upload_to_scratch( settings: Annotated[Settings, Depends(get_settings)], minio: Annotated[Minio, Depends(get_local_minio)], request: Request, + api_access_token: Annotated[AccessToken, Depends(get_access_token)], + background_tasks: BackgroundTasks, ): """Upload a file to the local S3 instance. The file is not forwarded to the FLAME hub. @@ -40,15 +86,29 @@ async def upload_to_scratch( This endpoint is to be used for submitting intermediate results of a federated analysis. """ object_id = str(uuid.uuid4()) + object_name = f"temp/{client_id}/{object_id}" minio.put_object( settings.minio.bucket, - f"scratch/{client_id}/{object_id}", + object_name, data=file.file, length=file.size, content_type=file.content_type or "application/octet-stream", ) + api = ApiWrapper(str(settings.hub.api_base_url), api_access_token.access_token) + object_id_to_hub_bucket_dict[object_id] = None + + background_tasks.add_task( + __bg_upload_to_remote, + minio, + settings.minio.bucket, + object_name, + api, + client_id, + object_id, + ) + return ScratchUploadResponse( url=str( request.url_for( @@ -68,7 +128,7 @@ async def read_from_scratch( client_id: Annotated[str, Depends(get_client_id)], object_id: uuid.UUID, settings: Annotated[Settings, Depends(get_settings)], - minio: Annotated[Minio, Depends(get_local_minio)], + api_access_token: Annotated[AccessToken, Depends(get_access_token)], ): """Get a file from the local S3 instance. The file must have previously been uploaded using the PUT method of this endpoint. @@ -76,25 +136,22 @@ async def read_from_scratch( This endpoint is to be used for retrieving intermediate results of a federated analysis. """ - try: - response = minio.get_object( - settings.minio.bucket, f"scratch/{client_id}/{object_id}" - ) - except S3Error as e: - logger.exception(f"Could not get object `{object_id}` for client `{client_id}`") - - if e.code == "NoSuchKey": - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Object with ID {object_id} does not exist", - ) + api = ApiWrapper(str(settings.hub.api_base_url), api_access_token.access_token) + oid = str(object_id) + if ( + oid not in object_id_to_hub_bucket_dict + or object_id_to_hub_bucket_dict[oid] is None + ): raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail="Unexpected error from object store", + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Object with ID {oid} does not exist", ) - return StreamingResponse( - response, - media_type=response.headers.get("Content-Type", "application/octet-stream"), - ) + bucket_file_id = object_id_to_hub_bucket_dict[oid] + + async def _stream_bucket_file(): + for b in api.stream_bucket_file(bucket_file_id): + yield b + + return StreamingResponse(_stream_bucket_file()) diff --git a/tests/test_scratch.py b/tests/test_scratch.py index 8907f9b..2a2b258 100644 --- a/tests/test_scratch.py +++ b/tests/test_scratch.py @@ -1,19 +1,22 @@ import re import uuid +import pytest from starlette import status from project.routers.scratch import ScratchUploadResponse from tests.common.auth import BearerAuth, issue_client_access_token -from tests.common.helpers import next_random_bytes +from tests.common.helpers import next_random_bytes, eventually from tests.common.rest import wrap_bytes_for_request, detail_of +pytestmark = pytest.mark.live -def test_200_submit_receive_from_scratch(test_client, rng): + +def test_200_submit_receive_from_scratch(test_client, rng, analysis_id): blob = next_random_bytes(rng) r = test_client.put( "/scratch", - auth=BearerAuth(issue_client_access_token()), + auth=BearerAuth(issue_client_access_token(analysis_id)), files=wrap_bytes_for_request(blob), ) @@ -27,13 +30,19 @@ def test_200_submit_receive_from_scratch(test_client, rng): assert re.fullmatch(path_regex, model.url.path) is not None - r = test_client.get( - model.url.path, - auth=BearerAuth(issue_client_access_token()), - ) + def _check_response_from_hub(): + r = test_client.get( + model.url.path, + auth=BearerAuth(issue_client_access_token()), + ) - assert r.status_code == status.HTTP_200_OK - assert r.read() == blob + if r.status_code != status.HTTP_200_OK: + return False + + assert r.read() == blob + return True + + assert eventually(_check_response_from_hub) def test_whatever(test_client):