Skip to content

Commit

Permalink
Merge pull request #13 from gdcc/set-file-meta-after-native-upload
Browse files Browse the repository at this point in the history
Set file meta after native upload
  • Loading branch information
JR-1991 authored May 9, 2024
2 parents eda8180 + 6d04bf2 commit af4c3b0
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 28 deletions.
22 changes: 10 additions & 12 deletions dvuploader/dvuploader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
from urllib.parse import urljoin
import requests
import os
Expand All @@ -19,6 +20,7 @@
from dvuploader.nativeupload import native_upload
from dvuploader.utils import build_url, retrieve_dataset_files, setup_pbar


class DVUploader(BaseModel):
"""
A class for uploading files to a Dataverse repository.
Expand Down Expand Up @@ -153,10 +155,7 @@ async def _validate_and_hash_files(self, verbose: bool):

if not verbose:
tasks = [
self._validate_and_hash_file(
file=file,
verbose=self.verbose
)
self._validate_and_hash_file(file=file, verbose=self.verbose)
for file in self.files
]

Expand All @@ -175,10 +174,7 @@ async def _validate_and_hash_files(self, verbose: bool):

tasks = [
self._validate_and_hash_file(
file=file,
progress=progress,
task_id=task,
verbose=self.verbose
file=file, progress=progress, task_id=task, verbose=self.verbose
)
for file in self.files
]
Expand All @@ -197,7 +193,7 @@ async def _validate_and_hash_file(
file.extract_file_name_hash_file()

if verbose:
progress.update(task_id, advance=1) # type: ignore
progress.update(task_id, advance=1) # type: ignore

def _check_duplicates(
self,
Expand Down Expand Up @@ -240,7 +236,7 @@ def _check_duplicates(
map(lambda dsFile: self._check_hashes(file, dsFile), ds_files)
)

if has_same_hash and file.checksum:
if has_same_hash:
n_skip_files += 1
table.add_row(
file.file_name, "[bright_black]Same hash", "[bright_black]Skip"
Expand Down Expand Up @@ -316,12 +312,14 @@ def _check_hashes(file: File, dsFile: Dict):
return False

hash_algo, hash_value = tuple(dsFile["dataFile"]["checksum"].values())
path = os.path.join(
dsFile.get("directoryLabel", ""), dsFile["dataFile"]["filename"]
)

return (
file.checksum.value == hash_value
and file.checksum.type == hash_algo
and file.file_name == dsFile["label"]
and file.directory_label == dsFile.get("directoryLabel", "")
and path == os.path.join(file.directory_label, file.file_name) # type: ignore
)

@staticmethod
Expand Down
165 changes: 154 additions & 11 deletions dvuploader/nativeupload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
import os
import tempfile
from typing import List, Tuple
from typing_extensions import Dict
import aiofiles
import aiohttp

from rich.progress import Progress, TaskID

from dvuploader.file import File
from dvuploader.packaging import distribute_files, zip_files
from dvuploader.utils import build_url
from dvuploader.utils import build_url, retrieve_dataset_files

MAX_RETRIES = os.environ.get("DVUPLOADER_MAX_RETRIES", 15)
NATIVE_UPLOAD_ENDPOINT = "/api/datasets/:persistentId/add"
NATIVE_REPLACE_ENDPOINT = "/api/files/{FILE_ID}/replace"
NATIVE_METADATA_ENDPOINT = "/api/files/{FILE_ID}/metadata"

assert isinstance(MAX_RETRIES, int), "DVUPLOADER_MAX_RETRIES must be an integer"

Expand Down Expand Up @@ -74,6 +76,22 @@ async def native_upload(
]

responses = await asyncio.gather(*tasks)
_validate_upload_responses(responses, files)

await _update_metadata(
session=session,
files=files,
persistent_id=persistent_id,
dataverse_url=dataverse_url,
api_token=api_token,
)


def _validate_upload_responses(
responses: List[Tuple],
files: List[File],
) -> None:
"""Validates the responses of the native upload requests."""

for (status, response), file in zip(responses, files):
if status == 200:
Expand Down Expand Up @@ -174,20 +192,21 @@ async def _single_native_upload(
endpoint=NATIVE_REPLACE_ENDPOINT.format(FILE_ID=file.file_id),
)

json_data = {
"description": file.description,
"forceReplace": True,
"directoryLabel": file.directory_label,
"categories": file.categories,
"restrict": file.restrict,
"forceReplace": True,
}
json_data = _get_json_data(file)

for _ in range(MAX_RETRIES):

formdata = aiohttp.FormData()
formdata.add_field("jsonData", json.dumps(json_data), content_type="application/json")
formdata.add_field("file", file.handler, filename=file.file_name)
formdata.add_field(
"jsonData",
json.dumps(json_data),
content_type="application/json",
)
formdata.add_field(
"file",
file.handler,
filename=file.file_name,
)

async with session.post(endpoint, data=formdata) as response:
status = response.status
Expand Down Expand Up @@ -234,3 +253,127 @@ def file_sender(
yield chunk
chunk = file.handler.read(chunk_size)
progress.advance(pbar, advance=chunk_size)


def _get_json_data(file: File) -> Dict:
"""Returns the JSON data for the native upload request."""
return {
"description": file.description,
"directoryLabel": file.directory_label,
"categories": file.categories,
"restrict": file.restrict,
"forceReplace": True,
}


async def _update_metadata(
session: aiohttp.ClientSession,
files: List[File],
dataverse_url: str,
api_token: str,
persistent_id: str,
):
"""Updates the metadata of the given files in a Dataverse repository.
Args:
session (aiohttp.ClientSession): The aiohttp client session.
files (List[File]): The files to update the metadata for.
dataverse_url (str): The URL of the Dataverse repository.
api_token (str): The API token of the Dataverse repository.
persistent_id (str): The persistent identifier of the dataset.
"""

file_mapping = _retrieve_file_ids(
persistent_id=persistent_id,
dataverse_url=dataverse_url,
api_token=api_token,
)

tasks = []

for file in files:
dv_path = os.path.join(file.directory_label, file.file_name) # type: ignore

try:
file_id = file_mapping[dv_path]
except KeyError:
raise ValueError(
(
f"File {dv_path} not found in Dataverse repository.",
"This may be due to the file not being uploaded to the repository.",
)
)

task = _update_single_metadata(
session=session,
url=NATIVE_METADATA_ENDPOINT.format(FILE_ID=file_id),
file=file,
)

tasks.append(task)

await asyncio.gather(*tasks)


async def _update_single_metadata(
session: aiohttp.ClientSession,
url: str,
file: File,
) -> None:
"""Updates the metadata of a single file in a Dataverse repository."""

json_data = _get_json_data(file)

del json_data["forceReplace"]
del json_data["restrict"]

formdata = aiohttp.FormData()
formdata.add_field(
"jsonData",
json.dumps(json_data),
content_type="application/json",
)

async with session.post(url, data=formdata) as response:
response.raise_for_status()


def _retrieve_file_ids(
persistent_id: str,
dataverse_url: str,
api_token: str,
) -> Dict[str, str]:
"""Retrieves the file IDs of the given files.
Args:
files (List[File]): The files to retrieve the IDs for.
persistent_id (str): The persistent identifier of the dataset.
dataverse_url (str): The URL of the Dataverse repository.
api_token (str): The API token of the Dataverse repository.
Returns:
Dict[str, str]: The list of file IDs.
"""

# Fetch file metadata
ds_files = retrieve_dataset_files(
persistent_id=persistent_id,
dataverse_url=dataverse_url,
api_token=api_token,
)

return _create_file_id_path_mapping(ds_files)


def _create_file_id_path_mapping(files):
"""Creates dictionary that maps from directoryLabel + filename to ID"""
mapping = {}

for file in files:
directory_label = file.get("directoryLabel", "")
file = file["dataFile"]
path = os.path.join(directory_label, file["filename"])
mapping[path] = file["id"]

return mapping
26 changes: 21 additions & 5 deletions tests/integration/test_native_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def test_forced_native_upload(
assert len(files) == 3
assert sorted([file["label"] for file in files]) == sorted(expected_files)


def test_native_upload_by_handler(
self,
credentials,
Expand All @@ -116,8 +115,16 @@ def test_native_upload_by_handler(
# Arrange
byte_string = b"Hello, World!"
files = [
File(filepath="subdir/file.txt", handler=BytesIO(byte_string)),
File(filepath="biggerfile.txt", handler=BytesIO(byte_string*10000)),
File(
filepath="subdir/file.txt",
handler=BytesIO(byte_string),
description="This is a test",
),
File(
filepath="biggerfile.txt",
handler=BytesIO(byte_string * 10000),
description="This is a test",
),
]

# Create Dataset
Expand Down Expand Up @@ -154,5 +161,14 @@ def test_native_upload_by_handler(

file = next(file for file in files if file["label"] == ex_f)

assert file["label"] == ex_f, f"File label does not match for file {json.dumps(file)}"
assert file.get("directoryLabel", "") == ex_dir, f"Directory label does not match for file {json.dumps(file)}"
assert (
file["label"] == ex_f
), f"File label does not match for file {json.dumps(file)}"

assert (
file.get("directoryLabel", "") == ex_dir
), f"Directory label does not match for file {json.dumps(file)}"

assert (
file["description"] == "This is a test"
), f"Description does not match for file {json.dumps(file)}"

0 comments on commit af4c3b0

Please sign in to comment.