Skip to content

Commit

Permalink
Merge pull request #8 from gdcc/fix-singlepart-direct-upload
Browse files Browse the repository at this point in the history
Fix singlepart direct upload and direct upload file registration
  • Loading branch information
JR-1991 authored Mar 4, 2024
2 parents d760934 + 4bf1123 commit 3a4df04
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 63 deletions.
153 changes: 109 additions & 44 deletions dvuploader/directupload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,22 @@
import aiohttp

from dvuploader.file import File
from dvuploader.nativeupload import file_sender
from dvuploader.utils import build_url

TESTING = bool(os.environ.get("DVUPLOADER_TESTING", False))
MAX_FILE_DISPLAY = int(os.environ.get("DVUPLOADER_MAX_FILE_DISPLAY", 50))
MAX_RETRIES = int(os.environ.get("DVUPLOADER_MAX_RETRIES", 10))

assert isinstance(
MAX_FILE_DISPLAY, int
), "DVUPLOADER_MAX_FILE_DISPLAY must be an integer"

assert isinstance(MAX_RETRIES, int), "DVUPLOADER_MAX_RETRIES must be an integer"

TICKET_ENDPOINT = "/api/datasets/:persistentId/uploadurls"
ADD_FILE_ENDPOINT = "/api/datasets/:persistentId/addFiles"
UPLOAD_ENDPOINT = "/api/datasets/:persistentId/add?persistentId="
REPLACE_ENDPOINT = "/api/files/{FILE_ID}/replace"
UPLOAD_ENDPOINT = "/api/datasets/:persistentId/addFiles?persistentId="
REPLACE_ENDPOINT = "/api/datasets/:persistentId/replaceFiles?persistentId="


