Skip to content

Commit

Permalink
fixes/improvements
Browse files Browse the repository at this point in the history
* adds option to delete all vocab valus from CLI
* fixes service writer write_many to work with celery
* adds some logging
  • Loading branch information
jrcastro2 committed Jul 19, 2024
1 parent 50d4d32 commit ad40655
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 110 deletions.
45 changes: 23 additions & 22 deletions invenio_vocabularies/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .datastreams import DataStreamFactory
from .factories import get_vocabulary_config

from invenio_logging.structlog import LoggerFactory

@click.group()
def vocabularies():
Expand All @@ -26,43 +26,35 @@ def vocabularies():

def _process_vocab(config, num_samples=None):
"""Import a vocabulary."""
import time
start_time = time.time()
ds = DataStreamFactory.create(
readers_config=config["readers"],
transformers_config=config.get("transformers"),
writers_config=config["writers"],
)

cli_logger = LoggerFactory.get_logger("cli")
cli_logger.info("Starting processing")
success, errored, filtered = 0, 0, 0
left = num_samples or -1
for result in ds.process(batch_size=config["batch_size"] if "batch_size" in config else 100
,write_many=config["write_many"] if "write_many" in config else False):
batch_size=config.get("batch_size", 1000)
write_many=config.get("write_many", False)

for result in ds.process(batch_size=batch_size, write_many=write_many):
left = left - 1
if result.filtered:
filtered += 1
cli_logger.info("Filtered", entry=result.entry, operation=result.op_type)
if result.errors:
for err in result.errors:
click.secho(err, fg="red")
cli_logger.error("Error", entry=result.entry, operation=result.op_type, errors=result.errors)
errored += 1
else:
success += 1
cli_logger.info("Success", entry=result.entry, operation=result.op_type)
if left == 0:
click.secho(f"Number of samples reached {num_samples}", fg="green")
break

end_time = time.time()

elapsed_time = end_time - start_time
friendly_time = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
friendly_time_per_record = 0
if success:
elapsed_time_per_record = elapsed_time/success * 1000
friendly_time_per_record = time.strftime("%H:%M:%S", time.gmtime(elapsed_time_per_record))

print(f"CLI elapsed time: {friendly_time} for {success} entries. An average of {friendly_time_per_record} per 1000 entry.\n")
with open("/tmp/elapsed_time.txt", "a") as file:
file.write(f"CLI elapsed time: {friendly_time} for {success} entries. An average of {friendly_time_per_record} per 1000 entry.\n")
cli_logger.info("Finished processing", success=success, errored=errored, filtered=filtered)

return success, errored, filtered

Expand Down Expand Up @@ -160,18 +152,27 @@ def convert(vocabulary, filepath=None, origin=None, target=None, num_samples=Non
type=click.STRING,
help="Identifier of the vocabulary item to delete.",
)
@click.option("--all", is_flag=True, default=False, help="Not supported yet.")
@click.option("--all", is_flag=True, default=False)
@with_appcontext
def delete(vocabulary, identifier, all):
"""Delete all items or a specific one of the vocabulary."""
if not id and not all:
if not identifier and not all:
click.secho("An identifier or the --all flag must be present.", fg="red")
exit(1)

vc = get_vocabulary_config(vocabulary)
service = vc.get_service()
if identifier:
try:
if service.delete(identifier, system_identity):
if service.delete(system_identity, identifier):
click.secho(f"{identifier} deleted from {vocabulary}.", fg="green")
except (PIDDeletedError, PIDDoesNotExistError):
click.secho(f"PID {identifier} not found.")
elif all:
items = service.scan(system_identity)
for item in items.hits:
try:
if service.delete(system_identity, item["id"]):
click.secho(f"{item['id']} deleted from {vocabulary}.", fg="green")
except (PIDDeletedError, PIDDoesNotExistError):
click.secho(f"PID {item['id']} not found.")
13 changes: 8 additions & 5 deletions invenio_vocabularies/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,13 @@
}
"""Vocabulary type search configuration."""

VOCABULARIES_ORCID_ACCESS_KEY="CHANGE_ME"
VOCABULARIES_ORCID_SECRET_KEY="CHANGE_ME"
VOCABULARIES_ORCID_FOLDER="/tmp/ORCID_public_data_files/"
VOCABULARIES_ORCID_ACCESS_KEY="TOD"
"""ORCID access key to access the s3 bucket."""
VOCABULARIES_ORCID_SECRET_KEY="TODO"
"""ORCID secret key to access the s3 bucket."""
VOCABULARIES_ORCID_SUMMARIES_BUCKET="v3.0-summaries"
VOCABULARIES_DATASTREAM_BATCH_SIZE = 100
"""ORCID summaries bucket name."""
VOCABULARIES_ORCID_SYNC_MAX_WORKERS = 32
VOCABULARIES_ORCID_SYNC_DAYS = 1
"""ORCID max number of simultaneous workers/connections."""
VOCABULARIES_ORCID_SYNC_DAYS = 1
"""ORCID number of days to sync."""
134 changes: 72 additions & 62 deletions invenio_vocabularies/contrib/names/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...datastreams.readers import SimpleHTTPReader, BaseReader
from ...datastreams.transformers import BaseTransformer
from ...datastreams.writers import ServiceWriter
import boto3
import s3fs
from flask import current_app
from datetime import datetime
from datetime import timedelta
Expand All @@ -25,69 +25,88 @@
class OrcidDataSyncReader(BaseReader):
"""ORCiD Data Sync Reader."""

