-
Notifications
You must be signed in to change notification settings - Fork 0
/
alma_batch.py
418 lines (382 loc) · 20.2 KB
/
alma_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
import requests
import json
from urllib.parse import urlparse
import yaml
import pandas as pd
from asyncio_throttle import Throttler
import aiohttp
import asyncio
from pathlib import Path
from alma_batch_utils import *
from uuid import uuid4
from datetime import datetime
import traceback
import sys
import logging
class AlmaBatch:
def __init__(self, config: str):
'''Initializes the class.
:param config: A path to a YAML config file that includes the following
- the user's API key
- the type of API query (as specified by an EX Libris API endpoint).
'''
self.logger = logging.getLogger(__name__)
self._load_config(config)
self.api_doc = self._load_openapi(self.openapi)
self.header = self._create_header()
self.results = [] # Store batch results
self.errors = [] # Store batch errors
self.batch_idx = 0 # Tracks the number processed in each batch (when batching requests)
self.num_workers = 25 # Default value for number of async workers
self.limit = 100 # Number of results per page (default is Ex Libris maximum)
def _load_config(self, config_path: str='', config: dict=None):
'''Loads the config file.
:param config_path: should point to a YAML file
:param config: should be a dict of options
'''
# Required elements in the config file/object
required_keys = ['apikey',
'endpoint',
'operation',
'openapi']
try:
if config_path:
with open(config_path, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if not config:
raise MissingConfiguration('No configuration specified.')
# Check for required config elements
if not all(k in config for k in required_keys):
raise MissingConfiguration('One or more required configuration elements is missing.')
# Dynamically set config values
for key, value in config.items():
setattr(self, key, value)
except Exception as e:
self.logger.exception('Error loading configuration.')
raise
def _load_openapi(self, openapi: str):
'''Loads the OpenAPI documentation for the desired endpoint.
:param openapi: either a URL or a local file path. The file should be in JSON format.
'''
try:
# tests whether the argument is a URL or not
if urlparse(openapi).scheme:
doc = requests.get(openapi)
doc.raise_for_status()
return doc.json()
else:
with open(openapi, 'r') as f:
return json.load(f)
except Exception as e:
self.logger.exception('Failed to load OpenAPI documentation.')
raise
def load_csv(self, path_to_csv: str, clean_columns: bool = True):
'''Loads the CSV file at the supplied path and assigns to self.
:param path_to_csv: path to the CSV file (local or via http).
:param clean_columns: a flag for whether the columns should be converted to snake case.
'''
try:
self.data = pd.read_csv(path_to_csv)
if clean_columns:
self.data.columns = convert_column_case(self.data.columns)
return self
except Exception as e:
self.logger.exception('Failed to load CSV.')
raise e
def validate_data(self):
'''Checks that the columns or keys in self.data correspond to the parameters in the specified API.'''
# Check to see if the data object has "columns," i.e., is a DataFrame
if hasattr(self.data, 'columns'):
columns = set(self.data.columns)
# Otherwise, should be a list of dictionaries -- get the set of the keys across all entries
else:
columns = {k for row in self.data for k in row.keys()}
try:
# Get the params associated with the specified endpoint
params = self.api_doc['paths'][self.endpoint][self.operation]['parameters']
query_params = {p['name'] for p in params if p['in'] == 'query'}
path_params = {p['name'] for p in params if p['in'] == 'path'}
# Path parameters should be a subset of the columns
if (not path_params <= columns) and (not 'link' in columns):
raise InvalidParameters('One or more path parameters missing from the data you supplied.')
# Intersection of query params and columns should not be empty if no path params
elif (not path_params) and not (query_params & columns):
raise InvalidParameters('No valid parameters present in the data you supplied.')
self.path_params = path_params
self.query_params = query_params
return self
except Exception as e:
self.logger.exception('Error validating parameters')
raise
def _create_header(self):
'''Creates the header for the API request, using the supplied API key from the config file and (optionally) a supplied accept parameter (default is JSON) and content type (required if operation is put/post).
'''
try:
header = {'Authorization': f'apikey {self.apikey}'}
if hasattr(self, 'accepts'):
header['Accept'] = self.accepts
else:
header['Accept'] = self.accepts = 'application/json'
# If a put or post operation, we need to include the content type
if self.operation in ['post', 'put']:
header['Content-Type'] = self.content_type
return header
except AttributeError:
self.logger.error('Configuration error. "content_type" is a required parameter for POST or PUT operations.')
raise
def _construct_request(self, row_data: dict, page: int):
'''Returns a url and a set of query parameters for an API call.
:param row_data: should contain URL and/or query string parameters for the request, or a complete URL assigned to the key "link."
:param page: should be a value for paginated results, or 0 for the first page of results.
'''
# In some cases, we may be able to use the pre-constructed URL returned by another ExL API.
if 'link' in row_data:
url = row_data['link']
else:
# Construct the URL
url = self.exl_base + self.endpoint
url = url.format(**row_data)
#Include the remaining key-value pairs in the row if in the params for this endpoint
params = {k: v for k,v in row_data.items() if k in self.query_params}
# Check for an offset to include
if page > 0:
params['offset'] = self.limit * page
params['limit'] = self.limit
# Check for fixed parameters, i.e., params that every call should include
if hasattr(self, 'fixed_params'):
params.update(self.fixed_params)
# Return the parameters together with the headers as an object for the aiohttp client, along with the formatted URL
args = {'params': params,
'headers': self.header}
return url, args
def _make_batch(self):
'''
Batches self.data, using self.batch_size. Each batch of requests will be run concurrently.
'''
# If self.data is a DataFrame, convert to native Python structure (list of dicts) for processing
# Include the index of the row as a separate value
# For each row, we initialize the page to 0 (for paginated results)
if hasattr(self.data, 'columns'):
rows = [{'idx': i,
'page': 0,
'row_data': row._asdict()} for i, row in enumerate(self.data.itertuples(index=False))]
else:
rows = [{'idx': i,
'page': 0,
'row_data': row} for i, row in enumerate(self.data)]
if self.batch_size > 0:
batches = chunk_list(rows, self.batch_size)
# If batch_size = 0, use a single batch
else:
batches = [rows]
return batches
def _do_after_requests(self, iteration: int):
'''
Called after a completed batch of requests.
:param iteration: the number of the just-completed batch.
'''
if hasattr(self, 'output_file'):
self._update_data()
if hasattr(self, 'path_for_api_output') and self.serialize_return:
self.logger.info(f'Saving API output for batch {iteration+1}')
self.dump_output(batch=iteration+1)
def _update_data(self):
'''Updates the data supplied as parameters to the API calls, indicating which rows have been completed.
'''
if len(self.results) == 0:
self.logger.debug('No successful results to update.')
return
try:
# If self.data is not already a a DataFrame, convert to one
if not hasattr(self.data, 'columns'):
data = pd.DataFrame.from_records(self.data)
else:
data = self.data.copy()
# Create another out of the successful results
# For results without pagination, the total_record_count will be absent, so we default to 1
successes = pd.DataFrame.from_records([{'idx': result['idx'],
'page': result['page'] + 1,
'total_results': result['result'].get('total_record_count', 1)}
for result in self.results]).sort_values(by=['idx', 'page'])
# Compute the precentage success
# First, translate the number of pages into results
summary = successes.groupby(['idx', 'total_results']).page.apply(lambda x: len(x)*self.limit)
# Then represent that as a percentage of the total results as indicated by the API
summary = summary / summary.index.get_level_values('total_results')
# Finally, correct for values > 1 (since the last page of results may not be a full page)
summary = summary.where(summary < 1, 1) * 100
# Give a name to the percentage column
summary.name = 'percentage_complete'
# Join the summary to the input data
summary_df = summary.reset_index().join(data, on='idx', how='right')
# Drop the index column and replace missing values with zeroes
summary_df = summary_df.drop('idx', axis=1)
summary_df.percentage_complete = summary_df.percentage_complete.fillna(0)
summary_df.to_csv(self.output_file, index=False)
except Exception as e:
raise
return self
def dump_output(self, batch: int = None):
'''Saves the API output to disk, as JSON map from row index to result object.
If a batch parameter is supplied and the batch_size is non-zero, save the output in batches.
:param batch: number of the batch most recently completed
'''
if self.batch_size > 0:
# Save everything added since the last batch, if anything was added
if not len(self.results) > self.batch_idx:
return self
api_data = self.results[self.batch_idx:]
# Update the counter for the next iteration
self.batch_idx = len(self.results)
else:
api_data = self.results
# Timestamp to add as part of metadata
timestamp = datetime.now().strftime("%d-%m-%Y %H:%M")
# Unique part of filename
file_id = uuid4()
api_metadata = {'timestamp': timestamp,
'endpoint': self.endpoint,
'operation': self.operation}
api_output = {'metadata': api_metadata,
'data': api_data}
try:
with open(Path(self.path_for_api_output) / f'api_results_batch-{batch}-{file_id}.json', 'w') as f:
json.dump(api_output, f)
except Exception as e:
self.logger.exception("Error saving API output.")
raise
return self
def _check_for_pagination(self, result: dict):
'''Checks a result from the API for paginated results. If the total_record_count attribute > self.limit, then returns a number of pages still needed. Otherwise, returns 0.
:param result: result of an API call.
'''
total_results = result.get('total_record_count')
if total_results and total_results > self.limit:
return int(total_results / self.limit) # Number of (additional) iterations needed to get the rest of the results
return 0
async def _async_request(self, row_data: dict, idx: int, page: int=0, payload=None):
'''
Make a single asynchronous request.
:param row: a mapping to path & query parameters
:param idx: a row index from the original dataset
:param page: current page of results (0 if unpaginated)
:param payload: for PUT/POST operations
'''
url, args = self._construct_request(row_data, page)
# Optional payload (for PUT/POST)
# TO DO: Accept XML payloads
if payload and (getattr(self, 'content_type', None) == 'application/json'):
args['json'] = payload
# Capture the index and page for the results/errors
output = {'idx': idx,
'page': page}
# Get the client method corresponding to the desired HTTP method
client_method = getattr(self.client, self.operation)
self.logger.debug(f'Making request for row {idx}, page {page}.')
try:
# Wrap the call in the throttler context manager
async with self.throttler:
async with client_method(url, **args) as session:
# Check for an API error (which will not be sent as JSON or XML)
if (session.content_type != self.accepts) or (session.status != 200):
error = await session.text()
raise APIException(error)
elif self.content_type == 'application/json':
result = await session.json()
self.logger.debug(f'Received data for row {idx}, page {page}.')
else:
result = await session.text()
# Save the result
output['result'] = result
self.results.append(output)
if page == 0:
return self._check_for_pagination(result)
# If this isn't the first page of results, assume that the pagination has already been accounted for; we shouldn't need to add more tasks to the queue
return 0
except Exception as e:
# If this row throws an exception, record it in the errors list
exc_info = sys.exc_info()
output['error'] = traceback.format_exception(*exc_info)
self.logger.exception(f'Error {e} on row {idx} page {page}.')
self.errors.append(output)
# Can't extract pagination in this case
return 0
async def _worker(self):
'''
Worker to consume tasks from the queue; each task is a single request.
If the results of the request are paginated, the worker adds new tasks to the queue, one per page.
'''
while True:
request_task = await self.queue.get() # Get a task to process. Request should be a dictionary containing a row of data, an index (idx), and an page
self.logger.debug(f"Task acquired: index {request_task['idx']}, page {request_task['page']}.")
try:
more_pages = await self._async_request(**request_task)
# Mark the current request task as done
# If there are more pages, add those requests to the queue
for page in range(more_pages):
# Update the page parameter for each additional page of results
new_task = request_task.copy()
new_task['page'] = page + 1
self.queue.put_nowait(new_task)
self.logger.debug(f"Task queued up: index {request_task['idx']}, page {request_task['page']}.")
self.queue.task_done()
except Exception as e:
self.exception(f"Aborting request {request_task['idx']} without completion.")
#self.queue.task_done()
raise e
async def _main(self, batch):
'''
Handles adding tasks to the queue, one per request, and then running the queue with the event loop. This function is blocking: each batch will kick off a "new" asynchronous queue of tasks, which will be finished before the next batch proceeds.
:param batch: a batch of requests to process
'''
self.queue = asyncio.Queue() # Async task queue => keeps track of the number of tasks left to perform
# Initialize a new client session
self.client = aiohttp.ClientSession(connector_owner=False)
for row in batch:
self.queue.put_nowait(row) # Initialize the queue with the initial call to each endpoint (page=0)
#Create a list of coroutine workers to process the tasks
tasks = [asyncio.create_task(self._worker()) for i in range(self.num_workers)] # Create the tasks
# Main call that blocks until all queue items have been marked done
async with self.client:
# This usage will cause an exception raised by a worker to bubble up
# If the worker (as opposed to the request) raises an exception, this will cause the wait function to return.
# Otherwise, it will return when all tasks in the queue have been consumed (marked done).
done, _ = await asyncio.wait([asyncio.create_task(self.queue.join()), *tasks],
return_when=asyncio.FIRST_COMPLETED)
# Test for the existence of an exception (i.e., a cancelled worker)
tasks_with_exceptions = set(done) & set(tasks)
if tasks_with_exceptions:
# propagate the exception -- this should raise an Exception and exit
await tasks_with_exceptions.pop()
# Cancel our workers if no exceptions
for task in tasks:
task.cancel()
# Wait until all worker tasks are cancelled.
await asyncio.gather(*tasks, return_exceptions=True)
self.client.close()
async def amake_requests(self, rate_limit: int = 25, batch_size: int = 1000, serialize_return: bool = False):
'''Manages asynchronous requests in batches.
:param rate_limit: set by default to the max allowed by the ExL API (25 reqs/sec).
:param batch_size: can be set to 0 to operate without batching. Otherwise, asynchronous
requests will be batched, and in betweeen batches, if an output file path is specified in the config file, then the data will be updated and saved to disk as a CSV file. This allows the user to keep track of which rows have been processed in the event of application crash, serious network interruption, etc.
:param serialize_return: set to True to save the current API ouput to disk between batches
'''
self.batch_size = batch_size
self.serialize_return = serialize_return
# Context manager for throttling the requests
self.throttler = Throttler(rate_limit)
batches = self._make_batch()
# Run in batches, pausing between each to update the data
for j, batch in enumerate(batches):
try:
self.logger.info(f'Running batch {j+1}...')
await self._main(batch)
self._do_after_requests(iteration=j)
except Exception as e:
self.logger.exception(f'Exception encountered on batch {j+1}. Proceeding to next batch.')
#continue
raise # For debugging
self.logger.info('All requests completed.')
return self
def make_requests(self, *args, **kwargs):
return asyncio.run(self.amake_requests(*args, **kwargs))