async def direct_upload(
Expand Down Expand Up @@ -44,6 +51,7 @@ async def direct_upload(
None
"""

leave_bar = len(files) < MAX_FILE_DISPLAY
connector = aiohttp.TCPConnector(limit=n_parallel_uploads)
async with aiohttp.ClientSession(connector=connector) as session:
tasks = [
Expand All @@ -56,6 +64,7 @@ async def direct_upload(
pbar=pbar,
progress=progress,
delay=0.0,
leave_bar=leave_bar,
)
for pbar, file in zip(pbars, files)
]
Expand All @@ -73,28 +82,20 @@ async def direct_upload(
"x-amz-tagging": "dv-state=temp",
}

connector = aiohttp.TCPConnector(limit=4)
pbar = progress.add_task("╰── [bold white]Registering files", total=len(files))
results = []
pbar = progress.add_task("╰── [bold white]Registering files", total=1)
connector = aiohttp.TCPConnector(limit=2)
async with aiohttp.ClientSession(
headers=headers,
connector=connector,
) as session:
for file in files:
results.append(
await _add_file_to_ds(
session=session,
file=file,
dataverse_url=dataverse_url,
pid=persistent_id,
)
)

progress.update(pbar, advance=1)

for file, status in zip(files, results):
if status is False:
print(f"❌ Failed to register file '{file.fileName}' at Dataverse")
await _add_files_to_ds(
session=session,
files=files,
dataverse_url=dataverse_url,
pid=persistent_id,
progress=progress,
pbar=pbar,
)


async def _upload_to_store(
Expand All @@ -106,6 +107,7 @@ async def _upload_to_store(
pbar,
progress,
delay: float,
leave_bar: bool,
):
"""
Uploads a file to a Dataverse collection using direct upload.
Expand All @@ -119,6 +121,7 @@ async def _upload_to_store(
pbar: The progress bar object.
progress: The progress object.
delay (float): The delay in seconds before starting the upload.
leave_bar (bool): A flag indicating whether to keep the progress bar visible after the upload is complete.
Returns:
tuple: A tuple containing the upload status (bool) and the file object.
Expand Down Expand Up @@ -146,6 +149,7 @@ async def _upload_to_store(
pbar=pbar,
progress=progress,
api_token=api_token,
leave_bar=leave_bar,
)

else:
Expand Down Expand Up @@ -207,6 +211,7 @@ async def _upload_singlepart(
pbar,
progress,
api_token: str,
leave_bar: bool,
) -> Tuple[bool, str]:
"""
Uploads a single part of a file to a remote server using HTTP PUT method.
Expand All @@ -217,6 +222,7 @@ async def _upload_singlepart(
filepath (str): The path to the file to be uploaded.
pbar (tqdm): A progress bar object to track the upload progress.
progress: The progress object used to update the progress bar.
leave_bar (bool): A flag indicating whether to keep the progress bar visible after the upload is complete.
Returns:
Tuple[bool, str]: A tuple containing the status of the upload (True for success, False for failure)
Expand All @@ -235,19 +241,25 @@ async def _upload_singlepart(
params = {
"headers": headers,
"url": ticket["url"],
"data": file_sender(
file_name=filepath,
progress=progress,
pbar=pbar,
),
"data": open(filepath, "rb"),
}

async with session.put(**params) as response:
status = response.status == 200
response.raise_for_status()

if status:
progress.update(pbar, advance=os.path.getsize(filepath))
progress.update(
pbar,
advance=os.path.getsize(filepath),
)

await asyncio.sleep(0.1)

progress.update(
pbar,
visible=leave_bar,
)

return status, storage_identifier

Expand Down Expand Up @@ -463,12 +475,14 @@ async def _abort_upload(
response.raise_for_status()


async def _add_file_to_ds(
async def _add_files_to_ds(
session: aiohttp.ClientSession,
dataverse_url: str,
pid: str,
file: File,
) -> bool:
files: List[File],
progress,
pbar,
) -> None:
"""
Adds a file to a Dataverse dataset.
Expand All @@ -481,26 +495,77 @@ async def _add_file_to_ds(
Returns:
bool: True if the file was added successfully, False otherwise.
"""
if not file.to_replace:
url = urljoin(dataverse_url, UPLOAD_ENDPOINT + pid)
else:
url = build_url(
dataverse_url=dataverse_url,
endpoint=urljoin(
dataverse_url,
REPLACE_ENDPOINT.format(FILE_ID=file.file_id),
),
)

json_data = file.model_dump_json(
by_alias=True,
exclude={"to_replace", "file_id"},
novel_url = urljoin(dataverse_url, UPLOAD_ENDPOINT + pid)
replace_url = urljoin(dataverse_url, REPLACE_ENDPOINT + pid)

novel_json_data = _prepare_registration(files, use_replace=False)
replace_json_data = _prepare_registration(files, use_replace=True)

await _multipart_json_data_request(
session=session,
json_data=novel_json_data,
url=novel_url,
)

await _multipart_json_data_request(
session=session,
json_data=replace_json_data,
url=replace_url,
)

progress.update(pbar, advance=1)


def _prepare_registration(files: List[File], use_replace: bool) -> str:
"""
Prepares the files for registration at the Dataverse instance.
Args:
files (List[File]): The list of files to prepare.
Returns:
str: A JSON string containing the file data.
"""

exclude = {"to_replace"} if use_replace else {"to_replace", "file_id"}

return json.dumps(
[
file.model_dump(
by_alias=True,
exclude=exclude,
exclude_none=True,
)
for file in files
if file.to_replace is use_replace
],
indent=2,
)


async def _multipart_json_data_request(
json_data: str,
url: str,
session: aiohttp.ClientSession,
):
"""
Sends a multipart/form-data POST request with JSON data to the specified URL using the provided session.
Args:
json_data (str): The JSON data to be sent in the request body.
url (str): The URL to send the request to.
session (aiohttp.ClientSession): The aiohttp client session to use for the request.
Raises:
aiohttp.ClientResponseError: If the response status code is not successful.
Returns:
None
"""
with aiohttp.MultipartWriter("form-data") as writer:
json_part = writer.append(json_data)
json_part.set_content_disposition("form-data", name="jsonData")

async with session.post(url, data=writer) as response:
return response.status == 200
response.raise_for_status()
Loading

0 comments on commit 3a4df04

Please sign in to comment.