def _iter(self, fp, *args, **kwargs):
"""."""
raise NotImplementedError(
"OrcidDataSyncReader downloads one file and therefore does not iterate through items"
)

def read(self, item=None, *args, **kwargs):
"""Downloads the ORCiD lambda file and yields an in-memory binary stream of it."""

path = current_app.config["VOCABULARIES_ORCID_FOLDER"]
def _fetch_orcid_data(self, orcid_to_sync, fs, bucket):
"""Fetches a single ORCiD record from S3."""
# The ORCiD file key is located in a folder which name corresponds to the last three digits of the ORCiD
suffix = orcid_to_sync[-3:]
key = f'{suffix}/{orcid_to_sync}.xml'
try:
with fs.open(f's3://{bucket}/{key}', 'rb') as f:
file_response = f.read()
return file_response
except Exception as e:
# TODO: log
return None

def _process_lambda_file(self, fileobj):
"""Process the ORCiD lambda file and returns a list of ORCiDs to sync.
The decoded fileobj looks like the following:
orcid,last_modified,created
0000-0001-5109-3700,2021-08-02 15:00:00.000,2021-08-02 15:00:00.000
Yield ORCiDs to sync until the last sync date is reached.
"""
date_format = '%Y-%m-%d %H:%M:%S.%f'
date_format_no_millis = '%Y-%m-%d %H:%M:%S'

s3client = boto3.client('s3', aws_access_key_id=current_app.config["VOCABULARIES_ORCID_ACCESS_KEY"], aws_secret_access_key=current_app.config["VOCABULARIES_ORCID_SECRET_KEY"])
response = s3client.get_object(Bucket='orcid-lambda-file', Key='last_modified.csv.tar')
tar_content = response['Body'].read()


last_sync = datetime.now() - timedelta(days=current_app.config["VOCABULARIES_ORCID_SYNC_DAYS"])

file_content = fileobj.read().decode('utf-8')

for line in file_content.splitlines()[1:]: # Skip the header line
elements = line.split(',')
orcid = elements[0]

# Lambda file is ordered by last modified date
last_modified_str = elements[3]
try:
last_modified_date = datetime.strptime(last_modified_str, date_format)
except ValueError:
last_modified_date = datetime.strptime(last_modified_str, date_format_no_millis)

def process_file(fileobj):
file_content = fileobj.read().decode('utf-8')
orcids = []
for line in file_content.splitlines()[1:]: # Skip the header line
elements = line.split(',')
orcid = elements[0]

last_modified_str = elements[3]
try:
last_modified_date = datetime.strptime(last_modified_str, date_format)
except ValueError:
last_modified_date = datetime.strptime(last_modified_str, date_format_no_millis)

if last_modified_date >= last_sync:
orcids.append(orcid)
else:
break
return orcids
if last_modified_date >= last_sync:
yield orcid
else:
break

orcids_to_sync = []
with tarfile.open(fileobj=io.BytesIO(tar_content)) as tar:
for member in tar.getmembers():
f = tar.extractfile(member)
if f:
orcids_to_sync.extend(process_file(f))

def fetch_orcid_data(orcid_to_sync, bucket):
suffix = orcid_to_sync[-3:]
key = f'{suffix}/{orcid_to_sync}.xml'
try:
file_response = s3client.get_object(Bucket=bucket, Key=key)
return file_response['Body'].read()
except Exception as e:
# TODO: log
return None

with ThreadPoolExecutor(max_workers=current_app.config["VOCABULARIES_ORCID_SYNC_MAX_WORKERS"]) as executor: # TODO allow to configure max_workers / test to use asyncio
futures = [executor.submit(fetch_orcid_data, orcid, current_app.config["VOCABULARIES_ORCID_SUMMARIES_BUCKET"]) for orcid in orcids_to_sync]
def _iter(self, orcids, fs):
"""Iterates over the ORCiD records yielding each one."""

with ThreadPoolExecutor(max_workers=current_app.config["VOCABULARIES_ORCID_SYNC_MAX_WORKERS"]) as executor:
futures = [executor.submit(self._fetch_orcid_data, orcid, fs, current_app.config["VOCABULARIES_ORCID_SUMMARIES_BUCKET"]) for orcid in orcids]
for future in as_completed(futures):
result = future.result()
if result is not None:
yield result


