Skip to content

Commit

Permalink
Merge pull request #38 from PrivateAIM/34-implement-temporary-file-up…
Browse files Browse the repository at this point in the history
…load-to-hub

Implement temporary file upload and download to Hub
  • Loading branch information
mjugl authored May 16, 2024
2 parents 6c63ad9 + 7acf108 commit 1b4bdda
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 39 deletions.
19 changes: 16 additions & 3 deletions project/hub.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from io import BytesIO
from typing import NamedTuple
from typing import NamedTuple, Literal
from urllib.parse import urljoin

import httpx
Expand Down Expand Up @@ -150,6 +150,15 @@ def get_bucket_file(self, bucket_file_id: str) -> BucketFile | None:
bucket_id=j["bucket_id"],
)

def stream_bucket_file(self, bucket_file_id: str):
with httpx.stream(
"GET",
urljoin(self.base_url, f"/storage/bucket-files/{bucket_file_id}/stream"),
headers=self.__auth_header(),
) as r:
for b in r.iter_bytes(chunk_size=1024):
yield b

def upload_to_bucket(
self,
bucket_name: str,
Expand Down Expand Up @@ -178,7 +187,11 @@ def upload_to_bucket(
]

def link_file_to_analysis(
self, analysis_id: str, bucket_file_id: str, bucket_file_name: str
self,
analysis_id: str,
bucket_file_id: str,
bucket_file_name: str,
bucket_file_type: Literal["CODE", "RESULT", "TEMP"],
) -> AnalysisFile:
"""Link the file associated with the given ID and name to the analysis associated with the given ID.
Currently, this function only supports linking result files."""
Expand All @@ -187,7 +200,7 @@ def link_file_to_analysis(
headers=self.__auth_header(),
json={
"analysis_id": analysis_id,
"type": "RESULT",
"type": bucket_file_type,
"bucket_file_id": bucket_file_id,
"name": bucket_file_name,
"root": True,
Expand Down
105 changes: 81 additions & 24 deletions project/routers/scratch.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,69 @@
import io
import logging
import uuid
from typing import Annotated
from typing import Annotated, Optional

from fastapi import APIRouter, UploadFile, Depends, HTTPException
from minio import Minio, S3Error
from fastapi import APIRouter, UploadFile, Depends, HTTPException, BackgroundTasks
from minio import Minio
from pydantic import BaseModel, HttpUrl
from starlette import status
from starlette.requests import Request
from starlette.responses import StreamingResponse

from project.config import Settings
from project.dependencies import get_settings, get_local_minio, get_client_id
from project.dependencies import (
get_settings,
get_local_minio,
get_client_id,
get_access_token,
)
from project.hub import AccessToken, ApiWrapper

router = APIRouter()
logger = logging.getLogger(__name__)

# TODO fix this jank
object_id_to_hub_bucket_dict: dict[str, Optional[str]] = {}


class ScratchUploadResponse(BaseModel):
url: HttpUrl


def __bg_upload_to_remote(
minio: Minio,
bucket_name: str,
object_name: str,
api: ApiWrapper,
client_id: str,
object_id: str,
):
logger.info(
"__bg_upload_to_remote: bucket `%s`, object `%s`", bucket_name, object_name
)

minio_resp = None

try:
minio_resp = minio.get_object(bucket_name, object_name)
bucket_file_lst = api.upload_to_bucket(
f"analysis-temp-files.{client_id}",
object_name,
io.BytesIO(minio_resp.data),
minio_resp.headers.get("Content-Type", "application/octet-stream"),
)

assert len(bucket_file_lst) == 1
bucket_file = bucket_file_lst[0]
api.link_file_to_analysis(client_id, bucket_file.id, bucket_file.name, "TEMP")
object_id_to_hub_bucket_dict[object_id] = bucket_file.id
minio.remove_object(bucket_name, object_name)
finally:
if minio is not None:
minio_resp.close()
minio_resp.release_conn()


@router.put(
"/",
response_model=ScratchUploadResponse,
Expand All @@ -32,6 +76,8 @@ async def upload_to_scratch(
settings: Annotated[Settings, Depends(get_settings)],
minio: Annotated[Minio, Depends(get_local_minio)],
request: Request,
api_access_token: Annotated[AccessToken, Depends(get_access_token)],
background_tasks: BackgroundTasks,
):
"""Upload a file to the local S3 instance.
The file is not forwarded to the FLAME hub.
Expand All @@ -40,15 +86,29 @@ async def upload_to_scratch(
This endpoint is to be used for submitting intermediate results of a federated analysis.
"""
object_id = str(uuid.uuid4())
object_name = f"temp/{client_id}/{object_id}"

minio.put_object(
settings.minio.bucket,
f"scratch/{client_id}/{object_id}",
object_name,
data=file.file,
length=file.size,
content_type=file.content_type or "application/octet-stream",
)

api = ApiWrapper(str(settings.hub.api_base_url), api_access_token.access_token)
object_id_to_hub_bucket_dict[object_id] = None

background_tasks.add_task(
__bg_upload_to_remote,
minio,
settings.minio.bucket,
object_name,
api,
client_id,
object_id,
)

return ScratchUploadResponse(
url=str(
request.url_for(
Expand All @@ -68,33 +128,30 @@ async def read_from_scratch(
client_id: Annotated[str, Depends(get_client_id)],
object_id: uuid.UUID,
settings: Annotated[Settings, Depends(get_settings)],
minio: Annotated[Minio, Depends(get_local_minio)],
api_access_token: Annotated[AccessToken, Depends(get_access_token)],
):
"""Get a file from the local S3 instance.
The file must have previously been uploaded using the PUT method of this endpoint.
Responds with a 200 on success and the requested file in the response body.
This endpoint is to be used for retrieving intermediate results of a federated analysis.
"""
try:
response = minio.get_object(
settings.minio.bucket, f"scratch/{client_id}/{object_id}"
)
except S3Error as e:
logger.exception(f"Could not get object `{object_id}` for client `{client_id}`")

if e.code == "NoSuchKey":
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Object with ID {object_id} does not exist",
)
api = ApiWrapper(str(settings.hub.api_base_url), api_access_token.access_token)
oid = str(object_id)

if (
oid not in object_id_to_hub_bucket_dict
or object_id_to_hub_bucket_dict[oid] is None
):
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Unexpected error from object store",
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Object with ID {oid} does not exist",
)

return StreamingResponse(
response,
media_type=response.headers.get("Content-Type", "application/octet-stream"),
)
bucket_file_id = object_id_to_hub_bucket_dict[oid]

async def _stream_bucket_file():
for b in api.stream_bucket_file(bucket_file_id):
yield b

return StreamingResponse(_stream_bucket_file())
2 changes: 1 addition & 1 deletion project/routers/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __bg_upload_to_remote(
# fetch file s.t. it can be linked
bucket_file = bucket_file_lst[0]
# link file to analysis
api.link_file_to_analysis(client_id, bucket_file.id, bucket_file.name)
api.link_file_to_analysis(client_id, bucket_file.id, bucket_file.name, "RESULT")
# remove from local minio
minio.remove_object(bucket_name, object_name)
finally:
Expand Down
10 changes: 8 additions & 2 deletions tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ def test_upload_to_bucket(api, rng, analysis_id):
# 1) upload file with random name to bucket
result_bucket_name = f"analysis-result-files.{analysis_id}"
file_name = next_prefixed_name()
file_blob = next_random_bytes(rng)

bucket_file_lst = api.upload_to_bucket(
result_bucket_name, file_name, BytesIO(next_random_bytes(rng))
result_bucket_name, file_name, BytesIO(file_blob)
)

# 2) check that the endpoint returned a single file
Expand All @@ -26,7 +28,7 @@ def test_upload_to_bucket(api, rng, analysis_id):

# 4) link uploaded file to analysis
analysis_file = api.link_file_to_analysis(
analysis_id, bucket_file.id, bucket_file.name
analysis_id, bucket_file.id, bucket_file.name, "RESULT"
)

assert analysis_file.name == bucket_file.name
Expand All @@ -38,3 +40,7 @@ def test_upload_to_bucket(api, rng, analysis_id):
analysis_file_list = api.get_analysis_files()

assert analysis_file.id in [f.id for f in analysis_file_list]

# 6) download the file and check that it's identical with the submitted bytes
bucket_file_data = next(api.stream_bucket_file(bucket_file.id))
assert bucket_file_data == file_blob
27 changes: 18 additions & 9 deletions tests/test_scratch.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import re
import uuid

import pytest
from starlette import status

from project.routers.scratch import ScratchUploadResponse
from tests.common.auth import BearerAuth, issue_client_access_token
from tests.common.helpers import next_random_bytes
from tests.common.helpers import next_random_bytes, eventually
from tests.common.rest import wrap_bytes_for_request, detail_of

pytestmark = pytest.mark.live

def test_200_submit_receive_from_scratch(test_client, rng):

def test_200_submit_receive_from_scratch(test_client, rng, analysis_id):
blob = next_random_bytes(rng)
r = test_client.put(
"/scratch",
auth=BearerAuth(issue_client_access_token()),
auth=BearerAuth(issue_client_access_token(analysis_id)),
files=wrap_bytes_for_request(blob),
)

Expand All @@ -27,13 +30,19 @@ def test_200_submit_receive_from_scratch(test_client, rng):

assert re.fullmatch(path_regex, model.url.path) is not None

r = test_client.get(
model.url.path,
auth=BearerAuth(issue_client_access_token()),
)
def _check_response_from_hub():
r = test_client.get(
model.url.path,
auth=BearerAuth(issue_client_access_token()),
)

assert r.status_code == status.HTTP_200_OK
assert r.read() == blob
if r.status_code != status.HTTP_200_OK:
return False

assert r.read() == blob
return True

assert eventually(_check_response_from_hub)


def test_whatever(test_client):
Expand Down

0 comments on commit 1b4bdda

Please sign in to comment.