Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(isProteinCoding): get all overlapping genes from the variant index #948

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
56 changes: 35 additions & 21 deletions src/gentropy/dataset/l2g_features/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,42 +85,56 @@ def common_genecount_feature_logic(
def is_protein_coding_feature_logic(
study_loci_to_annotate: StudyLocus | L2GGoldStandard,
*,
gene_index: GeneIndex,
variant_index: VariantIndex,
feature_name: str,
genomic_window: int,
genomic_window: int = 500_000,
) -> DataFrame:
"""Computes the feature to indicate if a gene is protein-coding or not.

Args:
study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci
that will be used for annotation
gene_index (GeneIndex): Dataset containing information related to all genes in release.
variant_index (VariantIndex): Dataset containing information related to all overlapping genes within a genomic window.
feature_name (str): The name of the feature
genomic_window (int): The maximum window size to consider
genomic_window (int): The window size around the locus to consider. Defaults to its maximum value: 500kb up and downstream the locus

Returns:
DataFrame: Feature dataset, with 1 if the gene is protein-coding, 0 if not.
"""
study_loci_window = (
study_loci_to_annotate.df.withColumn(
"window_start", f.col("position") - (genomic_window / 2)
assert genomic_window <= 500_000, "Genomic window must be less than 500kb."
genes_in_window = (
variant_index.df.withColumn(
"transcriptConsequence", f.explode("transcriptConsequences")
)
.withColumn("window_end", f.col("position") + (genomic_window / 2))
.withColumnRenamed("chromosome", "SL_chromosome")
.select(
"variantId",
f.col("transcriptConsequence.targetId").alias("geneId"),
f.col("transcriptConsequence.biotype").alias("biotype"),
f.col("transcriptConsequence.distanceFromFootprint").alias(
"distanceFromFootprint"
),
)
.filter(f.col("distanceFromFootprint") <= genomic_window)
)
if isinstance(study_loci_to_annotate, StudyLocus):
variants_df = study_loci_to_annotate.df.select(
f.explode_outer("locus.variantId").alias("variantId"),
"studyLocusId",
).filter(f.col("variantId").isNotNull())
elif isinstance(study_loci_to_annotate, L2GGoldStandard):
variants_df = study_loci_to_annotate.df.select("studyLocusId", "variantId")
return (
study_loci_window.join(
gene_index.df.alias("genes"),
on=(
(f.col("SL_chromosome") == f.col("genes.chromosome"))
& (f.col("genes.tss") >= f.col("window_start"))
& (f.col("genes.tss") <= f.col("window_end"))
),
how="inner",
# Annotate all genes in the window of a locus
variants_df.join(
genes_in_window,
on="variantId",
)
# Apply flag across all variants in the locus
.withColumn(
feature_name,
f.when(f.col("biotype") == "protein_coding", f.lit(1)).otherwise(f.lit(0)),
f.when(f.col("biotype") == "protein_coding", f.lit(1.0)).otherwise(
f.lit(0.0)
),
)
.select("studyLocusId", "geneId", feature_name)
.distinct()
Expand Down Expand Up @@ -211,7 +225,7 @@ def compute(
class ProteinCodingFeature(L2GFeature):
"""Indicates whether a gene is protein-coding within a specified window size from the study locus."""

feature_dependency_type = GeneIndex
feature_dependency_type = VariantIndex
feature_name = "isProteinCoding"

@classmethod
Expand All @@ -224,12 +238,12 @@ def compute(

Args:
study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation
feature_dependency (dict[str, Any]): Dictionary containing dependencies, including gene index
feature_dependency (dict[str, Any]): Dictionary containing dependencies, including variant index

Returns:
ProteinCodingFeature: Feature dataset with 1 if the gene is protein-coding, 0 otherwise
"""
genomic_window = 1000000
genomic_window = 500_000
protein_coding_df = is_protein_coding_feature_logic(
study_loci_to_annotate=study_loci_to_annotate,
feature_name=cls.feature_name,
Expand Down
84 changes: 52 additions & 32 deletions tests/gentropy/dataset/test_l2g_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,11 @@ def sample_variant_index_schema() -> StructType:
ArrayType(
StructType(
[
StructField("distanceFromFootprint", LongType(), True),
StructField("distanceFromTss", LongType(), True),
StructField("targetId", StringType(), True),
StructField("isEnsemblCanonical", BooleanType(), True),
StructField("biotype", StringType(), True),
]
)
),
Expand Down Expand Up @@ -624,13 +626,17 @@ def _setup(
[
{
"distanceFromTss": 10,
"distanceFromFootprint": 0,
"targetId": "gene1",
"isEnsemblCanonical": True,
"biotype": "protein_coding",
},
{
"distanceFromTss": 2,
"distanceFromFootprint": 0,
"targetId": "gene2",
"isEnsemblCanonical": True,
"biotype": "protein_coding",
},
],
),
Expand All @@ -643,8 +649,10 @@ def _setup(
[
{
"distanceFromTss": 5,
"distanceFromFootprint": 0,
"targetId": "gene1",
"isEnsemblCanonical": True,
"biotype": "protein_coding",
},
],
),
Expand Down Expand Up @@ -928,9 +936,8 @@ class TestCommonProteinCodingFeatureLogic:
[
(
[
{"studyLocusId": "1", "geneId": "gene1", "isProteinCoding500kb": 1},
{"studyLocusId": "1", "geneId": "gene2", "isProteinCoding500kb": 1},
{"studyLocusId": "1", "geneId": "gene3", "isProteinCoding500kb": 0},
{"studyLocusId": "1", "geneId": "gene1", "isProteinCoding": 1.0},
{"studyLocusId": "1", "geneId": "gene2", "isProteinCoding": 0.0},
]
),
],
Expand All @@ -944,25 +951,28 @@ def test_is_protein_coding_feature_logic(
observed_df = (
is_protein_coding_feature_logic(
study_loci_to_annotate=self.sample_study_locus,
gene_index=self.sample_gene_index,
feature_name="isProteinCoding500kb",
genomic_window=500000,
variant_index=self.sample_variant_index,
feature_name="isProteinCoding",
)
.select("studyLocusId", "geneId", "isProteinCoding500kb")
.select("studyLocusId", "geneId", "isProteinCoding")
.orderBy("studyLocusId", "geneId")
)

expected_df = (
spark.createDataFrame(expected_data)
.select("studyLocusId", "geneId", "isProteinCoding500kb")
.select("studyLocusId", "geneId", "isProteinCoding")
.orderBy("studyLocusId", "geneId")
)
assert (
observed_df.collect() == expected_df.collect()
), "Expected and observed DataFrames do not match."

@pytest.fixture(autouse=True)
def _setup(self: TestCommonProteinCodingFeatureLogic, spark: SparkSession) -> None:
def _setup(
self: TestCommonProteinCodingFeatureLogic,
spark: SparkSession,
sample_variant_index_schema: StructType,
) -> None:
"""Set up sample data for the test."""
# Sample study locus data
self.sample_study_locus = StudyLocus(
Expand All @@ -974,39 +984,47 @@ def _setup(self: TestCommonProteinCodingFeatureLogic, spark: SparkSession) -> No
"studyId": "study1",
"chromosome": "1",
"position": 1000000,
"locus": [
{
"variantId": "var1",
},
],
},
],
StudyLocus.get_schema(),
),
_schema=StudyLocus.get_schema(),
)

# Sample gene index data with biotype
self.sample_gene_index = GeneIndex(
self.sample_variant_index = VariantIndex(
_df=spark.createDataFrame(
[
{
"geneId": "gene1",
"chromosome": "1",
"tss": 950000,
"biotype": "protein_coding",
},
{
"geneId": "gene2",
"chromosome": "1",
"tss": 1050000,
"biotype": "protein_coding",
},
{
"geneId": "gene3",
"chromosome": "1",
"tss": 1010000,
"biotype": "non_coding",
},
(
"var1",
"chrom",
1,
"A",
"T",
[
{
"distanceFromFootprint": 0,
"distanceFromTss": 10,
"targetId": "gene1",
"biotype": "protein_coding",
"isEnsemblCanonical": True,
},
{
"distanceFromFootprint": 0,
"distanceFromTss": 20,
"targetId": "gene2",
"biotype": "non_coding",
"isEnsemblCanonical": True,
},
],
),
],
GeneIndex.get_schema(),
sample_variant_index_schema,
),
_schema=GeneIndex.get_schema(),
_schema=VariantIndex.get_schema(),
)


Expand Down Expand Up @@ -1067,8 +1085,10 @@ def _setup(
[
{
"distanceFromTss": 10,
"distanceFromFootprint": 0,
"targetId": "gene1",
"isEnsemblCanonical": True,
"biotype": "protein_coding",
},
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pyspark.sql.session import SparkSession

from gentropy.dataset.colocalisation import Colocalisation
from gentropy.dataset.gene_index import GeneIndex
from gentropy.dataset.study_locus import StudyLocus


Expand Down Expand Up @@ -162,15 +161,15 @@ def test_build_feature_matrix(
mock_study_locus: StudyLocus,
mock_colocalisation: Colocalisation,
mock_study_index: StudyIndex,
mock_gene_index: GeneIndex,
mock_variant_index: VariantIndex,
) -> None:
"""Test building feature matrix with the eQtlColocH4Maximum feature."""
"""Test building feature matrix with the eQtlColocH4Maximum and isProteinCoding features."""
features_list = ["eQtlColocH4Maximum", "isProteinCoding"]
loader = L2GFeatureInputLoader(
colocalisation=mock_colocalisation,
study_index=mock_study_index,
study_locus=mock_study_locus,
gene_index=mock_gene_index,
variant_index=mock_variant_index,
)
fm = mock_study_locus.build_feature_matrix(features_list, loader)
assert isinstance(
Expand Down
Loading