forked from google/osv.dev
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add script to refresh IDs to the new format.
- Loading branch information
1 parent
60b4ec8
commit 1b83241
Showing
1 changed file
with
125 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
#!/usr/bin/env python3 | ||
""" Utility to update the datastore key of each Bug to the new format determined by the pre put hook. | ||
Does this by deleting and reputting each Bug entry. | ||
""" | ||
from google.cloud import ndb | ||
import osv | ||
|
||
import argparse | ||
import json | ||
import functools | ||
import time | ||
|
||
MAX_BATCH_SIZE = 500 | ||
|
||
|
||
def get_relevant_ids(verbose: bool) -> list[str]: | ||
relevant_ids = [] | ||
query = osv.Bug.query() | ||
|
||
query: ndb.Query = query.filter() | ||
query.projection = ["db_id"] | ||
print(f"Running initial query on {query.kind}...") | ||
result: list[osv.Bug] = list(query.fetch(limit=5000)) | ||
|
||
print(f"Retrieved {len(result)} bugs to examine for reputting") | ||
|
||
for res in result: | ||
if res.key.id() != res.db_id: # type: ignore | ||
relevant_ids.append(res.db_id) | ||
if verbose: | ||
print(res.db_id + ' - ' + res.key.id()) # type: ignore | ||
|
||
print(str(len(relevant_ids)) + " / " + str(len(result))) | ||
return relevant_ids | ||
|
||
|
||
def reput_bugs(dryrun: bool, verbose: bool) -> None: | ||
""" Reput all bugs from a given source.""" | ||
|
||
# Uncomment below to load the state and skip the get_relevant_ids func | ||
# relevant_ids = json.load(open('relevant_ids.json', 'r')) | ||
relevant_ids = get_relevant_ids(verbose) | ||
|
||
# Store the state incase we cancel halfway to avoid having to do the initial query again | ||
json.dump(relevant_ids, open('relevant_ids.json', 'w')) | ||
|
||
relevant_ids = relevant_ids[:2] | ||
print(relevant_ids) | ||
num_reputted = 0 | ||
time_start = time.perf_counter() | ||
|
||
# This handles the actual transaction of reputting the bugs with ndb | ||
def _reput_ndb(batch: int): | ||
buf: list[osv.Bug] = [ | ||
osv.Bug.get_by_id(r) for r in relevant_ids[batch:batch + MAX_BATCH_SIZE] | ||
] | ||
|
||
# Delete the existing entries. This must be done in a transaction to avoid losing data if interrupted | ||
ndb.delete_multi([r.key for r in buf]) | ||
|
||
# Clear the key so the key name will be regenerated to the new key format | ||
for i in range(len(buf)): | ||
buf[i].key = None | ||
|
||
# Reput the bug back in | ||
ndb.put_multi_async(buf) | ||
|
||
if dryrun: | ||
print("Dry run mode. Preventing transaction from committing") | ||
raise Exception("Dry run mode") # pylint: disable=broad-exception-raised | ||
|
||
print(f"Time elapsed: {(time.perf_counter() - time_start):.2f} seconds.") | ||
|
||
# Chunk the results to reput in acceptibly sized batches for the API. | ||
for batch in range(0, len(relevant_ids), MAX_BATCH_SIZE): | ||
try: | ||
num_reputted += len(relevant_ids[batch:batch + MAX_BATCH_SIZE]) | ||
print( | ||
f"Reput {num_reputted} bugs... - {num_reputted/len(relevant_ids)*100:.2f}%" | ||
) | ||
ndb.transaction(functools.partial(_reput_ndb, batch)) | ||
except Exception as e: | ||
# Don't have the first batch's transaction-aborting exception stop | ||
# subsequent batches from being attempted. | ||
if dryrun and e.args[0].startswith("Dry run mode"): | ||
print("Dry run mode. Preventing transaction from committing") | ||
else: | ||
print([r for r in relevant_ids[batch:batch + MAX_BATCH_SIZE]]) | ||
print(f"Exception {e} occurred. Continuing to next batch.") | ||
|
||
print("Reputted!") | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser( | ||
description="Reput all bugs from a given source.") | ||
parser.add_argument( | ||
"--dry-run", | ||
action=argparse.BooleanOptionalAction, | ||
dest="dryrun", | ||
default=True, | ||
help="Abort before making changes") | ||
parser.add_argument( | ||
"--verbose", | ||
action=argparse.BooleanOptionalAction, | ||
dest="verbose", | ||
default=False, | ||
help="Print each ID that needs to be processed") | ||
parser.add_argument( | ||
"--project", | ||
action="store", | ||
dest="project", | ||
default="oss-vdb-test", | ||
help="GCP project to operate on") | ||
args = parser.parse_args() | ||
|
||
client = ndb.Client(project=args.project) | ||
print(f"Running on project {args.project}.") | ||
with client.context(): | ||
reput_bugs(args.dryrun, args.verbose) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |