diff --git a/dvuploader/directupload.py b/dvuploader/directupload.py index e22c729..18c1607 100644 --- a/dvuploader/directupload.py +++ b/dvuploader/directupload.py @@ -6,6 +6,7 @@ from urllib.parse import urljoin import aiofiles import aiohttp +import aiohttp_retry from dvuploader.file import File from dvuploader.nativeupload import file_sender @@ -13,11 +14,14 @@ TESTING = bool(os.environ.get("DVUPLOADER_TESTING", False)) MAX_FILE_DISPLAY = int(os.environ.get("DVUPLOADER_MAX_FILE_DISPLAY", 50)) +MAX_RETRIES = int(os.environ.get("DVUPLOADER_MAX_RETRIES", 10)) assert isinstance( MAX_FILE_DISPLAY, int ), "DVUPLOADER_MAX_FILE_DISPLAY must be an integer" +assert isinstance(MAX_RETRIES, int), "DVUPLOADER_MAX_RETRIES must be an integer" + TICKET_ENDPOINT = "/api/datasets/:persistentId/uploadurls" ADD_FILE_ENDPOINT = "/api/datasets/:persistentId/addFiles" UPLOAD_ENDPOINT = "/api/datasets/:persistentId/add?persistentId=" @@ -80,24 +84,27 @@ async def direct_upload( "x-amz-tagging": "dv-state=temp", } - connector = aiohttp.TCPConnector(limit=4) + connector = aiohttp.TCPConnector(limit=10, force_close=True) pbar = progress.add_task("╰── [bold white]Registering files", total=len(files)) results = [] async with aiohttp.ClientSession( headers=headers, connector=connector, ) as session: - for file in files: - results.append( - await _add_file_to_ds( - session=session, - file=file, - dataverse_url=dataverse_url, - pid=persistent_id, - ) + + register_tasks = [ + _add_file_to_ds( + session=session, + file=file, + dataverse_url=dataverse_url, + pid=persistent_id, + progress=progress, + pbar=pbar, ) + for file in files + ] - progress.update(pbar, advance=1) + results = await asyncio.gather(*register_tasks) for file, status in zip(files, results): if status is False: @@ -486,6 +493,8 @@ async def _add_file_to_ds( dataverse_url: str, pid: str, file: File, + progress, + pbar, ) -> bool: """ Adds a file to a Dataverse dataset. @@ -514,11 +523,23 @@ async def _add_file_to_ds( by_alias=True, exclude={"to_replace", "file_id"}, indent=2, + exclude_none=True, ) with aiohttp.MultipartWriter("form-data") as writer: json_part = writer.append(json_data) json_part.set_content_disposition("form-data", name="jsonData") - async with session.post(url, data=writer) as response: - return response.status == 200 + retry_options = aiohttp_retry.RandomRetry(attempts=MAX_RETRIES, statuses={400}) + retry_client = aiohttp_retry.RetryClient( + client_session=session, + retry_options=retry_options, + ) + + async with retry_client.post(url, data=writer) as response: + if response.status == 200: + progress.update(pbar, advance=1) + return True + + await retry_client.close() + return False