Skip to content

Commit

Permalink
Merge pull request #54 from PGScatalog/fix-user-agent
Browse files Browse the repository at this point in the history
fix setting user-agent when downloading
  • Loading branch information
nebfield authored Sep 12, 2023
2 parents e885a72 + 0d92bc3 commit 6ca8bcc
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 33 deletions.
10 changes: 10 additions & 0 deletions pgscatalog_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
logger = logging.getLogger(__name__)


def headers() -> dict[str, str]:
if PGSC_CALC_VERSION is None:
raise Exception("Missing User-Agent when querying PGS Catalog")
else:
logger.info(f"User-Agent header: {PGSC_CALC_VERSION}")

header = {"User-Agent": PGSC_CALC_VERSION}
return header


def setup_tmpdir(outdir, combine=False):
if combine:
work_dir = "work_combine"
Expand Down
23 changes: 6 additions & 17 deletions pgscatalog_utils/download/Catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import requests

from pgscatalog_utils import __version__ as pgscatalog_utils_version
from pgscatalog_utils import config
from pgscatalog_utils.download.CatalogCategory import CatalogCategory
from pgscatalog_utils.download.ScoringFile import ScoringFile
from pgscatalog_utils.download.download_file import get_with_user_agent

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,8 +69,8 @@ def get_download_urls(self) -> dict[str: ScoringFile]:
case CatalogCategory.TRAIT | CatalogCategory.PUBLICATION:
# publications and traits have to query Catalog API again to grab score data
results: list[CatalogResult] = CatalogQuery(CatalogCategory.SCORE,
accession=list(self.pgs_ids),
pgsc_calc_version=config.PGSC_CALC_VERSION).get()
accession=list(
self.pgs_ids)).get()
for result in results:
for pgs in result.response.get("results"):
urls[pgs["id"]] = ScoringFile.from_result(pgs)
Expand All @@ -84,12 +84,9 @@ class CatalogQuery:
"""
category: CatalogCategory
accession: typing.Union[str, list[str]]
pgsc_calc_version: typing.Union[str, None]
include_children: bool = False
_rest_url_root: str = "https://www.pgscatalog.org/rest"
_max_retries: int = 5
_version: str = pgscatalog_utils_version
_user_agent: dict[str: str] = field(init=False)

def _resolve_query_url(self) -> typing.Union[str, list[str]]:
child_flag: int = int(self.include_children)
Expand All @@ -109,16 +106,8 @@ def _resolve_query_url(self) -> typing.Union[str, list[str]]:
case CatalogCategory.PUBLICATION, str():
return f"{self._rest_url_root}/publication/{self.accession}"
case _:
raise Exception(f"Invalid CatalogCategory and accession type: {self.category}, type({self.accession})")

def __post_init__(self):
ua: str
if self.pgsc_calc_version:
ua = pgscatalog_utils_version
else:
ua = f"pgscatalog_utils/{self._version}"

self._user_agent = {"User-Agent": ua}
raise Exception(
f"Invalid CatalogCategory and accession type: {self.category}, type({self.accession})")

def _query_api(self, url: str):
wait: int = 10
Expand All @@ -128,7 +117,7 @@ def _query_api(self, url: str):
while retry < self._max_retries:
try:
logger.info(f"Querying {url}")
r: requests.models.Response = requests.get(url, headers=self._user_agent)
r: requests.models.Response = get_with_user_agent(url)
r.raise_for_status()
results_json = r.json()
break
Expand Down
8 changes: 6 additions & 2 deletions pgscatalog_utils/download/download_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
import pathlib
import time
import urllib.parse
Expand All @@ -13,6 +12,10 @@
logger = logging.getLogger(__name__)


def get_with_user_agent(url: str) -> requests.Response:
return requests.get(url, headers=config.headers())


def download_file(url: str, local_path: str, overwrite: bool, ftp_fallback: bool) -> None:
if config.OUTDIR.joinpath(local_path).exists():
if not overwrite:
Expand All @@ -25,7 +28,7 @@ def download_file(url: str, local_path: str, overwrite: bool, ftp_fallback: bool
attempt: int = 0

while attempt < config.MAX_RETRIES:
response: requests.Response = requests.get(url)
response: requests.Response = get_with_user_agent(url)
match response.status_code:
case 200:
with open(config.OUTDIR.joinpath(local_path), "wb") as f:
Expand Down Expand Up @@ -69,3 +72,4 @@ def _ftp_fallback_download(url: str, local_path: str) -> None:
else:
logger.critical(f"Download failed: {e}")
raise Exception

21 changes: 12 additions & 9 deletions pgscatalog_utils/download/download_scorefile.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import argparse
import logging
import os
import pathlib
import textwrap
import typing

from pgscatalog_utils import __version__ as version
from pgscatalog_utils import config
from pgscatalog_utils.download.CatalogCategory import CatalogCategory
from pgscatalog_utils.download.Catalog import CatalogQuery, CatalogResult
from pgscatalog_utils.download.CatalogCategory import CatalogCategory
from pgscatalog_utils.download.GenomeBuild import GenomeBuild
from pgscatalog_utils.download.ScoringFileDownloader import ScoringFileDownloader


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -44,7 +43,11 @@ def download_scorefile() -> None:

if args.pgsc_calc:
config.PGSC_CALC_VERSION = args.pgsc_calc
logger.info(f"Setting user agent to {config.PGSC_CALC_VERSION} for PGS Catalog API queries")
logger.info(
f"Setting user agent to {config.PGSC_CALC_VERSION} for PGS Catalog API queries")
else:
config.PGSC_CALC_VERSION = f"pgscatalog_utils/{version}"
logger.warning(f"No user agent set, defaulting to {config.PGSC_CALC_VERSION}")

config.OUTDIR = pathlib.Path(args.outdir).resolve()
logger.info(f"Download directory: {config.OUTDIR}")
Expand All @@ -60,19 +63,19 @@ def download_scorefile() -> None:
else:
logger.debug("--trait set, querying traits")
for term in args.efo:
results.append(CatalogQuery(CatalogCategory.TRAIT, term, include_children=inc_child,
pgsc_calc_version=config.PGSC_CALC_VERSION).get())
results.append(CatalogQuery(CatalogCategory.TRAIT, term,
include_children=inc_child).get())

if args.pgp:
logger.debug("--pgp set, querying publications")
for term in args.pgp:
results.append(CatalogQuery(CatalogCategory.PUBLICATION, term, pgsc_calc_version=config.PGSC_CALC_VERSION).get())
results.append(CatalogQuery(CatalogCategory.PUBLICATION, term).get())

if args.pgs:
logger.debug("--id set, querying scores")
results.append(
CatalogQuery(CatalogCategory.SCORE, args.pgs,
pgsc_calc_version=config.PGSC_CALC_VERSION).get()) # pgs_lst: a list containing up to three flat lists
CatalogQuery(CatalogCategory.SCORE,
args.pgs).get()) # pgs_lst: a list containing up to three flat lists

flat_results = [element for sublist in results for element in sublist]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_fail_combine(scorefiles, tmp_path_factory):

@pytest.fixture
def _n_variants(pgs_accessions):
result = CatalogQuery(CatalogCategory.SCORE, accession=pgs_accessions, pgsc_calc_version=None).get()[0]
result = CatalogQuery(CatalogCategory.SCORE, accession=pgs_accessions).get()[0]
json = result.response
n: list[int] = jq.compile("[.results][][].variants_number").input(json).all()
return sum(n)
6 changes: 2 additions & 4 deletions tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,11 @@ def test_download_trait(tmp_path):

def test_query_publication():
# publications are relatively static
query: list[CatalogResult] = CatalogQuery(CatalogCategory.PUBLICATION, accession="PGP000001",
pgsc_calc_version=None).get()
query: list[CatalogResult] = CatalogQuery(CatalogCategory.PUBLICATION, accession="PGP000001").get()
assert not query[0].pgs_ids.difference({'PGS000001', 'PGS000002', 'PGS000003'})


def test_query_trait():
# new scores may be added to traits in the future
query: list[CatalogResult] = CatalogQuery(CatalogCategory.TRAIT, accession="EFO_0004329",
pgsc_calc_version=None).get()
query: list[CatalogResult] = CatalogQuery(CatalogCategory.TRAIT, accession="EFO_0004329").get()
assert not {'PGS001901', 'PGS002115'}.difference(query[0].pgs_ids)

0 comments on commit 6ca8bcc

Please sign in to comment.