Skip to content

Commit

Permalink
BUG division by zero when num_samples == num_vars (#175)
Browse files Browse the repository at this point in the history
* fix: check that the design matrix has strictly more samples than variables, else throw a ValueError

* test: check that a ValueError is thrown when the design matrix has num_samples==num_vars
  • Loading branch information
BorisMuzellec authored Oct 3, 2023
1 parent 07aa6b0 commit 8fbcc2f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
9 changes: 9 additions & 0 deletions pydeseq2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,15 @@ def fit_rough_dispersions(
"""

num_samples, num_vars = design_matrix.shape
# This method is only possible when num_samples > num_vars.
# If this is not the case, throw an error.
if num_samples == num_vars:
raise ValueError(
"The number of samples and the number of design variables are "
"equal, i.e., there are no replicates to estimate the "
"dispersion. Please use a design with fewer variables."
)

# Exclude genes with all zeroes
normed_counts = normed_counts[:, ~(normed_counts == 0).all(axis=0)]

Expand Down
26 changes: 24 additions & 2 deletions tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def test_one_factor():


def test_rank_deficient_design():
"""Test that a ValueError is thrown when the design matrix does not have full column
rank."""
"""Test that a UserWarning is thrown when the design matrix does not have full
column rank."""
counts_df = pd.DataFrame(
{"gene1": [0, 1], "gene2": [4, 12]}, index=["sample1", "sample2"]
)
Expand All @@ -178,6 +178,28 @@ def test_rank_deficient_design():
)


def test_equal_num_vars_num_samples_design():
"""Test that a ValueError is thrown when fitting dispersions if the design matrix
has eaual numbers of rows and columns."""
counts_df = pd.DataFrame(
{"gene1": [0, 1, 55], "gene2": [4, 12, 60]},
index=["sample1", "sample2", "sample3"],
)
metadata = pd.DataFrame(
{"condition": [0, 1, 0], "batch": ["A", "B", "B"]},
index=["sample1", "sample2", "sample3"],
)

dds = DeseqDataSet(
counts=counts_df, metadata=metadata, design_factors=["condition", "batch"]
)

dds.fit_size_factors()

with pytest.raises(ValueError):
dds.fit_genewise_dispersions()


def test_reference_level():
"""Test that a ValueError is thrown when the reference level is not one of the
design factor values."""
Expand Down

0 comments on commit 8fbcc2f

Please sign in to comment.