From 6892f7bb54118003c9e2bf8ea0cffca1693b32b2 Mon Sep 17 00:00:00 2001 From: Mim Hastie Date: Wed, 20 Dec 2023 10:23:56 -0800 Subject: [PATCH 1/7] fix: unescape CrossRef journal (#6373) --- backend/common/providers/crossref_provider.py | 10 +++-- .../layers/thirdparty/crossref_provider.py | 9 ++-- .../thirdparty/test_crossref_provider.py | 41 ++++++++++++++++++- 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/backend/common/providers/crossref_provider.py b/backend/common/providers/crossref_provider.py index 421f221af5130..8fdea7ebdbafa 100644 --- a/backend/common/providers/crossref_provider.py +++ b/backend/common/providers/crossref_provider.py @@ -1,3 +1,4 @@ +import html import logging from datetime import datetime from urllib.parse import urlparse @@ -101,14 +102,16 @@ def fetch_metadata(self, doi: str) -> dict: # Journal try: if "short-container-title" in message and message["short-container-title"]: - journal = message["short-container-title"][0] + raw_journal = message["short-container-title"][0] elif "container-title" in message and message["container-title"]: - journal = message["container-title"][0] + raw_journal = message["container-title"][0] elif "institution" in message: - journal = message["institution"][0]["name"] + raw_journal = message["institution"][0]["name"] except Exception: raise CrossrefParseException("Journal node missing") from None + journal = html.unescape(raw_journal) + # Authors # Note: make sure that the order is preserved, as it is a relevant information authors = message["author"] @@ -138,7 +141,6 @@ def fetch_metadata(self, doi: str) -> dict: raise CrossrefParseException("Cannot parse metadata from Crossref") from e def fetch_preprint_published_doi(self, doi): - res = self._fetch_crossref_payload(doi) message = res.json()["message"] is_preprint = message.get("subtype") == "preprint" diff --git a/backend/layers/thirdparty/crossref_provider.py b/backend/layers/thirdparty/crossref_provider.py index 8a7a80d7946a1..3f97ce70cbf1b 100644 --- a/backend/layers/thirdparty/crossref_provider.py +++ b/backend/layers/thirdparty/crossref_provider.py @@ -1,3 +1,4 @@ +import html import logging from datetime import datetime from urllib.parse import urlparse @@ -107,14 +108,16 @@ def fetch_metadata(self, doi: str) -> dict: # Journal try: if "short-container-title" in message and message["short-container-title"]: - journal = message["short-container-title"][0] + raw_journal = message["short-container-title"][0] elif "container-title" in message and message["container-title"]: - journal = message["container-title"][0] + raw_journal = message["container-title"][0] elif "institution" in message: - journal = message["institution"][0]["name"] + raw_journal = message["institution"][0]["name"] except Exception: raise CrossrefParseException("Journal node missing") from None + journal = html.unescape(raw_journal) + # Authors # Note: make sure that the order is preserved, as it is a relevant information authors = message["author"] diff --git a/tests/unit/backend/layers/thirdparty/test_crossref_provider.py b/tests/unit/backend/layers/thirdparty/test_crossref_provider.py index 29644cb6caa54..4567001bac1b9 100644 --- a/tests/unit/backend/layers/thirdparty/test_crossref_provider.py +++ b/tests/unit/backend/layers/thirdparty/test_crossref_provider.py @@ -25,7 +25,6 @@ def test__provider_does_not_call_crossref_in_test(self, mock_get): @patch("backend.common.providers.crossref_provider.requests.get") @patch("backend.common.providers.crossref_provider.CorporaConfig") def test__provider_calls_crossref_if_api_key_defined(self, mock_config, mock_get): - # Defining a mocked CorporaConfig will allow the provider to consider the `crossref_api_key` # not None, so it will go ahead and do the mocked call. @@ -75,7 +74,6 @@ def test__provider_calls_crossref_if_api_key_defined(self, mock_config, mock_get @patch("backend.common.providers.crossref_provider.requests.get") @patch("backend.common.providers.crossref_provider.CorporaConfig") def test__provider_parses_authors_and_dates_correctly(self, mock_config, mock_get): - response = Response() response.status_code = 200 response._content = str.encode( @@ -138,6 +136,45 @@ def test__provider_parses_authors_and_dates_correctly(self, mock_config, mock_ge self.assertDictEqual(expected_response, res) + @patch("backend.common.providers.crossref_provider.requests.get") + @patch("backend.common.providers.crossref_provider.CorporaConfig") + def test__provider_unescapes_journal_correctly(self, mock_config, mock_get): + response = Response() + response.status_code = 200 + response._content = str.encode( + json.dumps( + { + "status": "ok", + "message": { + "author": [ + {"name": "A consortium"}, + ], + "published-online": {"date-parts": [[2021, 11]]}, + "container-title": ["Clinical & Translational Med"], + }, + } + ) + ) + + mock_get.return_value = response + provider = CrossrefProvider() + res = provider.fetch_metadata("test_doi") + mock_get.assert_called_once() + + expected_response = { + "authors": [ + {"name": "A consortium"}, + ], + "published_year": 2021, + "published_month": 11, + "published_day": 1, + "published_at": 1635724800.0, + "journal": "Clinical & Translational Med", + "is_preprint": False, + } + + self.assertDictEqual(expected_response, res) + @patch("backend.common.providers.crossref_provider.requests.get") @patch("backend.common.providers.crossref_provider.CorporaConfig") def test__provider_throws_exception_if_request_fails(self, mock_config, mock_get): From 260b986d64f7b5e928c8123172a691a11e2f095d Mon Sep 17 00:00:00 2001 From: Joyce Yan <5653616+joyceyan@users.noreply.github.com> Date: Thu, 21 Dec 2023 10:44:54 -0500 Subject: [PATCH 2/7] fix: source data badge overlapping download modal (#6382) --- .../components/GeneSearchBar/components/SaveExport/style.ts | 2 ++ .../GeneSearchBar/components/SourceDataButton/style.ts | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/SaveExport/style.ts b/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/SaveExport/style.ts index 8dc62d509798f..55f0034a96814 100644 --- a/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/SaveExport/style.ts +++ b/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/SaveExport/style.ts @@ -33,6 +33,8 @@ export const StyledModal = styled(Modal)` font-size: 24px !important; margin: 0px !important; } + z-index: 10; + position: relative; `; export const StyledTitle = styled.div` diff --git a/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/SourceDataButton/style.ts b/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/SourceDataButton/style.ts index a2b9071aad7f0..78a8fe28413f5 100644 --- a/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/SourceDataButton/style.ts +++ b/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/SourceDataButton/style.ts @@ -23,7 +23,7 @@ export const BadgeCounter = styled(Badge)` height: 16px; text-align: center; position: relative; - z-index: 22; + z-index: 5; top: 4px; left: 14px; From 2f985e29b3e7478a53f83c3f21fd0a8b94ff25b1 Mon Sep 17 00:00:00 2001 From: Timmy Huang Date: Fri, 22 Dec 2023 09:22:30 -0800 Subject: [PATCH 3/7] fix: share link and test (#6385) --- .../components/ShareButton/utils.tsx | 7 ++- .../features/wheresMyGene/shareLink.test.ts | 51 ++++++++++--------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/ShareButton/utils.tsx b/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/ShareButton/utils.tsx index 551bd5b01347d..dfe0066a5940e 100644 --- a/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/ShareButton/utils.tsx +++ b/frontend/src/views/WheresMyGeneV2/components/GeneSearchBar/components/ShareButton/utils.tsx @@ -218,7 +218,11 @@ function getNewSelectedFilters({ Object.entries(tissues ?? {}).map(([id, tissue]) => [tissue.name, id]) ); - const allTissueNames = Object.keys(tissueIdsByName); + /** + * (thuang): Warning - Map doesn't work with Object.keys(), so we need to use + * Map.keys() instead + */ + const allTissueNames = Array.from(tissueIdsByName.keys()); const allTissueIds = Object.keys(tissues ?? {}); Object.keys(selectedFilters).forEach((key) => { @@ -235,6 +239,7 @@ function getNewSelectedFilters({ const tissueParams = value.split(delimiter); const tissueIds = []; const tissueNames = []; + for (const tissueParam of tissueParams) { if ( tissueParam.includes("UBERON:") && diff --git a/frontend/tests/features/wheresMyGene/shareLink.test.ts b/frontend/tests/features/wheresMyGene/shareLink.test.ts index d3342db5e5439..807cfc3a077c6 100644 --- a/frontend/tests/features/wheresMyGene/shareLink.test.ts +++ b/frontend/tests/features/wheresMyGene/shareLink.test.ts @@ -25,14 +25,11 @@ const tissueIds = TISSUES.map((tissue) => tissue.id); const GENES = ["DPM1", "TNMD", "TSPAN6"]; -const DATASETS = [ +const PUBLICATIONS = [ + // (thuang): This publication has blood and lung tissues { - id: "d8da613f-e681-4c69-b463-e94f5e66847f", - text: "A molecular single-cell lung atlas of lethal COVID-19", - }, - { - id: "de2c780c-1747-40bd-9ccf-9588ec186cee", - text: "Immunophenotyping of COVID-19 and influenza highlights the role of type I interferons in development of severe COVID-19", + id: "Ren et al. Cell 2021", + text: "Ren et al. (2021) Cell", }, ]; @@ -53,8 +50,8 @@ const CELL_TYPES = ["natural killer cell"]; const SHARE_LINK_SEARCH_PARAMS = new URLSearchParams(); SHARE_LINK_SEARCH_PARAMS.set("compare", COMPARE); SHARE_LINK_SEARCH_PARAMS.set( - "datasets", - DATASETS.map((dataset) => dataset.id).join() + "publications", + PUBLICATIONS.map((publication) => publication.id).join() ); SHARE_LINK_SEARCH_PARAMS.set( "diseases", @@ -113,7 +110,7 @@ describe("Share link tests", () => { }); }); - test.skip("Should generate share link with correct format for all query param types", async ({ + test("Should generate share link with correct format for all query param types", async ({ page, browserName, }) => { @@ -130,8 +127,7 @@ describe("Share link tests", () => { linkVersion: LATEST_SHARE_LINK_VERSION, tissueIds, genes: GENES, - // TODO(seve): #6131 test is currently failing on dataset param, should investigate and reenable - // datasets: DATASETS, + publications: PUBLICATIONS, sexes: SEXES, diseases: DISEASES, ethnicities: ETHNICITIES, @@ -168,7 +164,7 @@ async function verifyShareLink({ linkVersion, tissueIds, genes, - datasets, + publications, sexes, diseases, ethnicities, @@ -179,7 +175,7 @@ async function verifyShareLink({ linkVersion: string; tissueIds?: string[]; genes?: string[]; - datasets?: ExpectedParam[]; + publications?: ExpectedParam[]; sexes?: ExpectedParam[]; diseases?: ExpectedParam[]; ethnicities?: string[]; @@ -195,6 +191,11 @@ async function verifyShareLink({ "navigator.clipboard.readText()" ); + /** + * (thuang): The param order below needs to match the order from the ShareButton + * component + */ + // split parameters const urlParams = new URLSearchParams( // (thuang): We only want the query params part of the URL, so we split by "?" @@ -210,15 +211,6 @@ async function verifyShareLink({ searchParams.set(param, compare); } - // datasets - if (datasets !== undefined) { - const param = "datasets"; - - const data = await verifyParameter(page, urlParams, param, datasets); - - searchParams.set(param, String(data)); - } - // diseases if (diseases !== undefined) { const param = "diseases"; @@ -237,6 +229,15 @@ async function verifyShareLink({ searchParams.set(param, String(data)); } + // publications + if (publications !== undefined) { + const param = "publications"; + + const data = await verifyParameter(page, urlParams, param, publications); + + searchParams.set(param, String(data)); + } + // sexes if (sexes !== undefined) { const param = "sexes"; @@ -299,10 +300,10 @@ async function verifyParameter( const expectedIds = expectedParams.map((expectedParam) => expectedParam.id); switch (param) { - case "datasets": { + case "publications": { const paramValues = getParamValues(param); - // verify datasets have been selected + // verify publications have been selected paramValues.forEach(async (_id: string) => { const item = expectedParams.find( (expectedParam) => expectedParam.id === _id From a7a26009cb872f89735e83e4821f7d17b6d2ae98 Mon Sep 17 00:00:00 2001 From: atarashansky Date: Wed, 3 Jan 2024 09:36:28 -0800 Subject: [PATCH 4/7] fix: update dependencies to be able to read latest census version (#6409) --- requirements-wmg-pipeline.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-wmg-pipeline.txt b/requirements-wmg-pipeline.txt index 73e0e9e5b5689..aa7e14694d7df 100644 --- a/requirements-wmg-pipeline.txt +++ b/requirements-wmg-pipeline.txt @@ -5,7 +5,7 @@ black==22.3.0 # Must be kept in sync with black version in .pre-commit-config.y boto3==1.28.7 botocore>=1.31.7, <1.32.0 click==8.1.3 -cellxgene-census==1.6.0 +cellxgene-census==1.9.1 coverage==7.2.7 dataclasses-json==0.5.7 ddtrace==2.1.4 @@ -35,5 +35,5 @@ scipy==1.10.1 SQLAlchemy==1.4.49 SQLAlchemy-Utils==0.41.1 tenacity==8.2.2 -tiledbsoma==1.4.4 +tiledbsoma==1.6.1 dask==2023.8.1 From 6e44c587febc2b14fbce3ac544bda68ef0c1e7a1 Mon Sep 17 00:00:00 2001 From: Timmy Huang Date: Wed, 3 Jan 2024 10:42:43 -0800 Subject: [PATCH 5/7] test: Fix CellGuide tree view e2e test (#6410) --- .../common/OntologyDagView/constants.ts | 3 ++ .../common/OntologyDagView/index.tsx | 2 + .../features/cellGuide/cellGuide.test.ts | 41 ++++++++++++++----- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/frontend/src/views/CellGuide/components/common/OntologyDagView/constants.ts b/frontend/src/views/CellGuide/components/common/OntologyDagView/constants.ts index c6364e4c539b7..707c65da1c4dc 100644 --- a/frontend/src/views/CellGuide/components/common/OntologyDagView/constants.ts +++ b/frontend/src/views/CellGuide/components/common/OntologyDagView/constants.ts @@ -10,6 +10,9 @@ export const CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_FULLSCREEN_BUTTON = export const CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_HOVER_CONTAINER = "cell-guide-card-ontology-dag-view-hover-container"; +export const CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_CONTENT = + "cell-guide-card-ontology-dag-view-content"; + export const CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_DEACTIVATE_MARKER_GENE_MODE = "cell-guide-card-ontology-dag-view-deactivate-marker-gene-mode"; export const MINIMUM_NUMBER_OF_HIDDEN_CHILDREN_FOR_DUMMY_NODE = 3; diff --git a/frontend/src/views/CellGuide/components/common/OntologyDagView/index.tsx b/frontend/src/views/CellGuide/components/common/OntologyDagView/index.tsx index 7e154af902892..b536a1d18eded 100644 --- a/frontend/src/views/CellGuide/components/common/OntologyDagView/index.tsx +++ b/frontend/src/views/CellGuide/components/common/OntologyDagView/index.tsx @@ -58,6 +58,7 @@ import { CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_DEACTIVATE_MARKER_GENE_MODE, MINIMUM_NUMBER_OF_HIDDEN_CHILDREN_FOR_DUMMY_NODE, ANIMAL_CELL_ID, + CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_CONTENT, } from "src/views/CellGuide/components/common/OntologyDagView/constants"; import { ALL_TISSUES, @@ -518,6 +519,7 @@ export default function OntologyDagView({ height={height} ref={zoom.containerRef} isDragging={zoom.isDragging} + data-testid={CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_CONTENT} > { CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_DEACTIVATE_MARKER_GENE_MODE ); - /** - * Hover over the `neural cell` node, since it will still be in the tree - * view window when on a small viewport size. - * Otherwise, choosing a different node will risk it being hidden and - * not be hoverable - */ - const node = page.getByTestId( - `${CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_RECT_OR_CIRCLE_PREFIX_ID}-CL:0002319__0-has-children-isTargetNode=false` + const neuralNodeId = `${CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_RECT_OR_CIRCLE_PREFIX_ID}-CL:0002319__0-has-children-isTargetNode=false`; + + await isElementVisible(page, neuralNodeId); + + const dagContent = page.getByTestId( + CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_CONTENT + ); + + await dagContent.hover(); + + await tryUntil( + async () => { + /** + * (thuang): Zoom out the tree view a little, in case the node is half hidden + * and not hoverable + */ + await page.mouse.wheel(0, 1); + /** + * Hover over the `neural cell` node, since it will still be in the tree + * view window when on a small viewport size. + * Otherwise, choosing a different node will risk it being hidden and + * not be hoverable + */ + await page.getByTestId(neuralNodeId).hover(); + await isElementVisible( + page, + CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_TOOLTIP + ); + }, + { page } ); - await node.hover(); - await isElementVisible(page, CELL_GUIDE_CARD_ONTOLOGY_DAG_VIEW_TOOLTIP); // assert that the tooltip text contains the marker gene information const tooltipText = await page From 25fb62dff732d52620eb33c93a4d2af53c2293c2 Mon Sep 17 00:00:00 2001 From: Nayib Gloria <55710092+nayib-jose-gloria@users.noreply.github.com> Date: Wed, 3 Jan 2024 14:40:47 -0500 Subject: [PATCH 6/7] chore: remove schema 4 feature flag and temp reprocess citations batch job (#6369) --- .happy/terraform/modules/batch/main.tf | 64 ---- backend/common/feature_flag.py | 10 +- .../api/v1/curation/collections/common.py | 5 - backend/layers/processing/process_validate.py | 8 +- .../processing/reprocess_dataset_metadata.py | 328 ------------------ backend/portal/api/enrichment.py | 9 +- backend/portal/api/portal_api.py | 4 - scripts/smoke_tests/setup.py | 11 +- tests/functional/backend/common.py | 13 +- .../unit/backend/common/test_feature_flag.py | 8 +- .../backend/layers/business/test_business.py | 1 - .../backend/layers/common/base_api_test.py | 1 - tests/unit/backend/layers/common/base_test.py | 1 - tests/unit/processing/test_h5ad_data_file.py | 6 - 14 files changed, 20 insertions(+), 449 deletions(-) delete mode 100644 backend/layers/processing/reprocess_dataset_metadata.py diff --git a/.happy/terraform/modules/batch/main.tf b/.happy/terraform/modules/batch/main.tf index b904c50094631..c5fd5968df530 100644 --- a/.happy/terraform/modules/batch/main.tf +++ b/.happy/terraform/modules/batch/main.tf @@ -119,70 +119,6 @@ resource aws_batch_job_definition dataset_metadata_update { }) } -resource aws_batch_job_definition reprocess_dataset_metadata { - # this was used to reprocess dataset metadata in place when an error was found after publishing cellxgene schema 4.0 - # TODO: can be removed after 4.0 migration is complete - type = "container" - name = "dp-${var.deployment_stage}-${var.custom_stack_name}-reprocess-dataset-metadata" - container_properties = jsonencode({ - "command": ["python3", "-m", "backend.layers.processing.reprocess_dataset_metadata"], - "jobRoleArn": "${var.batch_role_arn}", - "image": "${var.image}", - "memory": var.batch_container_memory_limit, - "environment": [ - { - "name": "ARTIFACT_BUCKET", - "value": "${var.artifact_bucket}" - }, - { - "name": "CELLXGENE_BUCKET", - "value": "${var.cellxgene_bucket}" - }, - { - "name": "DATASETS_BUCKET", - "value": "${var.datasets_bucket}" - }, - { - "name": "DEPLOYMENT_STAGE", - "value": "${var.deployment_stage}" - }, - { - "name": "AWS_DEFAULT_REGION", - "value": "${data.aws_region.current.name}" - }, - { - "name": "REMOTE_DEV_PREFIX", - "value": "${var.remote_dev_prefix}" - } - ], - "vcpus": 8, - "linuxParameters": { - "maxSwap": 800000, - "swappiness": 60 - }, - "retryStrategy": { - "attempts": 3, - "evaluateOnExit": [ - { - "action": "RETRY", - "onReason": "Task failed to start" - }, - { - "action": "EXIT", - "onReason": "*" - } - ] - }, - "logConfiguration": { - "logDriver": "awslogs", - "options": { - "awslogs-group": "${aws_cloudwatch_log_group.cloud_watch_logs_group.id}", - "awslogs-region": "${data.aws_region.current.name}" - } - } -}) -} - resource aws_cloudwatch_log_group cloud_watch_logs_group { retention_in_days = 365 name = "/dp/${var.deployment_stage}/${var.custom_stack_name}/upload" diff --git a/backend/common/feature_flag.py b/backend/common/feature_flag.py index e51bd09a6ad5a..78498e528d007 100644 --- a/backend/common/feature_flag.py +++ b/backend/common/feature_flag.py @@ -11,26 +11,26 @@ To use a feature flag: ``` -if FeatureFlagService.is_enabled(FeatureFlagValues.SCHEMA_4): - +if FeatureFlagService.is_enabled(FeatureFlagValues.): + ``` To mock a feature flag in a test: ``` self.mock_config = CorporaConfig() -self.mock_config.set(dict(schema_4_feature_flag="True")) +self.mock_config.set(dict(="True")) ``` """ -FeatureFlag = Literal["schema_4_feature_flag", "citation_update_feature_flag"] +FeatureFlag = Literal["citation_update_feature_flag"] class FeatureFlagValues: - SCHEMA_4 = "schema_4_feature_flag" CITATION_UPDATE = "citation_update_feature_flag" class FeatureFlagService: + @staticmethod def is_enabled(feature_flag: FeatureFlag) -> bool: flag_value = getattr(CorporaConfig(), feature_flag, "").lower() return flag_value == "true" diff --git a/backend/curation/api/v1/curation/collections/common.py b/backend/curation/api/v1/curation/collections/common.py index c7e65907ead1f..987dae3c4f84a 100644 --- a/backend/curation/api/v1/curation/collections/common.py +++ b/backend/curation/api/v1/curation/collections/common.py @@ -4,7 +4,6 @@ from uuid import UUID from backend.common.corpora_config import CorporaConfig -from backend.common.feature_flag import FeatureFlagService, FeatureFlagValues from backend.common.utils.http_exceptions import ForbiddenHTTPException, GoneHTTPException, NotFoundHTTPException from backend.layers.auth.user_info import UserInfo from backend.layers.business.business import BusinessLogic @@ -228,10 +227,6 @@ def reshape_dataset_for_curation_api( if col is not None: ds[column] = col - if ds.get("tissue") is not None and not FeatureFlagService.is_enabled(FeatureFlagValues.SCHEMA_4): - for tissue in ds["tissue"]: - del tissue["tissue_type"] - ds["dataset_id"] = dataset_version.dataset_id.id ds["dataset_version_id"] = dataset_version.version_id.id # Get none preview specific dataset fields diff --git a/backend/layers/processing/process_validate.py b/backend/layers/processing/process_validate.py index 5b80f6a9fa216..e1b9305d968aa 100644 --- a/backend/layers/processing/process_validate.py +++ b/backend/layers/processing/process_validate.py @@ -3,7 +3,6 @@ import numpy import scanpy -from backend.common.feature_flag import FeatureFlagService, FeatureFlagValues from backend.common.utils.corpora_constants import CorporaConstants from backend.layers.business.business_interface import BusinessLogicInterface from backend.layers.common.entities import ( @@ -82,8 +81,7 @@ def validate_h5ad_file_and_add_labels( if not is_valid: raise ValidationFailed(errors) else: - if FeatureFlagService.is_enabled(FeatureFlagValues.SCHEMA_4): - self.populate_dataset_citation(collection_version_id, dataset_version_id, output_filename) + self.populate_dataset_citation(collection_version_id, dataset_version_id, output_filename) # TODO: optionally, these could be batched into one self.update_processing_status(dataset_version_id, DatasetStatusKey.H5AD, DatasetConversionStatus.CONVERTED) @@ -172,9 +170,7 @@ def _get_batch_condition() -> Optional[str]: return DatasetMetadata( name=adata.uns["title"], organism=_get_term_pairs("organism"), - tissue=_get_tissue_terms() - if FeatureFlagService.is_enabled(FeatureFlagValues.SCHEMA_4) - else _get_term_pairs("tissue"), + tissue=_get_tissue_terms(), assay=_get_term_pairs("assay"), disease=_get_term_pairs("disease"), sex=_get_term_pairs("sex"), diff --git a/backend/layers/processing/reprocess_dataset_metadata.py b/backend/layers/processing/reprocess_dataset_metadata.py deleted file mode 100644 index a9b4a692b8af6..0000000000000 --- a/backend/layers/processing/reprocess_dataset_metadata.py +++ /dev/null @@ -1,328 +0,0 @@ -""" -Updates citations in-place across dataset artifacts for a Collection - -this was used to reprocess dataset metadata in place when an error was found after publishing cellxgene schema 4.0 -TODO: can be removed after 4.0 migration is complete -""" -import json -import logging -import os -from multiprocessing import Process - -import scanpy -import tiledb -from rpy2.robjects import ListVector, StrVector, r -from rpy2.robjects.packages import importr - -from backend.common.utils.corpora_constants import CorporaConstants -from backend.layers.business.business import BusinessLogic -from backend.layers.common.entities import ( - CollectionVersionId, - DatasetArtifactMetadataUpdate, - DatasetArtifactType, - DatasetConversionStatus, - DatasetProcessingStatus, - DatasetVersion, - DatasetVersionId, -) -from backend.layers.persistence.persistence import DatabaseProvider -from backend.layers.processing.exceptions import ProcessingFailed -from backend.layers.processing.h5ad_data_file import H5ADDataFile -from backend.layers.processing.logger import configure_logging -from backend.layers.processing.process_download import ProcessDownload -from backend.layers.thirdparty.s3_provider import S3Provider -from backend.layers.thirdparty.uri_provider import UriProvider - -base = importr("base") -seurat = importr("SeuratObject") - -configure_logging(level=logging.INFO) - - -class DatasetMetadataReprocessWorker(ProcessDownload): - def __init__(self, artifact_bucket: str, datasets_bucket: str) -> None: - # init each worker with business logic backed by non-shared DB connection - self.business_logic = BusinessLogic( - DatabaseProvider(), - None, - None, - None, - S3Provider(), - UriProvider(), - ) - super().__init__(self.business_logic, self.business_logic.uri_provider, self.business_logic.s3_provider) - self.artifact_bucket = artifact_bucket - self.datasets_bucket = datasets_bucket - - def create_updated_artifacts( - self, - file_name: str, - artifact_type: str, - key_prefix: str, - dataset_version_id: DatasetVersionId, - ): - try: - s3_uri = self.upload_artifact(file_name, key_prefix, self.artifact_bucket) - self.logger.info(f"Uploaded [{dataset_version_id}/{file_name}] to {s3_uri}") - - # TODO: include datasets_bucket uploads or not? - # key = ".".join((key_prefix, DatasetArtifactType.H5AD)) - # self.s3_provider.upload_file( - # file_name, self.datasets_bucket, key, extra_args={"ACL": "bucket-owner-full-control"} - # ) - # datasets_s3_uri = self.make_s3_uri(self.datasets_bucket, key_prefix, key) - # self.logger.info(f"Uploaded [{dataset_version_id}/{file_name}] to {datasets_s3_uri}") - except Exception: - self.logger.error(f"Uploading Artifact {artifact_type} from dataset {dataset_version_id} failed.") - raise ProcessingFailed from None - - def update_h5ad( - self, - h5ad_uri: str, - current_dataset_version: DatasetVersion, - new_key_prefix: str, - metadata_update: DatasetArtifactMetadataUpdate, - ): - h5ad_filename = self.download_from_source_uri( - source_uri=h5ad_uri, - local_path=CorporaConstants.LABELED_H5AD_ARTIFACT_FILENAME, - ) - - adata = scanpy.read_h5ad(h5ad_filename) - metadata = current_dataset_version.metadata - # maps artifact name for metadata field to DB field name, if different - for key, val in metadata_update.as_dict_without_none_values().items(): - adata.uns[key] = val - setattr(metadata, key, val) - - adata.write(h5ad_filename, compression="gzip") - current_dataset_version_id = current_dataset_version.version_id - self.business_logic.set_dataset_metadata(current_dataset_version_id, metadata) - self.create_updated_artifacts( - h5ad_filename, DatasetArtifactType.H5AD, new_key_prefix, current_dataset_version_id - ) - os.remove(h5ad_filename) - - def update_rds( - self, - rds_uri: str, - new_key_prefix: str, - current_dataset_version_id: DatasetVersionId, - metadata_update: DatasetArtifactMetadataUpdate, - ): - seurat_filename = self.download_from_source_uri( - source_uri=rds_uri, - local_path=CorporaConstants.LABELED_RDS_ARTIFACT_FILENAME, - ) - - rds_object = base.readRDS(seurat_filename) - - new_keys = [] - seurat_metadata = seurat.Misc(object=rds_object) - for key, val in metadata_update.as_dict_without_none_values().items(): - if seurat_metadata.rx2[key]: - val = val if isinstance(val, list) else [val] - seurat_metadata[seurat_metadata.names.index(key)] = StrVector(val) - else: - new_keys.append((key, val)) - - if new_keys: - new_key_vector = ListVector({k: v for k, v in new_keys}) - seurat_metadata += new_key_vector - r.assign("rds_object", rds_object) - r.assign("seurat_metadata", seurat_metadata) - r("rds_object@misc <- seurat_metadata") - rds_object = r["rds_object"] - - base.saveRDS(rds_object, file=seurat_filename) - - self.create_updated_artifacts( - seurat_filename, DatasetArtifactType.RDS, new_key_prefix, current_dataset_version_id - ) - os.remove(seurat_filename) - - def update_cxg( - self, - cxg_uri: str, - new_cxg_dir: str, - metadata_update: DatasetArtifactMetadataUpdate, - ): - self.s3_provider.upload_directory(cxg_uri, new_cxg_dir) - ctx = tiledb.Ctx(H5ADDataFile.tile_db_ctx_config) - array_name = f"{new_cxg_dir}/cxg_group_metadata" - with tiledb.open(array_name, mode="r", ctx=ctx) as metadata_array: - cxg_metadata_dict = json.loads(metadata_array.meta["corpora"]) - cxg_metadata_dict.update(metadata_update.as_dict_without_none_values()) - - with tiledb.open(array_name, mode="w", ctx=ctx) as metadata_array: - metadata_array.meta["corpora"] = json.dumps(cxg_metadata_dict) - - -class DatasetMetadataReprocess(ProcessDownload): - def __init__( - self, business_logic: BusinessLogic, artifact_bucket: str, cellxgene_bucket: str, datasets_bucket: str - ) -> None: - super().__init__(business_logic, business_logic.uri_provider, business_logic.s3_provider) - self.artifact_bucket = artifact_bucket - self.cellxgene_bucket = cellxgene_bucket - self.datasets_bucket = datasets_bucket - - @staticmethod - def update_h5ad( - artifact_bucket: str, - datasets_bucket: str, - h5ad_uri: str, - current_dataset_version: DatasetVersion, - new_key_prefix: str, - metadata_update: DatasetArtifactMetadataUpdate, - ): - DatasetMetadataReprocessWorker(artifact_bucket, datasets_bucket).update_h5ad( - h5ad_uri, - current_dataset_version, - new_key_prefix, - metadata_update, - ) - - @staticmethod - def update_rds( - artifact_bucket: str, - datasets_bucket: str, - rds_uri: str, - new_key_prefix: str, - current_dataset_version_id: DatasetVersionId, - metadata_update: DatasetArtifactMetadataUpdate, - ): - DatasetMetadataReprocessWorker(artifact_bucket, datasets_bucket).update_rds( - rds_uri, new_key_prefix, current_dataset_version_id, metadata_update - ) - - @staticmethod - def update_cxg( - artifact_bucket: str, - datasets_bucket: str, - cxg_uri: str, - new_cxg_dir: str, - metadata_update: DatasetArtifactMetadataUpdate, - ): - DatasetMetadataReprocessWorker(artifact_bucket, datasets_bucket).update_cxg( - cxg_uri, new_cxg_dir, metadata_update - ) - - def update_dataset_metadata( - self, - current_dataset_version_id: DatasetVersionId, - metadata_update: DatasetArtifactMetadataUpdate, - ): - current_dataset_version = self.business_logic.get_dataset_version(current_dataset_version_id) - if current_dataset_version.status.processing_status != DatasetProcessingStatus.SUCCESS: - self.logger.info( - f"Dataset {current_dataset_version_id} is not successfully processed. Skipping metadata update." - ) - return - - artifact_uris = {artifact.type: artifact.uri for artifact in current_dataset_version.artifacts} - - new_artifact_key_prefix = self.get_key_prefix(current_dataset_version_id.id) + "_updated" - - artifact_jobs = [] - - if DatasetArtifactType.H5AD in artifact_uris: - self.logger.info("Main: Starting thread for h5ad update") - h5ad_job = Process( - target=DatasetMetadataReprocess.update_h5ad, - args=( - self.artifact_bucket, - self.datasets_bucket, - artifact_uris[DatasetArtifactType.H5AD], - current_dataset_version, - new_artifact_key_prefix, - metadata_update, - ), - ) - artifact_jobs.append(h5ad_job) - h5ad_job.start() - else: - self.logger.error(f"Cannot find labeled H5AD artifact uri for {current_dataset_version_id}.") - raise ProcessingFailed from None - - if DatasetArtifactType.RDS in artifact_uris: - self.logger.info("Main: Starting thread for rds update") - # RDS-only, one-time - metadata_update.schema_reference = ( - "https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/4.0.0/schema.md" - ) - rds_job = Process( - target=DatasetMetadataReprocess.update_rds, - args=( - self.artifact_bucket, - self.datasets_bucket, - artifact_uris[DatasetArtifactType.RDS], - new_artifact_key_prefix, - current_dataset_version_id, - metadata_update, - ), - ) - artifact_jobs.append(rds_job) - rds_job.start() - elif current_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED: - pass - else: - self.logger.error( - f"Cannot find RDS artifact uri for {current_dataset_version_id}, and Conversion Status is not SKIPPED." - ) - raise ProcessingFailed from None - - if DatasetArtifactType.CXG in artifact_uris: - self.logger.info("Main: Starting thread for cxg update") - cxg_job = Process( - target=DatasetMetadataReprocess.update_cxg, - args=( - self.artifact_bucket, - self.datasets_bucket, - artifact_uris[DatasetArtifactType.CXG], - f"s3://{self.cellxgene_bucket}/{new_artifact_key_prefix}.cxg", - metadata_update, - ), - ) - artifact_jobs.append(cxg_job) - cxg_job.start() - else: - self.logger.error(f"Cannot find cxg artifact uri for {current_dataset_version_id}.") - raise ProcessingFailed from None - - # blocking call on async functions before checking for valid artifact statuses - [j.join() for j in artifact_jobs] - - def update_dataset_citation( - self, - collection_version_id: CollectionVersionId, - dataset_version_id: DatasetVersionId, - ): - - collection_version = self.business_logic.get_collection_version(collection_version_id) - doi = next((link.uri for link in collection_version.metadata.links if link.type == "DOI"), None) - new_citation = self.business_logic.generate_dataset_citation( - collection_version.collection_id, dataset_version_id, doi - ) - metadata_update = DatasetArtifactMetadataUpdate(citation=new_citation) - self.update_dataset_metadata(dataset_version_id, metadata_update) - - -if __name__ == "__main__": - business_logic = BusinessLogic( - DatabaseProvider(), - None, - None, - None, - S3Provider(), - UriProvider(), - ) - - artifact_bucket = os.environ.get("ARTIFACT_BUCKET", "test-bucket") - cellxgene_bucket = os.environ.get("CELLXGENE_BUCKET", "test-cellxgene-bucket") - datasets_bucket = os.environ.get("DATASETS_BUCKET", "test-datasets-bucket") - collection_version_id = CollectionVersionId(os.environ["COLLECTION_VERSION_ID"]) - dataset_version_id = DatasetVersionId(os.environ["DATASET_VERSION_ID"]) - DatasetMetadataReprocess( - business_logic, artifact_bucket, cellxgene_bucket, datasets_bucket - ).update_dataset_citation(collection_version_id, dataset_version_id) diff --git a/backend/portal/api/enrichment.py b/backend/portal/api/enrichment.py index 254ad59a249f4..b038dfb0e1b90 100644 --- a/backend/portal/api/enrichment.py +++ b/backend/portal/api/enrichment.py @@ -5,8 +5,6 @@ from collections import OrderedDict -from backend.common.feature_flag import FeatureFlagService, FeatureFlagValues - def enrich_dataset_with_ancestors(dataset, key, ontology_mapping): """ @@ -17,12 +15,9 @@ def enrich_dataset_with_ancestors(dataset, key, ontology_mapping): terms = [e["ontology_term_id"] for e in dataset[key]] - is_schema_4 = FeatureFlagService.is_enabled(FeatureFlagValues.SCHEMA_4) is_tissue = key == "tissue" - if is_tissue and is_schema_4: - # TODO remove is_schema_4 condition once Schema 4 is rolled out and - # feature flag is removed (#6266). "tissue" must include "tissue_type" - # when generating ancestors; "cell_type" and "development_stage" do not. + if is_tissue: + # "tissue" must include "tissue_type" when generating ancestors; "cell_type" and "development_stage" do not. terms = [generate_tagged_tissue_ontology_id(e) for e in dataset[key]] else: terms = [e["ontology_term_id"] for e in dataset[key]] diff --git a/backend/portal/api/portal_api.py b/backend/portal/api/portal_api.py index 8d0a49ed713c5..a75b299f77307 100644 --- a/backend/portal/api/portal_api.py +++ b/backend/portal/api/portal_api.py @@ -6,7 +6,6 @@ from flask import Response, jsonify, make_response -from backend.common.feature_flag import FeatureFlagService, FeatureFlagValues from backend.common.utils.http_exceptions import ( ConflictException, ForbiddenHTTPException, @@ -175,9 +174,6 @@ def _dataset_to_response( if is_in_revision and revision_created_at and dataset.created_at > revision_created_at: published = False tissue = None if dataset.metadata is None else _ontology_term_ids_to_response(dataset.metadata.tissue) - if tissue is not None and not FeatureFlagService.is_enabled(FeatureFlagValues.SCHEMA_4): - for t in tissue: - del t["tissue_type"] return remove_none( { "assay": None if dataset.metadata is None else _ontology_term_ids_to_response(dataset.metadata.assay), diff --git a/scripts/smoke_tests/setup.py b/scripts/smoke_tests/setup.py index 9af2544c42efc..310d2876bcd21 100644 --- a/scripts/smoke_tests/setup.py +++ b/scripts/smoke_tests/setup.py @@ -63,13 +63,10 @@ def publish_collection(self, collection_id): collection_count = smoke_test_init.get_collection_count() if collection_count >= NUM_TEST_COLLECTIONS: sys.exit(0) - if smoke_test_init.is_using_schema_4: - dataset_dropbox_url = ( - "https://www.dropbox.com/scl/fi/d99hpw3p2cxtmi7v4kyv5/" - "4_0_0_test_dataset.h5ad?rlkey=i5ownt8g1mropbu41r7fa0i06&dl=0" - ) - else: - dataset_dropbox_url = "https://www.dropbox.com/s/m1ur46nleit8l3w/3_0_0_valid.h5ad?dl=0" + dataset_dropbox_url = ( + "https://www.dropbox.com/scl/fi/d99hpw3p2cxtmi7v4kyv5/" + "4_0_0_test_dataset.h5ad?rlkey=i5ownt8g1mropbu41r7fa0i06&dl=0" + ) num_to_create = NUM_TEST_COLLECTIONS - collection_count threads = [] for _ in range(num_to_create): diff --git a/tests/functional/backend/common.py b/tests/functional/backend/common.py index 45e2cd7372388..d72fdd9017c34 100644 --- a/tests/functional/backend/common.py +++ b/tests/functional/backend/common.py @@ -10,7 +10,6 @@ from requests.packages.urllib3.util import Retry from backend.common.corpora_config import CorporaAuthConfig -from backend.common.feature_flag import FeatureFlagService, FeatureFlagValues API_URL = { "prod": "https://api.cellxgene.cziscience.com", @@ -39,17 +38,9 @@ def setUpClass(cls, smoke_tests: bool = False): super().setUpClass() cls.deployment_stage = os.environ["DEPLOYMENT_STAGE"] cls.config = CorporaAuthConfig() - cls.is_using_schema_4 = FeatureFlagService.is_enabled(FeatureFlagValues.SCHEMA_4) cls.test_dataset_uri = ( - ( - "https://www.dropbox.com/scl/fi/d99hpw3p2cxtmi7v4kyv5/" - "4_0_0_test_dataset.h5ad?rlkey=i5ownt8g1mropbu41r7fa0i06&dl=0" - ) - if cls.is_using_schema_4 - else ( - "https://www.dropbox.com/scl/fi/phrt3ru8ulep7ttnwttu2/" - "example_valid.h5ad?rlkey=mmcm2qd9xrnbqle3l3vyii0gx&dl=0" - ) + "https://www.dropbox.com/scl/fi/d99hpw3p2cxtmi7v4kyv5/" + "4_0_0_test_dataset.h5ad?rlkey=i5ownt8g1mropbu41r7fa0i06&dl=0" ) cls.session = requests.Session() # apply retry config to idempotent http methods we use + POST requests, which are currently all either diff --git a/tests/unit/backend/common/test_feature_flag.py b/tests/unit/backend/common/test_feature_flag.py index df3c9f9078516..842c8aa7cd2b7 100644 --- a/tests/unit/backend/common/test_feature_flag.py +++ b/tests/unit/backend/common/test_feature_flag.py @@ -1,17 +1,19 @@ import unittest +from unittest.mock import Mock, patch from backend.common.corpora_config import CorporaConfig -from backend.common.feature_flag import FeatureFlagService, FeatureFlagValues +from backend.common.feature_flag import FeatureFlagService class TestFeatureFlag(unittest.TestCase): def setUp(self): super().setUp() self.mock_config = CorporaConfig() - self.mock_config.set({"schema_4_feature_flag": "True"}) + self.mock_config.set({"test_feature_flag": "True"}) def tearDown(self): self.mock_config.reset() + @patch("backend.common.feature_flag.FeatureFlag", Mock()) def test_feature_flag(self): - self.assertTrue(FeatureFlagService.is_enabled(FeatureFlagValues.SCHEMA_4)) + self.assertTrue(FeatureFlagService.is_enabled("test_feature_flag")) diff --git a/tests/unit/backend/layers/business/test_business.py b/tests/unit/backend/layers/business/test_business.py index a6eeb6f59b581..70b642fe8c100 100644 --- a/tests/unit/backend/layers/business/test_business.py +++ b/tests/unit/backend/layers/business/test_business.py @@ -93,7 +93,6 @@ def setUp(self) -> None: self.mock_config.set( { "upload_max_file_size_gb": 30, - "schema_4_feature_flag": "True", "citation_update_feature_flag": "True", "dataset_assets_base_url": "https://dataset_assets_domain", "collections_base_url": "https://collections_domain", diff --git a/tests/unit/backend/layers/common/base_api_test.py b/tests/unit/backend/layers/common/base_api_test.py index 711dd6a13b278..1bdb1851e6ae6 100644 --- a/tests/unit/backend/layers/common/base_api_test.py +++ b/tests/unit/backend/layers/common/base_api_test.py @@ -78,7 +78,6 @@ def setUp(self): "upload_max_file_size_gb": 1, "dataset_assets_base_url": "http://domain", "citation_update_feature_flag": "True", - "schema_4_feature_flag": "True", } self.mock_auth_config = CorporaAuthConfig() diff --git a/tests/unit/backend/layers/common/base_test.py b/tests/unit/backend/layers/common/base_test.py index 959ef39b82b2f..30c41411525d5 100644 --- a/tests/unit/backend/layers/common/base_test.py +++ b/tests/unit/backend/layers/common/base_test.py @@ -84,7 +84,6 @@ def setUp(self): "collections_base_url": "https://domain", "dataset_assets_base_url": "http://domain", "citation_update_feature_flag": "True", - "schema_4_feature_flag": "True", } # Mock CorporaConfig # TODO: deduplicate with base_api diff --git a/tests/unit/processing/test_h5ad_data_file.py b/tests/unit/processing/test_h5ad_data_file.py index 74b387fcdd82c..92e0ea4befd7b 100644 --- a/tests/unit/processing/test_h5ad_data_file.py +++ b/tests/unit/processing/test_h5ad_data_file.py @@ -9,7 +9,6 @@ import tiledb from pandas import Categorical, DataFrame, Series -from backend.common.corpora_config import CorporaConfig from backend.common.utils.corpora_constants import CorporaConstants from backend.layers.processing.h5ad_data_file import H5ADDataFile from tests.unit.backend.fixtures.environment_setup import fixture_file_path @@ -22,9 +21,6 @@ def setUp(self): self.sample_output_directory = path.splitext(self.sample_h5ad_filename)[0] + ".cxg" - self.mock_config = CorporaConfig() - self.mock_config.set({"schema_4_feature_flag": "True"}) - def tearDown(self): if self.sample_h5ad_filename: remove(self.sample_h5ad_filename) @@ -32,8 +28,6 @@ def tearDown(self): if path.isdir(self.sample_output_directory): rmtree(self.sample_output_directory) - self.mock_config.reset() - def test__create_h5ad_data_file__non_h5ad_raises_exception(self): non_h5ad_filename = "my_fancy_dataset.csv" From 0bdde21e8cb3cf54e9626a998df1631a217ae917 Mon Sep 17 00:00:00 2001 From: Daniel Hegeman Date: Wed, 3 Jan 2024 15:18:04 -0800 Subject: [PATCH 7/7] chore: reference correct CrossrefProvider code in unit tests (#6407) --- backend/common/providers/crossref_provider.py | 154 ------------------ .../layers/thirdparty/crossref_provider.py | 6 +- .../thirdparty/test_crossref_provider.py | 32 ++-- 3 files changed, 20 insertions(+), 172 deletions(-) delete mode 100644 backend/common/providers/crossref_provider.py diff --git a/backend/common/providers/crossref_provider.py b/backend/common/providers/crossref_provider.py deleted file mode 100644 index 8fdea7ebdbafa..0000000000000 --- a/backend/common/providers/crossref_provider.py +++ /dev/null @@ -1,154 +0,0 @@ -import html -import logging -from datetime import datetime -from urllib.parse import urlparse - -import requests - -from backend.common.corpora_config import CorporaConfig - - -class CrossrefException(Exception): - pass - - -class CrossrefFetchException(CrossrefException): - pass - - -class CrossrefDOINotFoundException(CrossrefException): - pass - - -class CrossrefParseException(CrossrefException): - pass - - -class CrossrefProvider: - """ - Provider class used to call Crossref and retrieve publisher metadata - """ - - def __init__(self) -> None: - self.base_crossref_uri = "https://api.crossref.org/works" - try: - self.crossref_api_key = CorporaConfig().crossref_api_key - except RuntimeError: - self.crossref_api_key = None - super().__init__() - - @staticmethod - def parse_date_parts(obj): - date_parts = obj["date-parts"][0] - year = date_parts[0] - month = date_parts[1] if len(date_parts) > 1 else 1 - day = date_parts[2] if len(date_parts) > 2 else 1 - return (year, month, day) - - def _fetch_crossref_payload(self, doi): - # Remove the https://doi.org part - parsed = urlparse(doi) - if parsed.scheme and parsed.netloc: - doi = parsed.path - - if self.crossref_api_key is None: - logging.info("No Crossref API key found, skipping metadata fetching.") - return None - - try: - res = requests.get( - f"{self.base_crossref_uri}/{doi}", - headers={"Crossref-Plus-API-Token": f"Bearer {self.crossref_api_key}"}, - ) - res.raise_for_status() - except requests.RequestException as e: - if e.response is not None and e.response.status_code == 404: - raise CrossrefDOINotFoundException from e - else: - raise CrossrefFetchException("Cannot fetch metadata from Crossref") from e - - return res - - def fetch_metadata(self, doi: str) -> dict: - """ - Fetches and extracts publisher metadata from Crossref for a specified DOI. - If the Crossref API URI isn't in the configuration, we will just return an empty object. - This is to avoid calling Crossref in non-production environments. - """ - - res = self._fetch_crossref_payload(doi) - if not res: - return - - try: - message = res.json()["message"] - - # Date - published_date = ( - message.get("published-print") or message.get("published") or message.get("published-online") - ) - - if published_date is None: - raise CrossrefParseException("Date node missing") - - published_year, published_month, published_day = self.parse_date_parts(published_date) - - dates = [] - for k, v in message.items(): - if isinstance(v, dict) and "date-parts" in v: - dt = v["date-parts"][0] - dates.append(f"{k}: {dt}") - - # Journal - try: - if "short-container-title" in message and message["short-container-title"]: - raw_journal = message["short-container-title"][0] - elif "container-title" in message and message["container-title"]: - raw_journal = message["container-title"][0] - elif "institution" in message: - raw_journal = message["institution"][0]["name"] - except Exception: - raise CrossrefParseException("Journal node missing") from None - - journal = html.unescape(raw_journal) - - # Authors - # Note: make sure that the order is preserved, as it is a relevant information - authors = message["author"] - parsed_authors = [] - for author in authors: - if "given" in author and "family" in author: - parsed_authors.append({"given": author["given"], "family": author["family"]}) - elif "family" in author: - # Assume family is consortium - parsed_authors.append({"name": author["family"]}) - elif "name" in author: - parsed_authors.append({"name": author["name"]}) - - # Preprint - is_preprint = message.get("subtype") == "preprint" - - return { - "authors": parsed_authors, - "published_year": published_year, - "published_month": published_month, - "published_day": published_day, - "published_at": datetime.timestamp(datetime(published_year, published_month, published_day)), - "journal": journal, - "is_preprint": is_preprint, - } - except Exception as e: - raise CrossrefParseException("Cannot parse metadata from Crossref") from e - - def fetch_preprint_published_doi(self, doi): - res = self._fetch_crossref_payload(doi) - message = res.json()["message"] - is_preprint = message.get("subtype") == "preprint" - - if is_preprint: - try: - published_doi = message["relation"]["is-preprint-of"] - if published_doi[0]["id-type"] == "doi": - return published_doi[0]["id"] - except Exception: - pass diff --git a/backend/layers/thirdparty/crossref_provider.py b/backend/layers/thirdparty/crossref_provider.py index 3f97ce70cbf1b..6401184a078c9 100644 --- a/backend/layers/thirdparty/crossref_provider.py +++ b/backend/layers/thirdparty/crossref_provider.py @@ -69,8 +69,8 @@ def _fetch_crossref_payload(self, doi): headers={"Crossref-Plus-API-Token": f"Bearer {self.crossref_api_key}"}, ) res.raise_for_status() - except Exception as e: - if res.status_code == 404: + except requests.RequestException as e: + if e.response is not None and e.response.status_code == 404: raise CrossrefDOINotFoundException from e else: raise CrossrefFetchException("Cannot fetch metadata from Crossref") from e @@ -85,6 +85,8 @@ def fetch_metadata(self, doi: str) -> dict: """ res = self._fetch_crossref_payload(doi) + if not res: + return try: message = res.json()["message"] diff --git a/tests/unit/backend/layers/thirdparty/test_crossref_provider.py b/tests/unit/backend/layers/thirdparty/test_crossref_provider.py index 4567001bac1b9..355beee4f27c7 100644 --- a/tests/unit/backend/layers/thirdparty/test_crossref_provider.py +++ b/tests/unit/backend/layers/thirdparty/test_crossref_provider.py @@ -5,7 +5,7 @@ from requests import RequestException from requests.models import HTTPError, Response -from backend.common.providers.crossref_provider import ( +from backend.layers.thirdparty.crossref_provider import ( CrossrefDOINotFoundException, CrossrefException, CrossrefFetchException, @@ -15,15 +15,15 @@ class TestCrossrefProvider(unittest.TestCase): - @patch("backend.common.providers.crossref_provider.requests.get") + @patch("backend.layers.thirdparty.crossref_provider.requests.get") def test__provider_does_not_call_crossref_in_test(self, mock_get): provider = CrossrefProvider() res = provider.fetch_metadata("test_doi") self.assertIsNone(res) mock_get.assert_not_called() - @patch("backend.common.providers.crossref_provider.requests.get") - @patch("backend.common.providers.crossref_provider.CorporaConfig") + @patch("backend.layers.thirdparty.crossref_provider.requests.get") + @patch("backend.layers.thirdparty.crossref_provider.CorporaConfig") def test__provider_calls_crossref_if_api_key_defined(self, mock_config, mock_get): # Defining a mocked CorporaConfig will allow the provider to consider the `crossref_api_key` # not None, so it will go ahead and do the mocked call. @@ -71,8 +71,8 @@ def test__provider_calls_crossref_if_api_key_defined(self, mock_config, mock_get self.assertDictEqual(expected_response, res) - @patch("backend.common.providers.crossref_provider.requests.get") - @patch("backend.common.providers.crossref_provider.CorporaConfig") + @patch("backend.layers.thirdparty.crossref_provider.requests.get") + @patch("backend.layers.thirdparty.crossref_provider.CorporaConfig") def test__provider_parses_authors_and_dates_correctly(self, mock_config, mock_get): response = Response() response.status_code = 200 @@ -136,8 +136,8 @@ def test__provider_parses_authors_and_dates_correctly(self, mock_config, mock_ge self.assertDictEqual(expected_response, res) - @patch("backend.common.providers.crossref_provider.requests.get") - @patch("backend.common.providers.crossref_provider.CorporaConfig") + @patch("backend.layers.thirdparty.crossref_provider.requests.get") + @patch("backend.layers.thirdparty.crossref_provider.CorporaConfig") def test__provider_unescapes_journal_correctly(self, mock_config, mock_get): response = Response() response.status_code = 200 @@ -175,8 +175,8 @@ def test__provider_unescapes_journal_correctly(self, mock_config, mock_get): self.assertDictEqual(expected_response, res) - @patch("backend.common.providers.crossref_provider.requests.get") - @patch("backend.common.providers.crossref_provider.CorporaConfig") + @patch("backend.layers.thirdparty.crossref_provider.requests.get") + @patch("backend.layers.thirdparty.crossref_provider.CorporaConfig") def test__provider_throws_exception_if_request_fails(self, mock_config, mock_get): """ Asserts a CrossrefFetchException if the GET request fails for any reason @@ -192,8 +192,8 @@ def test__provider_throws_exception_if_request_fails(self, mock_config, mock_get with self.assertRaises(CrossrefException): provider.fetch_metadata("test_doi") - @patch("backend.common.providers.crossref_provider.requests.get") - @patch("backend.common.providers.crossref_provider.CorporaConfig") + @patch("backend.layers.thirdparty.crossref_provider.requests.get") + @patch("backend.layers.thirdparty.crossref_provider.CorporaConfig") def test__provider_throws_exception_if_request_fails_with_404(self, mock_config, mock_get): """ Asserts a CrossrefFetchException if the GET request fails for any reason @@ -207,8 +207,8 @@ def test__provider_throws_exception_if_request_fails_with_404(self, mock_config, with self.assertRaises(CrossrefDOINotFoundException): provider.fetch_metadata("test_doi") - @patch("backend.common.providers.crossref_provider.requests.get") - @patch("backend.common.providers.crossref_provider.CorporaConfig") + @patch("backend.layers.thirdparty.crossref_provider.requests.get") + @patch("backend.layers.thirdparty.crossref_provider.CorporaConfig") def test__provider_throws_exception_if_request_fails_with_non_2xx_code(self, mock_config, mock_get): """ Asserts a CrossrefFetchException if the GET request return a 500 error (any non 2xx will work) @@ -223,8 +223,8 @@ def test__provider_throws_exception_if_request_fails_with_non_2xx_code(self, moc with self.assertRaises(CrossrefFetchException): provider.fetch_metadata("test_doi") - @patch("backend.common.providers.crossref_provider.requests.get") - @patch("backend.common.providers.crossref_provider.CorporaConfig") + @patch("backend.layers.thirdparty.crossref_provider.requests.get") + @patch("backend.layers.thirdparty.crossref_provider.CorporaConfig") def test__provider_throws_exception_if_request_cannot_be_parsed(self, mock_config, mock_get): """ Asserts an CrossrefParseException if the GET request succeeds but cannot be parsed