def read(self, item=None, *args, **kwargs):
"""Streams the ORCiD lambda file, process it to get the ORCiDS to sync and yields it's data."""
fs = s3fs.S3FileSystem(
key=current_app.config["VOCABULARIES_ORCID_ACCESS_KEY"],
secret=current_app.config["VOCABULARIES_ORCID_SECRET_KEY"]
)
# Read the file from S3
with fs.open('s3://orcid-lambda-file/last_modified.csv.tar', 'rb') as f:
tar_content = f.read()

orcids_to_sync = []
# Opens tar file and process it
with tarfile.open(fileobj=io.BytesIO(tar_content)) as tar:
# Iterate over each member (file or directory) in the tar file
for member in tar.getmembers():
# Extract the file
extracted_file = tar.extractfile(member)
if extracted_file:
# Process the file and get the ORCiDs to sync
orcids_to_sync.extend(self._process_lambda_file(extracted_file))

yield from self._iter(orcids_to_sync, fs)



class OrcidHTTPReader(SimpleHTTPReader):
"""ORCiD HTTP Reader."""

Expand Down Expand Up @@ -184,26 +203,17 @@ def _entry_id(self, entry):
{"type": "xml"},
],
"transformers": [{"type": "orcid"}],
# "writers": [
# {
# "type": "names-service",
# "args": {
# "identity": system_identity,
# },
# }
# ],
"writers": [
{
"type": "async",
"args": {
"writer":{
"type": "names-service",
"args": {},
}
},
}
],
"batch_size": 1000, # TODO: current_app.config["VOCABULARIES_DATASTREAM_BATCH_SIZE"],
"batch_size": 1000,
"write_many": True,
}
"""ORCiD Data Stream configuration.
Expand Down
7 changes: 6 additions & 1 deletion invenio_vocabularies/datastreams/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""Base data stream."""

from .errors import ReaderError, TransformerError, WriterError
from invenio_logging.structlog import LoggerFactory

class StreamEntry:
"""Object to encapsulate streams processing."""
Expand Down Expand Up @@ -66,15 +67,19 @@ def process_batch(self, batch, write_many=False):
print(f"write {len(transformed_entries)} entries.")
yield from (self.write(entry) for entry in transformed_entries)

def process(self, batch_size=100, write_many=False, *args, **kwargs):
def process(self, batch_size=100, write_many=False, logger=None, *args, **kwargs):
"""Iterates over the entries.
Uses the reader to get the raw entries and transforms them.
It will iterate over the `StreamEntry` objects returned by
the reader, apply the transformations and yield the result of
writing it.
"""
if not logger:
logger = LoggerFactory.get_logger("datastreams")

batch = []
logger.info(f"Start reading datastream with batch_size={batch_size} and write_many={write_many}")
for stream_entry in self.read():
batch.append(stream_entry)
if len(batch) >= batch_size:
Expand Down
14 changes: 11 additions & 3 deletions invenio_vocabularies/datastreams/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ..datastreams import StreamEntry
from ..datastreams.factories import WriterFactory

from invenio_logging.structlog import LoggerFactory

@shared_task(ignore_result=True)
def write_entry(writer_config, entry):
Expand All @@ -25,12 +25,20 @@ def write_entry(writer_config, entry):
writer.write(StreamEntry(entry))

@shared_task(ignore_result=True)
def write_many_entry(writer_config, entries):
def write_many_entry(writer_config, entries, logger=None):
"""Write many entries.
:param writer: writer configuration as accepted by the WriterFactory.
:param entry: lisf ot dictionaries, StreamEntry is not serializable.
"""
if not logger:
logger = LoggerFactory.get_logger("write_many_entry")
writer = WriterFactory.create(config=writer_config)
stream_entries = [StreamEntry(entry) for entry in entries]
writer.write_many(stream_entries)
stream_entries_processed = writer.write_many(stream_entries)
errored = [entry for entry in stream_entries_processed if entry.errors]
succeeded = len(stream_entries_processed) - len(errored)
logger.info("Entries written", succeeded=succeeded)
if errored:
for entry in errored:
logger.error("Error writing entry", entry=entry.entry, errors=entry.errors)
9 changes: 6 additions & 3 deletions invenio_vocabularies/datastreams/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,15 @@ def write_many(self, stream_entries, *args, **kwargs):
entries = [entry.entry for entry in stream_entries]
entries_with_id = [(self._entry_id(entry), entry) for entry in entries]
records = self._service.create_or_update_many(self._identity, entries_with_id)
stream_entries = []
stream_entries_processed= []
for op_type, record, errors in records:
if errors == []:
yield StreamEntry(entry=record, op_type=op_type)
stream_entries_processed.append(StreamEntry(entry=record, op_type=op_type))
else:
yield StreamEntry(entry=record, errors=errors, op_type=op_type)
stream_entries_processed.append(StreamEntry(entry=record, errors=errors, op_type=op_type))

return stream_entries_processed


class YamlWriter(BaseWriter):
"""Writes the entries to a YAML file."""
Expand Down
Loading

0 comments on commit ad40655

Please sign in to comment.