Skip to content

Commit

Permalink
Merge pull request #148 from getzlab/drs_via_drshub
Browse files Browse the repository at this point in the history
Implement DRS URI localization.
  • Loading branch information
julianhess authored Sep 17, 2024
2 parents 9882007 + 86f3317 commit 421e505
Showing 1 changed file with 161 additions and 32 deletions.
193 changes: 161 additions & 32 deletions canine/localization/file_handlers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import abc
import google.cloud.storage
import google.auth
import glob, google_crc32c, json, hashlib, base64, binascii, os, re, requests, shlex, subprocess, threading
import pandas as pd

from google.auth.transport.requests import AuthorizedSession
from ..utils import sha1_base32, canine_logging

class FileType(abc.ABC):
Expand Down Expand Up @@ -125,7 +127,7 @@ def get_requester_pays(self) -> bool:
bucket = re.match(r"gs://(.*?)/.*", self.path)[1]

gcs_cl = gcloud_storage_client()
bucket_obj = google.cloud.storage.Bucket(gcs_cl, bucket, user_project = self.extra_args["project"] if "project" in self.extra_args else None)
bucket_obj = google.cloud.storage.Bucket(gcs_cl, bucket, user_project = self.extra_args.get("project"))

return bucket_obj.requester_pays

Expand Down Expand Up @@ -168,7 +170,7 @@ def blob(self):

gcs_cl = gcloud_storage_client()

bucket_obj = google.cloud.storage.Bucket(gcs_cl, bucket, user_project = self.extra_args["project"] if "project" in self.extra_args else None)
bucket_obj = google.cloud.storage.Bucket(gcs_cl, bucket, user_project = self.extra_args.get("project"))

# check whether this path exists, and whether it's a directory

Expand Down Expand Up @@ -247,6 +249,23 @@ def localization_command(self, dest):

# }}}

## GCP Authorized Session {{{

GCP_AUTH_SESSION = None
gcp_auth_session_creation_lock = threading.Lock()

def gcp_auth_session():
global GCP_AUTH_SESSION
with gcp_auth_session_creation_lock:
if GCP_AUTH_SESSION is None:
# this is the expensive operation
GCP_AUTH_SESSION = AuthorizedSession(
google.auth.default(['https://www.googleapis.com/auth/userinfo.profile',
'https://www.googleapis.com/auth/userinfo.email'])[0])
return GCP_AUTH_SESSION

# }}}

## AWS S3 {{{

class HandleAWSURL(FileType):
Expand All @@ -270,13 +289,13 @@ def __init__(self, path, **kwargs):

# keys get passed via environment variable
self.command_env = {}
self.command_env["AWS_ACCESS_KEY_ID"] = self.extra_args["aws_access_key_id"] if "aws_access_key_id" in self.extra_args else None
self.command_env["AWS_SECRET_ACCESS_KEY"] = self.extra_args["aws_secret_access_key"] if "aws_secret_access_key" in self.extra_args else None
self.command_env["AWS_ACCESS_KEY_ID"] = self.extra_args.get("aws_access_key_id")
self.command_env["AWS_SECRET_ACCESS_KEY"] = self.extra_args.get("aws_secret_access_key")
self.command_env_str = " ".join([f"{k}={v}" for k, v in self.command_env.items() if v is not None])

# compute extra arguments for s3 commands
# TODO: add requester pays check here
self.aws_endpoint_url = self.extra_args["aws_endpoint_url"] if "aws_endpoint_url" in self.extra_args else None
self.aws_endpoint_url = self.extra_args.get("aws_endpoint_url")

self.s3_extra_args = []
if self.command_env["AWS_ACCESS_KEY_ID"] is None and self.command_env["AWS_SECRET_ACCESS_KEY"] is None:
Expand Down Expand Up @@ -382,13 +401,14 @@ def localization_command(self, dest):
## GDC HTTPS URLs {{{
class HandleGDCHTTPURL(FileType):
localization_mode = "url"
gdc_drs_root = "drs://dg.4dfc:"

