Skip to content

Commit

Permalink
feat: implement upload of temporary files to hub
Browse files Browse the repository at this point in the history
  • Loading branch information
mjugl committed May 16, 2024
1 parent 520a6dd commit 7acf108
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 33 deletions.
105 changes: 81 additions & 24 deletions project/routers/scratch.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,69 @@
import io
import logging
import uuid
from typing import Annotated
from typing import Annotated, Optional

from fastapi import APIRouter, UploadFile, Depends, HTTPException
from minio import Minio, S3Error
from fastapi import APIRouter, UploadFile, Depends, HTTPException, BackgroundTasks
from minio import Minio
from pydantic import BaseModel, HttpUrl
from starlette import status
from starlette.requests import Request
from starlette.responses import StreamingResponse

from project.config import Settings
from project.dependencies import get_settings, get_local_minio, get_client_id
from project.dependencies import (
get_settings,
get_local_minio,
get_client_id,
get_access_token,
)
from project.hub import AccessToken, ApiWrapper

router = APIRouter()
logger = logging.getLogger(__name__)

# TODO fix this jank
object_id_to_hub_bucket_dict: dict[str, Optional[str]] = {}


class ScratchUploadResponse(BaseModel):
url: HttpUrl


def __bg_upload_to_remote(
minio: Minio,
bucket_name: str,
object_name: str,
api: ApiWrapper,
client_id: str,
object_id: str,
):
logger.info(
"__bg_upload_to_remote: bucket `%s`, object `%s`", bucket_name, object_name
)

minio_resp = None

try:
minio_resp = minio.get_object(bucket_name, object_name)
bucket_file_lst = api.upload_to_bucket(
f"analysis-temp-files.{client_id}",
object_name,
io.BytesIO(minio_resp.data),
minio_resp.headers.get("Content-Type", "application/octet-stream"),
)

assert len(bucket_file_lst) == 1
bucket_file = bucket_file_lst[0]
api.link_file_to_analysis(client_id, bucket_file.id, bucket_file.name, "TEMP")
object_id_to_hub_bucket_dict[object_id] = bucket_file.id
minio.remove_object(bucket_name, object_name)
finally:
if minio is not None:
minio_resp.close()
minio_resp.release_conn()


@router.put(
"/",
response_model=ScratchUploadResponse,
Expand All @@ -32,6 +76,8 @@ async def upload_to_scratch(
settings: Annotated[Settings, Depends(get_settings)],
minio: Annotated[Minio, Depends(get_local_minio)],
request: Request,
api_access_token: Annotated[AccessToken, Depends(get_access_token)],
background_tasks: BackgroundTasks,
):
"""Upload a file to the local S3 instance.
The file is not forwarded to the FLAME hub.
Expand All @@ -40,15 +86,29 @@ async def upload_to_scratch(
This endpoint is to be used for submitting intermediate results of a federated analysis.
"""
object_id = str(uuid.uuid4())
object_name = f"temp/{client_id}/{object_id}"

minio.put_object(
settings.minio.bucket,
f"scratch/{client_id}/{object_id}",
object_name,
data=file.file,
length=file.size,
content_type=file.content_type or "application/octet-stream",
)

api = ApiWrapper(str(settings.hub.api_base_url), api_access_token.access_token)
object_id_to_hub_bucket_dict[object_id] = None

background_tasks.add_task(
__bg_upload_to_remote,
minio,
settings.minio.bucket,
object_name,
api,
client_id,
object_id,
)

return ScratchUploadResponse(
url=str(
request.url_for(
Expand All @@ -68,33 +128,30 @@ async def read_from_scratch(
client_id: Annotated[str, Depends(get_client_id)],
object_id: uuid.UUID,
settings: Annotated[Settings, Depends(get_settings)],
minio: Annotated[Minio, Depends(get_local_minio)],
api_access_token: Annotated[AccessToken, Depends(get_access_token)],
):
"""Get a file from the local S3 instance.
The file must have previously been uploaded using the PUT method of this endpoint.
Responds with a 200 on success and the requested file in the response body.
This endpoint is to be used for retrieving intermediate results of a federated analysis.
"""
try:
response = minio.get_object(
settings.minio.bucket, f"scratch/{client_id}/{object_id}"
)
except S3Error as e:
logger.exception(f"Could not get object `{object_id}` for client `{client_id}`")

if e.code == "NoSuchKey":
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Object with ID {object_id} does not exist",
)
api = ApiWrapper(str(settings.hub.api_base_url), api_access_token.access_token)
oid = str(object_id)

if (
oid not in object_id_to_hub_bucket_dict
or object_id_to_hub_bucket_dict[oid] is None
):
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Unexpected error from object store",
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Object with ID {oid} does not exist",
)

return StreamingResponse(
response,
media_type=response.headers.get("Content-Type", "application/octet-stream"),
)
bucket_file_id = object_id_to_hub_bucket_dict[oid]

async def _stream_bucket_file():
for b in api.stream_bucket_file(bucket_file_id):
yield b

return StreamingResponse(_stream_bucket_file())
27 changes: 18 additions & 9 deletions tests/test_scratch.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import re
import uuid

import pytest
from starlette import status

from project.routers.scratch import ScratchUploadResponse
from tests.common.auth import BearerAuth, issue_client_access_token
from tests.common.helpers import next_random_bytes
from tests.common.helpers import next_random_bytes, eventually
from tests.common.rest import wrap_bytes_for_request, detail_of

pytestmark = pytest.mark.live

def test_200_submit_receive_from_scratch(test_client, rng):

def test_200_submit_receive_from_scratch(test_client, rng, analysis_id):
blob = next_random_bytes(rng)
r = test_client.put(
"/scratch",
auth=BearerAuth(issue_client_access_token()),
auth=BearerAuth(issue_client_access_token(analysis_id)),
files=wrap_bytes_for_request(blob),
)

Expand All @@ -27,13 +30,19 @@ def test_200_submit_receive_from_scratch(test_client, rng):

assert re.fullmatch(path_regex, model.url.path) is not None

r = test_client.get(
model.url.path,
auth=BearerAuth(issue_client_access_token()),
)
def _check_response_from_hub():
r = test_client.get(
model.url.path,
auth=BearerAuth(issue_client_access_token()),
)

assert r.status_code == status.HTTP_200_OK
assert r.read() == blob
if r.status_code != status.HTTP_200_OK:
return False

assert r.read() == blob
return True

assert eventually(_check_response_from_hub)


def test_whatever(test_client):
Expand Down

0 comments on commit 7acf108

Please sign in to comment.