diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 4eb48b992..9c166432f 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -15,6 +15,8 @@ jobs: release: runs-on: ubuntu-latest concurrency: release + outputs: + released: ${{ steps.semrelease.outputs.released }} permissions: # NOTE: this enables trusted publishing. # See https://github.com/pypa/gh-action-pypi-publish/tree/release/v1#trusted-publishing @@ -46,7 +48,8 @@ jobs: - name: Publish package to GitHub Release uses: python-semantic-release/upload-to-gh-release@main - if: ${{ steps.semrelease.outputs.released }} == 'true' + # NOTE: semrelease output is a string, so we need to compare it to a string + if: steps.semrelease.outputs.released == 'true' with: # NOTE: allow to start the workflow when push action on tag gets executed # requires using GH_APP to authenitcate, otherwise push authorised with @@ -56,6 +59,7 @@ jobs: tag: ${{ steps.semrelease.outputs.tag }} - name: Store the distribution packages + if: steps.semrelease.outputs.released == 'true' uses: actions/upload-artifact@v4 with: name: python-package-distributions @@ -63,9 +67,8 @@ jobs: publish-to-pypi: needs: release - name: >- - Publish 📦 in PyPI - if: github.ref == 'refs/heads/main' + name: Publish 📦 in PyPI + if: github.ref == 'refs/heads/main' && needs.release.outputs.released == 'true' runs-on: ubuntu-latest environment: name: pypi @@ -84,7 +87,7 @@ jobs: publish-to-testpypi: name: Publish 📦 in TestPyPI needs: release - if: github.ref != 'refs/heads/main' + if: github.ref == 'refs/heads/main' && needs.release.outputs.released == 'true' runs-on: ubuntu-latest environment: @@ -108,7 +111,7 @@ jobs: documentation: needs: release runs-on: ubuntu-latest - if: github.ref == 'refs/heads/main' + if: github.ref == 'refs/heads/main' && needs.release.outputs.released == 'true' steps: - uses: actions/checkout@v4 with: diff --git a/docs/python_api/common/version_engine.md b/docs/python_api/common/version_engine.md deleted file mode 100644 index 28d9b4b2e..000000000 --- a/docs/python_api/common/version_engine.md +++ /dev/null @@ -1,12 +0,0 @@ ---- -title: VersionEngine ---- - -**VersionEngine**: - -Version engine allows for registering datasource specific version seeker class to retrieve datasource version used as input to gentropy steps. Currently implemented only for GnomAD datasource. - -This class can be then used to produce automation over output directory versioning. - -:::gentropy.common.version_engine.VersionEngine -:::gentropy.common.version_engine.GnomADVersionSeeker diff --git a/poetry.lock b/poetry.lock index d162df5b2..d5e3b3696 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2065,13 +2065,13 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio [[package]] name = "ipython" -version = "8.29.0" +version = "8.30.0" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.10" files = [ - {file = "ipython-8.29.0-py3-none-any.whl", hash = "sha256:0188a1bd83267192123ccea7f4a8ed0a78910535dbaa3f37671dca76ebd429c8"}, - {file = "ipython-8.29.0.tar.gz", hash = "sha256:40b60e15b22591450eef73e40a027cf77bd652e757523eebc5bd7c7c498290eb"}, + {file = "ipython-8.30.0-py3-none-any.whl", hash = "sha256:85ec56a7e20f6c38fce7727dcca699ae4ffc85985aa7b23635a8008f918ae321"}, + {file = "ipython-8.30.0.tar.gz", hash = "sha256:cb0a405a306d2995a5cbb9901894d240784a9f341394c6ba3f4fe8c6eb89ff6e"}, ] [package.dependencies] @@ -2081,16 +2081,16 @@ exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} jedi = ">=0.16" matplotlib-inline = "*" pexpect = {version = ">4.3", markers = "sys_platform != \"win32\" and sys_platform != \"emscripten\""} -prompt-toolkit = ">=3.0.41,<3.1.0" +prompt_toolkit = ">=3.0.41,<3.1.0" pygments = ">=2.4.0" -stack-data = "*" +stack_data = "*" traitlets = ">=5.13.0" -typing-extensions = {version = ">=4.6", markers = "python_version < \"3.12\""} +typing_extensions = {version = ">=4.6", markers = "python_version < \"3.12\""} [package.extras] all = ["ipython[black,doc,kernel,matplotlib,nbconvert,nbformat,notebook,parallel,qtconsole]", "ipython[test,test-extra]"] black = ["black"] -doc = ["docrepr", "exceptiongroup", "intersphinx-registry", "ipykernel", "ipython[test]", "matplotlib", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "sphinxcontrib-jquery", "tomli", "typing-extensions"] +doc = ["docrepr", "exceptiongroup", "intersphinx_registry", "ipykernel", "ipython[test]", "matplotlib", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "sphinxcontrib-jquery", "tomli", "typing_extensions"] kernel = ["ipykernel"] matplotlib = ["matplotlib"] nbconvert = ["nbconvert"] diff --git a/src/gentropy/common/version_engine.py b/src/gentropy/common/version_engine.py deleted file mode 100644 index d852d8f5d..000000000 --- a/src/gentropy/common/version_engine.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Mechanism to seek version from specific datasource.""" - -from __future__ import annotations - -import re -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Callable - -from gentropy.common.types import DataSourceType - - -class VersionEngine: - """Seek version from the datasource.""" - - def __init__(self, datasource: DataSourceType) -> None: - """Initialize VersionEngine. - - Args: - datasource (DataSourceType): datasource to seek the version from - """ - self.datasource = datasource - - @staticmethod - def version_seekers() -> dict[DataSourceType, DatasourceVersionSeeker]: - """List version seekers. - - Returns: - dict[DataSourceType, DatasourceVersionSeeker]: list of available data sources. - """ - return { - "gnomad": GnomADVersionSeeker(), - } - - def seek(self, text: str | Path) -> str: - """Interface for inferring the version from text by using registered data source version iner method. - - Args: - text (str | Path): text to seek version from - - Returns: - str: inferred version - - Raises: - TypeError: if version can not be found in the text - - Examples: - >>> VersionEngine("gnomad").seek("gs://gcp-public-data--gnomad/release/2.1.1/vcf/genomes/gnomad.genomes.r2.1.1.sites.vcf.bgz") - '2.1.1' - """ - match text: - case Path() | str(): - text = str(text) - case _: - msg = f"Can not find version in {text}" - raise TypeError(msg) - infer_method = self._get_version_seek_method() - return infer_method(text) - - def _get_version_seek_method(self) -> Callable[[str], str]: - """Method that gets the version seeker for the datasource. - - Returns: - Callable[[str], str]: Method to seek version based on the initialized datasource - - Raises: - ValueError: if datasource is not registered in the list of version seekers - """ - if self.datasource not in self.version_seekers(): - raise ValueError(f"Invalid datasource {self.datasource}") - return self.version_seekers()[self.datasource].seek_version - - def amend_version( - self, analysis_input_path: str | Path, analysis_output_path: str | Path - ) -> str: - """Amend version to the analysis output path if it is not already present. - - Path can be path to g3:// or Path object, absolute or relative. - The analysis_input_path has to contain the version number. - If the analysis_output_path contains the same version as inferred from input version already, - then it will not be appended. - - Args: - analysis_input_path (str | Path): step input path - analysis_output_path (str | Path): step output path - - Returns: - str: Path with the ammended version, does not return Path object! - - Examples: - >>> VersionEngine("gnomad").amend_version("gs://gcp-public-data--gnomad/release/2.1.1/vcf/genomes/gnomad.genomes.r2.1.1.sites.vcf.bgz", "/some/path/without/version") - '/some/path/without/version/2.1.1' - """ - version = self.seek(analysis_input_path) - output_path = str(analysis_output_path) - if version in output_path: - return output_path - if output_path.endswith("/"): - return f"{analysis_output_path}{version}" - return f"{analysis_output_path}/{version}" - - -class DatasourceVersionSeeker(ABC): - """Interface for datasource version seeker. - - Raises: - NotImplementedError: if method is not implemented in the subclass - """ - - @staticmethod - @abstractmethod - def seek_version(text: str) -> str: - """Seek version from text. Implement this method in the subclass. - - Args: - text (str): text to seek version from - - Returns: - str: seeked version - - Raises: - NotImplementedError: if method is not implemented in the subclass - - """ - raise NotImplementedError - - -class GnomADVersionSeeker(DatasourceVersionSeeker): - """Seek version from GnomAD datasource.""" - - @staticmethod - def seek_version(text: str) -> str: - """Seek GnomAD version from provided text by using regex. - - Up to 3 digits are allowed in the version number. - Historically gnomAD version numbers have been in the format - 2.1.1, 3.1, etc. as of 2024-05. GnomAD versions can be found by - running `"gs://gcp-public-data--gnomad/release/*/*/*"` - - Args: - text (str): text to seek version from - - Raises: - ValueError: if version can not be seeked - - Returns: - str: seeked version - - Examples: - >>> GnomADVersionSeeker.seek_version("gs://gcp-public-data--gnomad/release/2.1.1/vcf/genomes/gnomad.genomes.r2.1.1.sites.vcf.bgz") - '2.1.1' - """ - result = re.search(r"v?((\d+){1}\.(\d+){1}\.?(\d+)?)", text) - match result: - case None: - raise ValueError(f"No GnomAD version found in provided text: {text}") - case _: - return result.group(1) diff --git a/src/gentropy/config.py b/src/gentropy/config.py index 9c454d41b..0d9d6abf1 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -17,7 +17,8 @@ class SessionConfig: write_mode: str = "errorifexists" spark_uri: str = "local[*]" hail_home: str = os.path.dirname(hail_location) - extended_spark_conf: dict[str, str] | None = field(default_factory=dict[str, str]) + extended_spark_conf: dict[str, str] | None = field( + default_factory=dict[str, str]) output_partitions: int = 200 _target_: str = "gentropy.common.session.Session" @@ -39,7 +40,8 @@ class ColocalisationConfig(StepConfig): credible_set_path: str = MISSING coloc_path: str = MISSING colocalisation_method: str = MISSING - colocalisation_method_params: dict[str, Any] = field(default_factory=dict[str, Any]) + colocalisation_method_params: dict[str, Any] = field( + default_factory=dict[str, Any]) _target_: str = "gentropy.colocalisation.ColocalisationStep" @@ -124,7 +126,8 @@ class EqtlCatalogueConfig(StepConfig): eqtl_catalogue_paths_imported: str = MISSING eqtl_catalogue_study_index_out: str = MISSING eqtl_catalogue_credible_sets_out: str = MISSING - mqtl_quantification_methods_blacklist: list[str] = field(default_factory=lambda: []) + mqtl_quantification_methods_blacklist: list[str] = field( + default_factory=lambda: []) eqtl_lead_pvalue_threshold: float = 1e-3 _target_: str = "gentropy.eqtl_catalogue.EqtlCatalogueStep" @@ -146,7 +149,8 @@ class FinngenStudiesConfig(StepConfig): ) finngen_summary_stats_url_suffix: str = ".gz" efo_curation_mapping_url: str = "https://raw.githubusercontent.com/opentargets/curation/24.09.1/mappings/disease/manual_string.tsv" - sample_size: int = 453733 # https://www.finngen.fi/en/access_results#:~:text=Total%20sample%20size%3A%C2%A0453%2C733%C2%A0(254%2C618%C2%A0females%20and%C2%A0199%2C115%20males) + # https://www.finngen.fi/en/access_results#:~:text=Total%20sample%20size%3A%C2%A0453%2C733%C2%A0(254%2C618%C2%A0females%20and%C2%A0199%2C115%20males) + sample_size: int = 453733 _target_: str = "gentropy.finngen_studies.FinnGenStudiesStep" @@ -199,7 +203,6 @@ class LDIndexConfig(StepConfig): "nfe", # Non-Finnish European ] ) - use_version_from_input: bool = False _target_: str = "gentropy.gnomad_ingestion.LDIndexStep" @@ -409,7 +412,6 @@ class GnomadVariantConfig(StepConfig): "remaining", # Other ] ) - use_version_from_input: bool = False _target_: str = "gentropy.gnomad_ingestion.GnomadVariantIndexStep" @@ -432,7 +434,6 @@ class PanUKBBConfig(StepConfig): "EUR", # European ] ) - use_version_from_input: bool = False _target_: str = "gentropy.pan_ukb_ingestion.PanUKBBVariantIndexStep" @@ -680,7 +681,8 @@ class Config: """Application configuration.""" # this is unfortunately verbose due to @dataclass limitations - defaults: List[Any] = field(default_factory=lambda: ["_self_", {"step": MISSING}]) + defaults: List[Any] = field(default_factory=lambda: [ + "_self_", {"step": MISSING}]) step: StepConfig = MISSING datasets: dict[str, str] = field(default_factory=dict) @@ -714,7 +716,8 @@ def register_config() -> None: name="gwas_catalog_top_hit_ingestion", node=GWASCatalogTopHitIngestionConfig, ) - cs.store(group="step", name="ld_based_clumping", node=LDBasedClumpingConfig) + cs.store(group="step", name="ld_based_clumping", + node=LDBasedClumpingConfig) cs.store(group="step", name="ld_index", node=LDIndexConfig) cs.store(group="step", name="locus_to_gene", node=LocusToGeneConfig) cs.store( @@ -732,7 +735,8 @@ def register_config() -> None: cs.store(group="step", name="pics", node=PICSConfig) cs.store(group="step", name="gnomad_variants", node=GnomadVariantConfig) - cs.store(group="step", name="ukb_ppp_eur_sumstat_preprocess", node=UkbPppEurConfig) + cs.store(group="step", name="ukb_ppp_eur_sumstat_preprocess", + node=UkbPppEurConfig) cs.store(group="step", name="variant_index", node=VariantIndexConfig) cs.store(group="step", name="variant_to_vcf", node=ConvertToVcfStepConfig) cs.store( @@ -765,5 +769,7 @@ def register_config() -> None: name="locus_to_gene_associations", node=LocusToGeneAssociationsStepConfig, ) - cs.store(group="step", name="finngen_ukb_meta_ingestion", node=FinngenUkbMetaConfig) - cs.store(group="step", name="credible_set_qc", node=CredibleSetQCStepConfig) + cs.store(group="step", name="finngen_ukb_meta_ingestion", + node=FinngenUkbMetaConfig) + cs.store(group="step", name="credible_set_qc", + node=CredibleSetQCStepConfig) diff --git a/src/gentropy/dataset/dataset.py b/src/gentropy/dataset/dataset.py index 67fe05eaf..fa06faec2 100644 --- a/src/gentropy/dataset/dataset.py +++ b/src/gentropy/dataset/dataset.py @@ -8,8 +8,9 @@ from functools import reduce from typing import TYPE_CHECKING, Any -import pyspark.sql.functions as f -from pyspark.sql.types import DoubleType +from pyspark.sql import DataFrame +from pyspark.sql import functions as f +from pyspark.sql import types as t from pyspark.sql.window import Window from typing_extensions import Self @@ -18,7 +19,7 @@ if TYPE_CHECKING: from enum import Enum - from pyspark.sql import Column, DataFrame + from pyspark.sql import Column from pyspark.sql.types import StructType from gentropy.common.session import Session @@ -26,17 +27,34 @@ @dataclass class Dataset(ABC): - """Open Targets Gentropy Dataset. + """Open Targets Gentropy Dataset Interface. - `Dataset` is a wrapper around a Spark DataFrame with a predefined schema. Schemas for each child dataset are described in the `schemas` module. + The `Dataset` interface is a wrapper around a Spark DataFrame with a predefined schema. + Class allows for overwriting the schema with `_schema` parameter. + If the `_schema` is not provided, the schema is inferred from the Dataset.get_schema specific + method which must be implemented by the child classes. """ _df: DataFrame - _schema: StructType + _schema: StructType | None = None def __post_init__(self: Dataset) -> None: - """Post init.""" - self.validate_schema() + """Post init. + + Raises: + TypeError: If the type of the _df or _schema is not valid + """ + match self._df: + case DataFrame(): + pass + case _: + raise TypeError(f"Invalid type for _df: {type(self._df)}") + + match self._schema: + case None | t.StructType(): + self.validate_schema() + case _: + raise TypeError(f"Invalid type for _schema: {type(self._schema)}") @property def df(self: Dataset) -> DataFrame: @@ -64,7 +82,7 @@ def schema(self: Dataset) -> StructType: Returns: StructType: Dataframe expected schema """ - return self._schema + return self._schema or self.get_schema() @classmethod def _process_class_params( @@ -172,7 +190,7 @@ def validate_schema(self: Dataset) -> None: Raises: SchemaValidationError: If the DataFrame schema does not match the expected schema """ - expected_schema = self._schema + expected_schema = self.schema observed_schema = self._df.schema # Unexpected fields in dataset @@ -244,7 +262,7 @@ def drop_infinity_values(self: Self, *cols: str) -> Self: if len(cols) == 0: return self inf_strings = ("Inf", "+Inf", "-Inf", "Infinity", "+Infinity", "-Infinity") - inf_values = [f.lit(v).cast(DoubleType()) for v in inf_strings] + inf_values = [f.lit(v).cast(t.DoubleType()) for v in inf_strings] conditions = [f.col(c).isin(inf_values) for c in cols] # reduce individual filter expressions with or statement # to col("beta").isin([lit(Inf)]) | col("beta").isin([lit(Inf)])... diff --git a/src/gentropy/dataset/pairwise_ld.py b/src/gentropy/dataset/pairwise_ld.py index b64592094..ab68a74ab 100644 --- a/src/gentropy/dataset/pairwise_ld.py +++ b/src/gentropy/dataset/pairwise_ld.py @@ -38,7 +38,7 @@ def __post_init__(self: PairwiseLD) -> None: ), f"The number of rows in a pairwise LD table has to be square. Got: {row_count}" self.dimension = (int(sqrt(row_count)), int(sqrt(row_count))) - self.validate_schema() + super().__post_init__() @classmethod def get_schema(cls: type[PairwiseLD]) -> StructType: diff --git a/src/gentropy/dataset/study_index.py b/src/gentropy/dataset/study_index.py index da310f6f1..6bc9d9fd3 100644 --- a/src/gentropy/dataset/study_index.py +++ b/src/gentropy/dataset/study_index.py @@ -67,6 +67,18 @@ class StudyIndex(Dataset): A study index dataset captures all the metadata for all studies including GWAS and Molecular QTL. """ + VALID_TYPES = [ + "gwas", + "eqtl", + "pqtl", + "sqtl", + "tuqtl", + "sceqtl", + "scpqtl", + "scsqtl", + "sctuqtl", + ] + @staticmethod def _aggregate_samples_by_ancestry(merged: Column, ancestry: Column) -> Column: """Aggregate sample counts by ancestry in a list of struct colmns. diff --git a/src/gentropy/gnomad_ingestion.py b/src/gentropy/gnomad_ingestion.py index d930b54c6..e0960c8f7 100644 --- a/src/gentropy/gnomad_ingestion.py +++ b/src/gentropy/gnomad_ingestion.py @@ -4,7 +4,6 @@ from gentropy.common.session import Session from gentropy.common.types import LD_Population, VariantPopulation -from gentropy.common.version_engine import VersionEngine from gentropy.config import GnomadVariantConfig, LDIndexConfig from gentropy.datasource.gnomad.ld import GnomADLDMatrix from gentropy.datasource.gnomad.variants import GnomADVariants @@ -26,10 +25,10 @@ def __init__( min_r2: float = LDIndexConfig().min_r2, ld_matrix_template: str = LDIndexConfig().ld_matrix_template, ld_index_raw_template: str = LDIndexConfig().ld_index_raw_template, - ld_populations: list[LD_Population | str] = LDIndexConfig().ld_populations, + ld_populations: list[LD_Population | + str] = LDIndexConfig().ld_populations, liftover_ht_path: str = LDIndexConfig().liftover_ht_path, grch37_to_grch38_chain_path: str = LDIndexConfig().grch37_to_grch38_chain_path, - use_version_from_input: bool = LDIndexConfig().use_version_from_input, ) -> None: """Run step. @@ -42,17 +41,9 @@ def __init__( ld_populations (list[LD_Population | str]): Population names derived from the ld file paths liftover_ht_path (str): Path to the liftover ht file grch37_to_grch38_chain_path (str): Path to the chain file used to lift over the coordinates. - use_version_from_input (bool): Append version derived from input ld_matrix_template to the output ld_index_out. Defaults to False. - In case use_version_from_input is set to True, - data source version inferred from ld_matrix_temolate is appended as the last path segment to the output path. Default values are provided in LDIndexConfig. """ - if use_version_from_input: - # amend data source version to output path - ld_index_out = VersionEngine("gnomad").amend_version( - ld_matrix_template, ld_index_out - ) ( GnomADLDMatrix( ld_matrix_template=ld_matrix_template, @@ -84,7 +75,6 @@ def __init__( gnomad_variant_populations: list[ VariantPopulation | str ] = GnomadVariantConfig().gnomad_variant_populations, - use_version_from_input: bool = GnomadVariantConfig().use_version_from_input, ) -> None: """Run Variant Annotation step. @@ -93,18 +83,10 @@ def __init__( variant_annotation_path (str): Path to resulting dataset. gnomad_genomes_path (str): Path to gnomAD genomes hail table, e.g. `gs://gcp-public-data--gnomad/release/4.0/ht/genomes/gnomad.genomes.v4.0.sites.ht/`. gnomad_variant_populations (list[VariantPopulation | str]): List of populations to include. - use_version_from_input (bool): Append version derived from input gnomad_genomes_path to the output variant_annotation_path. Defaults to False. - In case use_version_from_input is set to True, - data source version inferred from gnomad_genomes_path is appended as the last path segment to the output path. All defaults are stored in the GnomadVariantConfig. """ # amend data source version to output path - if use_version_from_input: - variant_annotation_path = VersionEngine("gnomad").amend_version( - gnomad_genomes_path, variant_annotation_path - ) - session.logger.info("Gnomad variant annotation path:") session.logger.info(variant_annotation_path) # Parse variant info from source. diff --git a/tests/gentropy/common/test_version_engine.py b/tests/gentropy/common/test_version_engine.py deleted file mode 100644 index 2ee2e12ce..000000000 --- a/tests/gentropy/common/test_version_engine.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Tests version engine class.""" - -from __future__ import annotations - -from pathlib import Path - -import pytest - -from gentropy.common.version_engine import GnomADVersionSeeker, VersionEngine - - -@pytest.mark.parametrize( - ["text", "version"], - [ - pytest.param( - "gcp-public-data--gnomad/release/2.1.1/vcf/genomes/gnomad.genomes.r2.1.1.sites.7.vcf", - "2.1.1", - id="GnomAD v2.1.1", - ), - pytest.param( - "/gcp-public-data--gnomad/release/3.0/vcf/genomes/gnomad.genomes.r3.0.sites.chr6.vcf", - "3.0", - id="GnomAD v3.0", - ), - pytest.param( - "gs://gcp-public-data--gnomad/release/3.1.1/vcf/genomes/gnomad.genomes.v3.1.1.sites.chr1.vcf", - "3.1.1", - id="GnomAD v3.1.1", - ), - pytest.param( - "gs://gcp-public-data--gnomad/release/3.1.2/vcf/genomes/gnomad.genomes.v3.1.2.sites.chrY.vcf", - "3.1.2", - id="GnomAD v3.1.2", - ), - pytest.param( - "gsa://gcp-public-data--gnomad/release/4.0/vcf/genomes/gnomad.genomes.v4.0.sites.chrY.vcf", - "4.0", - id="GnomAD v4.0", - ), - pytest.param( - "gs://gcp-public-data--gnomad/release/4.1/vcf/genomes/gnomad.genomes.v4.1.sites.chr18.vcf", - "4.1", - id="GnomAD v4.1", - ), - pytest.param( - "/some/path/to/the/version/r20.111.44", - "20.111.44", - id="Extreme version number", - ), - ], -) -def test_extracting_version_with_gnomad_seeker(text: str, version: str) -> None: - """Test gnomad version extraction with GnomADVersionSeeker.""" - version_seeker = GnomADVersionSeeker().seek_version - assert version_seeker(text) == version - - -def test_not_registered_datasource_raises_error() -> None: - """Test that unknown datasource raises error.""" - with pytest.raises(ValueError) as e: - VersionEngine("ClinVar").seek("some/path/to/the/version/v20.111.44") # type: ignore - assert e.value.args[0].startswith("Invalid datasource ClinVar") - - -def test_extracting_version_when_no_version_is_found() -> None: - """Test that unknown datasource raises error.""" - with pytest.raises(ValueError) as e: - VersionEngine("ClinVar").seek("some/path/without/version") # type: ignore - assert e.value.args[0].startswith( - "Can not find version in some/path/without/version" - ) - - -def test_non_string_path_raises_error() -> None: - """Test that non-string path raises error.""" - with pytest.raises(TypeError) as e: - VersionEngine("gnomad").seek(123) # type: ignore - assert e.value.args[0].startswith("Can not infer version from 123") - - -@pytest.mark.parametrize( - ["text", "version"], - [ - pytest.param(Path("some/file/path/v3.1.1"), "3.1.1", id="Path object"), - pytest.param("s3://some/file/path/v3.1.1", "3.1.1", id="S3 protocol"), - pytest.param("gs://some/file/path/v3.1.1", "3.1.1", id="GS protocol"), - ], -) -def test_extracting_version_with_version_engine(text: str | Path, version: str) -> None: - """Check if concrete data types and file protocols does not return an error while passed to VersionEngine.""" - assert VersionEngine("gnomad").seek(text) == version - - -@pytest.mark.parametrize( - ["input_path", "output_path", "expected_output"], - [ - pytest.param( - "input/v20.111.44", "output", "output/20.111.44", id="Append version" - ), - pytest.param( - "input/1.0.0", - "output/1.0.0", - "output/1.0.0", - id="Do not append version, already present", - ), - pytest.param( - Path("input/1.0.0"), Path("output/"), "output/1.0.0", id="Path objects" - ), - ], -) -def test_appending_version_to_path( - input_path: Path | str, output_path: Path | str, expected_output: str -) -> None: - """Test that the version is ammended at the end of the output path.""" - VersionEngine("gnomad").amend_version(input_path, output_path) == expected_output diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index f19c28623..4cff392d4 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -76,6 +76,14 @@ def mock_colocalisation(spark: SparkSession) -> Colocalisation: randomSeedMethod="hash_fieldname", ) .withSchema(coloc_schema) + .withColumnSpec( + "leftStudyLocusId", + expr="cast(id as string)", + ) + .withColumnSpec( + "rightStudyLocusId", + expr="cast(id as string)", + ) .withColumnSpec("h0", percentNulls=0.1) .withColumnSpec("h1", percentNulls=0.1) .withColumnSpec("h2", percentNulls=0.1) @@ -103,6 +111,10 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame: randomSeedMethod="hash_fieldname", ) .withSchema(si_schema) + .withColumnSpec( + "studyId", + expr="cast(id as string)", + ) .withColumnSpec( "traitFromSourceMappedIds", expr="array(cast(rand() AS string))", @@ -123,7 +135,10 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame: expr='array(named_struct("sampleSize", cast(rand() as string), "ancestry", cast(rand() as string)))', percentNulls=0.1, ) - .withColumnSpec("geneId", percentNulls=0.1) + .withColumnSpec( + "geneId", + expr="cast(id as string)", + ) .withColumnSpec("pubmedId", percentNulls=0.1) .withColumnSpec("publicationFirstAuthor", percentNulls=0.1) .withColumnSpec("publicationDate", percentNulls=0.1) @@ -134,9 +149,7 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame: .withColumnSpec("nControls", percentNulls=0.1) .withColumnSpec("nSamples", percentNulls=0.1) .withColumnSpec("summarystatsLocation", percentNulls=0.1) - .withColumnSpec( - "studyType", percentNulls=0.0, values=["eqtl", "pqtl", "sqtl", "gwas"] - ) + .withColumnSpec("studyType", percentNulls=0.0, values=StudyIndex.VALID_TYPES) ) return data_spec.build() @@ -164,12 +177,30 @@ def mock_study_locus_overlap(spark: SparkSession) -> StudyLocusOverlap: """Mock StudyLocusOverlap dataset.""" overlap_schema = StudyLocusOverlap.get_schema() - data_spec = dg.DataGenerator( - spark, - rows=400, - partitions=4, - randomSeedMethod="hash_fieldname", - ).withSchema(overlap_schema) + data_spec = ( + dg.DataGenerator( + spark, + rows=400, + partitions=4, + randomSeedMethod="hash_fieldname", + ) + .withSchema(overlap_schema) + .withColumnSpec( + "leftStudyLocusId", + expr="cast(id as string)", + ) + .withColumnSpec( + "rightStudyLocusId", + expr="cast(id as string)", + ) + .withColumnSpec( + "tagVariantId", + expr="cast(id as string)", + ) + .withColumnSpec( + "rightStudyType", percentNulls=0.0, values=StudyIndex.VALID_TYPES + ) + ) return StudyLocusOverlap(_df=data_spec.build(), _schema=overlap_schema) @@ -186,6 +217,10 @@ def mock_study_locus_data(spark: SparkSession) -> DataFrame: randomSeedMethod="hash_fieldname", ) .withSchema(sl_schema) + .withColumnSpec( + "variantId", + expr="cast(id as string)", + ) .withColumnSpec("chromosome", percentNulls=0.1) .withColumnSpec("position", minValue=100, percentNulls=0.1) .withColumnSpec("beta", percentNulls=0.1) @@ -202,7 +237,7 @@ def mock_study_locus_data(spark: SparkSession) -> DataFrame: .withColumnSpec("finemappingMethod", percentNulls=0.1) .withColumnSpec( "locus", - expr='array(named_struct("is95CredibleSet", cast(rand() > 0.5 as boolean), "is99CredibleSet", cast(rand() > 0.5 as boolean), "logBF", rand(), "posteriorProbability", rand(), "variantId", cast(rand() as string), "beta", rand(), "standardError", rand(), "r2Overall", rand(), "pValueMantissa", rand(), "pValueExponent", rand()))', + expr='array(named_struct("is95CredibleSet", cast(rand() > 0.5 as boolean), "is99CredibleSet", cast(rand() > 0.5 as boolean), "logBF", rand(), "posteriorProbability", rand(), "variantId", cast(floor(rand() * 400) + 1 as string), "beta", rand(), "standardError", rand(), "r2Overall", rand(), "pValueMantissa", rand(), "pValueExponent", rand()))', percentNulls=0.1, ) ) @@ -262,6 +297,10 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex: randomSeedMethod="hash_fieldname", ) .withSchema(vi_schema) + .withColumnSpec( + "variantId", + expr="cast(id as string)", + ) .withColumnSpec("mostSevereConsequenceId", percentNulls=0.1) # Nested column handling workaround # https://github.com/databrickslabs/dbldatagen/issues/135 @@ -275,7 +314,7 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex: "assessment", cast(rand() as string), "score", rand(), "assessmentFlag", cast(rand() as string), - "targetId", cast(rand() as string), + "targetId", cast(floor(rand() * 400) + 1 as string), "normalizedScore", cast(rand() as float) ) ) @@ -298,11 +337,11 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex: "uniprotAccessions", array(cast(rand() as string)), "isEnsemblCanonical", cast(rand() as boolean), "codons", cast(rand() as string), - "distanceFromTss", cast(rand() as long), - "distanceFromFootprint", cast(rand() as long), + "distanceFromTss", cast(floor(rand() * 500000) as long), + "distanceFromFootprint", cast(floor(rand() * 500000) as long), "appris", cast(rand() as string), "maneSelect", cast(rand() as string), - "targetId", cast(rand() as string), + "targetId", cast(floor(rand() * 400) + 1 as string), "impact", cast(rand() as string), "lofteePrediction", cast(rand() as string), "siftPrediction", rand(), @@ -310,7 +349,7 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex: "consequenceScore", cast(rand() as float), "transcriptIndex", cast(rand() as integer), "transcriptId", cast(rand() as string), - "biotype", cast(rand() as string), + "biotype", 'protein_coding', "approvedSymbol", cast(rand() as string) ) ) @@ -355,13 +394,20 @@ def mock_summary_statistics_data(spark: SparkSession) -> DataFrame: name="summaryStats", ) .withSchema(ss_schema) + .withColumnSpec( + "studyId", + expr="cast(id as string)", + ) + .withColumnSpec( + "variantId", + expr="cast(id as string)", + ) # Allowing missingness in effect allele frequency and enforce upper limit: .withColumnSpec( "effectAlleleFrequencyFromSource", percentNulls=0.1, maxValue=1.0 ) # Allowing missingness: .withColumnSpec("standardError", percentNulls=0.1) - # Making sure p-values are below 1: ) return data_spec.build() @@ -390,6 +436,10 @@ def mock_ld_index(spark: SparkSession) -> LDIndex: randomSeedMethod="hash_fieldname", ) .withSchema(ld_schema) + .withColumnSpec( + "variantId", + expr="cast(id as string)", + ) .withColumnSpec( "ldSet", expr="array(named_struct('tagVariantId', cast(rand() as string), 'rValues', array(named_struct('population', cast(rand() as string), 'r', cast(rand() as double)))))", @@ -526,18 +576,24 @@ def mock_gene_index(spark: SparkSession) -> GeneIndex: data_spec = ( dg.DataGenerator( spark, - rows=400, + rows=30, partitions=4, randomSeedMethod="hash_fieldname", ) .withSchema(gi_schema) + .withColumnSpec( + "geneId", + expr="cast(id as string)", + ) .withColumnSpec("approvedSymbol", percentNulls=0.1) - .withColumnSpec("biotype", percentNulls=0.1) + .withColumnSpec( + "biotype", percentNulls=0.1, values=["protein_coding", "lncRNA"] + ) .withColumnSpec("approvedName", percentNulls=0.1) .withColumnSpec("tss", percentNulls=0.1) .withColumnSpec("start", percentNulls=0.1) .withColumnSpec("end", percentNulls=0.1) - .withColumnSpec("strand", percentNulls=0.1) + .withColumnSpec("strand", percentNulls=0.1, values=[1, -1]) ) return GeneIndex(_df=data_spec.build(), _schema=gi_schema) @@ -559,6 +615,10 @@ def mock_biosample_index(spark: SparkSession) -> BiosampleIndex: randomSeedMethod="hash_fieldname", ) .withSchema(bi_schema) + .withColumnSpec( + "biosampleId", + expr="cast(id as string)", + ) .withColumnSpec("biosampleName", percentNulls=0.1) .withColumnSpec("description", percentNulls=0.1) .withColumnSpec("xrefs", expr=array_expression, percentNulls=0.1) @@ -613,9 +673,35 @@ def mock_l2g_feature_matrix(spark: SparkSession) -> L2GFeatureMatrix: def mock_l2g_gold_standard(spark: SparkSession) -> L2GGoldStandard: """Mock l2g gold standard dataset.""" schema = L2GGoldStandard.get_schema() - data_spec = dg.DataGenerator( - spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname" - ).withSchema(schema) + data_spec = ( + dg.DataGenerator( + spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname" + ) + .withSchema(schema) + .withColumnSpec( + "studyLocusId", + expr="cast(id as string)", + ) + .withColumnSpec( + "variantId", + expr="cast(id as string)", + ) + .withColumnSpec( + "geneId", + expr="cast(id as string)", + ) + .withColumnSpec( + "traitFromSourceMappedId", + expr="cast(id as string)", + ) + .withColumnSpec( + "goldStandardSet", + values=[ + L2GGoldStandard.GS_NEGATIVE_LABEL, + L2GGoldStandard.GS_POSITIVE_LABEL, + ], + ) + ) return L2GGoldStandard(_df=data_spec.build(), _schema=schema) @@ -624,9 +710,20 @@ def mock_l2g_gold_standard(spark: SparkSession) -> L2GGoldStandard: def mock_l2g_predictions(spark: SparkSession) -> L2GPrediction: """Mock l2g predictions dataset.""" schema = L2GPrediction.get_schema() - data_spec = dg.DataGenerator( - spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname" - ).withSchema(schema) + data_spec = ( + dg.DataGenerator( + spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname" + ) + .withSchema(schema) + .withColumnSpec( + "studyLocusId", + expr="cast(id as string)", + ) + .withColumnSpec( + "geneId", + expr="cast(id as string)", + ) + ) return L2GPrediction(_df=data_spec.build(), _schema=schema) diff --git a/tests/gentropy/dataset/test_dataset.py b/tests/gentropy/dataset/test_dataset.py index 7c61f3f52..96a96ec27 100644 --- a/tests/gentropy/dataset/test_dataset.py +++ b/tests/gentropy/dataset/test_dataset.py @@ -21,32 +21,44 @@ def get_schema(cls) -> StructType: return StructType([StructField("value", IntegerType(), False)]) -class TestCoalesceAndRepartition: +class TestDataset: """Test TestDataset.coalesce and TestDataset.repartition.""" - def test_repartition(self: TestCoalesceAndRepartition) -> None: + def test_repartition(self: TestDataset) -> None: """Test Dataset.repartition.""" initial_partitions = self.test_dataset._df.rdd.getNumPartitions() new_partitions = initial_partitions + 1 self.test_dataset.repartition(new_partitions) assert self.test_dataset._df.rdd.getNumPartitions() == new_partitions - def test_coalesce(self: TestCoalesceAndRepartition) -> None: + def test_coalesce(self: TestDataset) -> None: """Test Dataset.coalesce.""" initial_partitions = self.test_dataset._df.rdd.getNumPartitions() new_partitions = initial_partitions - 1 if initial_partitions > 1 else 1 self.test_dataset.coalesce(new_partitions) assert self.test_dataset._df.rdd.getNumPartitions() == new_partitions + def test_initialize_without_schema(self: TestDataset, spark: SparkSession) -> None: + """Test if Dataset derived class collects the schema from assets if schema is not provided.""" + df = spark.createDataFrame([(1,)], schema=MockDataset.get_schema()) + ds = MockDataset(_df=df) + assert ( + ds.schema == MockDataset.get_schema() + ), "Schema should be inferred from df" + + def test_passing_incorrect_types(self: TestDataset, spark: SparkSession) -> None: + """Test if passing incorrect object types to Dataset raises an error.""" + with pytest.raises(TypeError): + MockDataset(_df="not a dataframe") + with pytest.raises(TypeError): + MockDataset(_df=self.df, _schema="not a schema") + @pytest.fixture(autouse=True) - def _setup(self: TestCoalesceAndRepartition, spark: SparkSession) -> None: + def _setup(self: TestDataset, spark: SparkSession) -> None: """Setup fixture.""" - self.test_dataset = MockDataset( - _df=spark.createDataFrame( - [(1,), (2,), (3,)], schema=MockDataset.get_schema() - ), - _schema=MockDataset.get_schema(), - ) + df = spark.createDataFrame([(1,), (2,), (3,)], schema=MockDataset.get_schema()) + self.df = df + self.test_dataset = MockDataset(_df=df, _schema=MockDataset.get_schema()) def test_dataset_filter(mock_study_index: StudyIndex) -> None: @@ -68,6 +80,7 @@ def test_dataset_drop_infinity_values() -> None: rows = [(v,) for v in data] schema = StructType([StructField("field", DoubleType())]) input_df = spark.createDataFrame(rows, schema=schema) + assert input_df.count() == 7 # run without specifying *cols results in no filtering ds = MockDataset(_df=input_df, _schema=schema) @@ -76,7 +89,7 @@ def test_dataset_drop_infinity_values() -> None: assert ds.drop_infinity_values("field").df.count() == 1 -def test__process_class_params(spark: SparkSession) -> None: +def test_process_class_params(spark: SparkSession) -> None: """Test splitting of parameters between class and spark parameters.""" params = { "_df": spark.createDataFrame([(1,)], schema=MockDataset.get_schema()),