def __init__(self, path, **kwargs):
super().__init__(path, **kwargs)

self.token = self.extra_args["token"] if "token" in self.extra_args else None
self.token = self.extra_args.get("token")
self.token_flag = f'--header "X-Auth-Token: {self.token}"' if self.token is not None else ''
self.check_md5 = self.extra_args["check_md5"] if "check_md5" in self.extra_args else False
self.check_md5 = self.extra_args.get("check_md5", False)

# parse URL
self.url = self.path
Expand All @@ -399,34 +419,50 @@ def __init__(self, path, **kwargs):
self.prefix = url_parse[1]
self.uuid = url_parse[2]

# the actual filename is encoded in the content-disposition header;
# save this to self.path
# since the filesize and hashes are also encoded in the header, populate
# these fields now
resp_headers = subprocess.run(
'curl -s -D - -o /dev/full {token_flag} {file}'.format(
token_flag = self.token_flag,
file = self.path
),
shell = True,
capture_output = True
)
try:
headers = pd.DataFrame(
[x.split(": ") for x in resp_headers.stdout.decode().split("\r\n")[1:]],
columns=["header", "value"],
).set_index("header")["value"]

self.path = re.match(".*filename=(.*)$", headers["Content-Disposition"])[1]
self._size = int(headers["Content-Length"])
self._hash = headers["Content-MD5"]
self.uri = type(self).gdc_drs_root + self.uuid
self.drs_obj = HandleDRSURI(self.uri, **self.extra_args)
except:
canine_logging.error("Error resolving GDC file; see details:")
canine_logging.error(resp_headers.stdout.decode())
raise
canine_logging.warning("Re-attempting with GDC API")
self.drs_obj = None

# the actual filename is encoded in the content-disposition header;
# save this to self.path
# since the filesize and hashes are also encoded in the header, populate
# these fields now
resp_headers = subprocess.run(
'curl -s -D - -o /dev/full {token_flag} {file}'.format(
token_flag = self.token_flag,
file = self.path
),
shell = True,
capture_output = True
)
try:
headers = pd.DataFrame(
[x.split(": ") for x in resp_headers.stdout.decode().split("\r\n")[1:]],
columns=["header", "value"],
).set_index("header")["value"]

self.path = re.match(".*filename=(.*)$", headers["Content-Disposition"])[1]
self._size = int(headers["Content-Length"])
self._hash = headers["Content-MD5"]
except:
canine_logging.error("Error resolving GDC file; see details:")
canine_logging.error(resp_headers.stdout.decode())
raise
if self.drs_obj is not None:
# if we have a DRS object, use its properties
self.path = self.drs_obj.path
self._size = self.drs_obj.size
self._hash = self.drs_obj.hash
self.url = self.drs_obj.uri
self.token = None
self.localized_path = self.path

def localization_command(self, dest):
if self.drs_obj is not None:
return self.drs_obj.localization_command(dest)
dest_dir = shlex.quote(os.path.dirname(dest))
dest_file = shlex.quote(os.path.basename(dest))
self.localized_path = os.path.join(dest_dir, dest_file)
Expand Down Expand Up @@ -471,6 +507,98 @@ def localization_command(self, dest):

# }}}

class HandleDRSURI(FileType):
localization_mode = "url"
drs_resolver = "https://drshub.dsde-prod.broadinstitute.org/api/v4/drs/resolve"

def __init__(self, path, **kwargs):
super().__init__(path, **kwargs)

self.check_md5 = self.extra_args.get("check_md5", False)

# parse URL
self.uri = self.path
uri_parse = re.match(r"^drs://(?:[A-Za-z0-9._]+/)?[A-Za-z0-9._]+:[A-Za-z0-9.-_~%]+",
self.uri)
if uri_parse is None:
raise ValueError(f"Invalid DRS URI '{self.uri}'")

