Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix singlepart direct upload #8

Merged
merged 15 commits into from
Mar 4, 2024
153 changes: 109 additions & 44 deletions dvuploader/directupload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,22 @@
import aiohttp

from dvuploader.file import File
from dvuploader.nativeupload import file_sender
from dvuploader.utils import build_url

TESTING = bool(os.environ.get("DVUPLOADER_TESTING", False))
MAX_FILE_DISPLAY = int(os.environ.get("DVUPLOADER_MAX_FILE_DISPLAY", 50))
MAX_RETRIES = int(os.environ.get("DVUPLOADER_MAX_RETRIES", 10))

assert isinstance(
MAX_FILE_DISPLAY, int
), "DVUPLOADER_MAX_FILE_DISPLAY must be an integer"

assert isinstance(MAX_RETRIES, int), "DVUPLOADER_MAX_RETRIES must be an integer"

TICKET_ENDPOINT = "/api/datasets/:persistentId/uploadurls"
ADD_FILE_ENDPOINT = "/api/datasets/:persistentId/addFiles"
UPLOAD_ENDPOINT = "/api/datasets/:persistentId/add?persistentId="
REPLACE_ENDPOINT = "/api/files/{FILE_ID}/replace"
UPLOAD_ENDPOINT = "/api/datasets/:persistentId/addFiles?persistentId="
REPLACE_ENDPOINT = "/api/datasets/:persistentId/replaceFiles?persistentId="


async def direct_upload(
Expand Down Expand Up @@ -44,6 +51,7 @@ async def direct_upload(
None
"""

leave_bar = len(files) < MAX_FILE_DISPLAY
connector = aiohttp.TCPConnector(limit=n_parallel_uploads)
async with aiohttp.ClientSession(connector=connector) as session:
tasks = [
Expand All @@ -56,6 +64,7 @@ async def direct_upload(
pbar=pbar,
progress=progress,
delay=0.0,
leave_bar=leave_bar,
)
for pbar, file in zip(pbars, files)
]
Expand All @@ -73,28 +82,20 @@ async def direct_upload(
"x-amz-tagging": "dv-state=temp",
}

connector = aiohttp.TCPConnector(limit=4)
pbar = progress.add_task("╰── [bold white]Registering files", total=len(files))
results = []
pbar = progress.add_task("╰── [bold white]Registering files", total=1)
connector = aiohttp.TCPConnector(limit=2)
async with aiohttp.ClientSession(
headers=headers,
connector=connector,
) as session:
for file in files:
results.append(
await _add_file_to_ds(
session=session,
file=file,
dataverse_url=dataverse_url,
pid=persistent_id,
)
)

progress.update(pbar, advance=1)

for file, status in zip(files, results):
if status is False:
print(f"❌ Failed to register file '{file.fileName}' at Dataverse")
await _add_files_to_ds(
session=session,
files=files,
dataverse_url=dataverse_url,
pid=persistent_id,
progress=progress,
pbar=pbar,
)


async def _upload_to_store(
Expand All @@ -106,6 +107,7 @@ async def _upload_to_store(
pbar,
progress,
delay: float,
leave_bar: bool,
):
"""
Uploads a file to a Dataverse collection using direct upload.
Expand All @@ -119,6 +121,7 @@ async def _upload_to_store(
pbar: The progress bar object.
progress: The progress object.
delay (float): The delay in seconds before starting the upload.
leave_bar (bool): A flag indicating whether to keep the progress bar visible after the upload is complete.

Returns:
tuple: A tuple containing the upload status (bool) and the file object.
Expand Down Expand Up @@ -146,6 +149,7 @@ async def _upload_to_store(
pbar=pbar,
progress=progress,
api_token=api_token,
leave_bar=leave_bar,
)

else:
Expand Down Expand Up @@ -207,6 +211,7 @@ async def _upload_singlepart(
pbar,
progress,
api_token: str,
leave_bar: bool,
) -> Tuple[bool, str]:
"""
Uploads a single part of a file to a remote server using HTTP PUT method.
Expand All @@ -217,6 +222,7 @@ async def _upload_singlepart(
filepath (str): The path to the file to be uploaded.
pbar (tqdm): A progress bar object to track the upload progress.
progress: The progress object used to update the progress bar.
leave_bar (bool): A flag indicating whether to keep the progress bar visible after the upload is complete.

Returns:
Tuple[bool, str]: A tuple containing the status of the upload (True for success, False for failure)
Expand All @@ -235,19 +241,25 @@ async def _upload_singlepart(
params = {
"headers": headers,
"url": ticket["url"],
"data": file_sender(
file_name=filepath,
progress=progress,
pbar=pbar,
),
"data": open(filepath, "rb"),
}

async with session.put(**params) as response:
status = response.status == 200
response.raise_for_status()

if status:
progress.update(pbar, advance=os.path.getsize(filepath))
progress.update(
pbar,
advance=os.path.getsize(filepath),
)

await asyncio.sleep(0.1)

progress.update(
pbar,
visible=leave_bar,
)

return status, storage_identifier

Expand Down Expand Up @@ -463,12 +475,14 @@ async def _abort_upload(
response.raise_for_status()


async def _add_file_to_ds(
async def _add_files_to_ds(
session: aiohttp.ClientSession,
dataverse_url: str,
pid: str,
file: File,
) -> bool:
files: List[File],
progress,
pbar,
) -> None:
"""
Adds a file to a Dataverse dataset.

Expand All @@ -481,26 +495,77 @@ async def _add_file_to_ds(
Returns:
bool: True if the file was added successfully, False otherwise.
"""
if not file.to_replace:
url = urljoin(dataverse_url, UPLOAD_ENDPOINT + pid)
else:
url = build_url(
dataverse_url=dataverse_url,
endpoint=urljoin(
dataverse_url,
REPLACE_ENDPOINT.format(FILE_ID=file.file_id),
),
)

json_data = file.model_dump_json(
by_alias=True,
exclude={"to_replace", "file_id"},
novel_url = urljoin(dataverse_url, UPLOAD_ENDPOINT + pid)
replace_url = urljoin(dataverse_url, REPLACE_ENDPOINT + pid)

novel_json_data = _prepare_registration(files, use_replace=False)
replace_json_data = _prepare_registration(files, use_replace=True)

await _multipart_json_data_request(
session=session,
json_data=novel_json_data,
url=novel_url,
)

await _multipart_json_data_request(
session=session,
json_data=replace_json_data,
url=replace_url,
)

progress.update(pbar, advance=1)


def _prepare_registration(files: List[File], use_replace: bool) -> str:
"""
Prepares the files for registration at the Dataverse instance.

Args:
files (List[File]): The list of files to prepare.

Returns:
str: A JSON string containing the file data.
"""

exclude = {"to_replace"} if use_replace else {"to_replace", "file_id"}

return json.dumps(
[
file.model_dump(
by_alias=True,
exclude=exclude,
exclude_none=True,
)
for file in files
if file.to_replace is use_replace
],
indent=2,
)


async def _multipart_json_data_request(
json_data: str,
url: str,
session: aiohttp.ClientSession,
):
"""
Sends a multipart/form-data POST request with JSON data to the specified URL using the provided session.

Args:
json_data (str): The JSON data to be sent in the request body.
url (str): The URL to send the request to.
session (aiohttp.ClientSession): The aiohttp client session to use for the request.

Raises:
aiohttp.ClientResponseError: If the response status code is not successful.

Returns:
None
"""
with aiohttp.MultipartWriter("form-data") as writer:
json_part = writer.append(json_data)
json_part.set_content_disposition("form-data", name="jsonData")

async with session.post(url, data=writer) as response:
return response.status == 200
response.raise_for_status()
Loading
Loading