fields = ["size", "fileName"]
if self.check_md5:
fields += ["hashes"]
data = {"url": self.uri, "fields": fields}

drshub_session = gcp_auth_session()
resp = drshub_session.post(type(self).drs_resolver,
headers={"Content-type": "application/json"}, json=data)

try:
metadata = resp.json()
self.path = metadata["fileName"]
self._size = metadata["size"]
self._hash = metadata.get("hashes", {}).get("md5")
except:
try:
msg = json.dumps(resp.json())
except:
msg = resp.text
canine_logging.error("Error resolving DRS URI; see details:")
canine_logging.error(f"Response code: {resp.status_code}")
canine_logging.error(msg)
raise
self.localized_path = self.path

def localization_command(self, dest):
dest_dir = shlex.quote(os.path.dirname(dest))
dest_file = shlex.quote(os.path.basename(dest))
self.localized_path = os.path.join(dest_dir, dest_file)
data_str = json.dumps({"url": self.uri, "fields": ["accessUrl"]})
signed_url = f'$(curl -S -X POST --url "{type(self).drs_resolver}" ' + \
'-H "authorization: Bearer $(gcloud auth print-access-token)" ' + \
f'-H "content-type: application/json" --data \'{data_str}\' | ' + \
'python3 -c \'import json,sys; print(json.load(sys.stdin)["accessUrl"]["url"])\')'
cmd = [f'signed_url={signed_url}',
f'[ ! -d {dest_dir} ] && mkdir -p {dest_dir} || :; curl -C - -o {self.localized_path} "$signed_url"']

# ensure that file downloaded properly
if self.check_md5:
cmd += [f"[[ $(md5sum {self.localized_path} | sed -r 's/ .*$//') == {self.hash} ]] || {{ echo 'deleting corrupted file' ; rm -f {self.localized_path} ; exit 1 ; }}"]

return "\n".join(cmd)

class HandleDRSURIStream(HandleDRSURI):
localization_mode = "stream"

def localization_command(self, dest):

dest_dir = shlex.quote(os.path.dirname(dest))
dest_file = shlex.quote(os.path.basename(dest))
self.localized_path = os.path.join(dest_dir, dest_file)
cmd = []

# clean existing file if it exists
cmd += ['if [[ -e {0} ]]; then rm {0}; fi'.format(dest)]

# create dir if it doesn't exist
cmd += ["[ ! -d {dest_dir} ] && mkdir -p {dest_dir} || :;".format(dest_dir=dest_dir)]

# create fifo object
cmd += ['mkfifo {}'.format(dest)]

# get signed URL
data_str = json.dumps({"url": self.uri, "fields": ["accessUrl"]})
signed_url = f'$(curl -S -X POST --url "{type(self).drs_resolver}" ' + \
'-H "authorization: Bearer $(gcloud auth print-access-token)" ' + \
f'-H "content-type: application/json" --data \'{data_str}\' | ' + \
'python3 -c \'import json,sys; print(json.load(sys.stdin)["accessUrl"]["url"])\')'
cmd += [f'signed_url={signed_url}']

# stream into fifo object
cmd += ['curl -C - -o {path} "$signed_url" &'.format(path=self.localized_path)]

return "\n".join(cmd)


class HandleOtherURL(FileType):
localization_mode = "url"

Expand Down Expand Up @@ -612,8 +740,9 @@ def _get_hash(self):

def get_file_handler(path, url_map = None, **kwargs):
url_map = {
r"^gs://" : HandleGSURL,
r"^s3://" : HandleAWSURL,
r"^gs://" : HandleGSURL,
r"^s3://" : HandleAWSURL,
r"^drs://" : HandleDRSURI,
r"^https://api.gdc.cancer.gov" : HandleGDCHTTPURL,
r"^https://api.awg.gdc.cancer.gov" : HandleGDCHTTPURL,
r"^rodisk://" : HandleRODISKURL,
Expand Down

0 comments on commit 421e505

Please sign in to comment.