Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate variance explained #88

Merged
merged 9 commits into from
Aug 8, 2024
138 changes: 84 additions & 54 deletions docs/tutorials/generate_in_silico_data.ipynb

Large diffs are not rendered by default.

538 changes: 528 additions & 10 deletions docs/tutorials/hyperparameter_sweep.ipynb

Large diffs are not rendered by default.

333 changes: 320 additions & 13 deletions docs/tutorials/lightning_crash_course.ipynb

Large diffs are not rendered by default.

1,228 changes: 1,192 additions & 36 deletions docs/tutorials/testing_model_metrics.ipynb

Large diffs are not rendered by default.

4,774 changes: 4,567 additions & 207 deletions docs/tutorials/visualizing_and_testing_data_generation_methods.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion experiments/simple_model_synthetic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def simple_model_synthetic_data_experiment(
data_module = SyntheticDataLoader(
batch_size=batch_size,
num_genes=1000,
signal=[0.1, 0.15, 0.2, 0.25, 0.3],
bound=[0.1, 0.15, 0.2, 0.25, 0.3],
n_sample=[1, 1, 2, 2, 4],
val_size=0.1,
test_size=0.1,
Expand Down
2 changes: 1 addition & 1 deletion yeastdnnexplorer/data_loaders/real_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def prepare_data(self) -> None:
perturbation_pvalues.values, dtype=torch.float64
)

# note that we no longer have a signal / noise tensor
# note that we no longer have a bound / unbound tensor
# (like for the synthetic data)
self.final_data_tensor = torch.stack(
[
Expand Down
48 changes: 24 additions & 24 deletions yeastdnnexplorer/data_loaders/synthetic_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(
self,
batch_size: int = 32,
num_genes: int = 1000,
signal: list[float] = [0.1, 0.2, 0.2, 0.4, 0.5],
signal_mean: float = 3.0,
bound: list[float] = [0.1, 0.2, 0.2, 0.4, 0.5],
bound_mean: float = 3.0,
n_sample: list[int] = [1, 2, 2, 4, 4],
val_size: float = 0.1,
test_size: float = 0.1,
Expand All @@ -47,10 +47,10 @@ def __init__(
:param num_genes: The number of genes in the synthetic data (this is the number
of datapoints in our dataset)
:type num_genes: int
:param signal: The proportion of genes in each sample group that are put in the
signal grop (i.e. have a non-zero binding effect and expression response)
:type signal: List[int]
:param n_sample: The number of samples to draw from each signal group
:param bound: The proportion of genes in each sample group that are put in the
bound grop (i.e. have a non-zero binding effect and expression response)
:type bound: List[int]
:param n_sample: The number of samples to draw from each bound group
:type n_sample: List[int]
:param val_size: The proportion of the dataset to include in the validation
split
Expand All @@ -60,23 +60,23 @@ def __init__(
:param random_state: The random seed to use for splitting the data (keep this
consistent to ensure reproduceability)
:type random_state: int
:param signal_mean: The mean of the signal distribution
:type signal_mean: float
:param bound_mean: The mean of the bound distribution
:type bound_mean: float
:param max_mean_adjustment: The maximum mean adjustment to apply to the mean
of the signal (bound) perturbation effects
of the bound (bound) perturbation effects
:type max_mean_adjustment: float
:param adjustment_function: A function that adjusts the mean of the signal
:param adjustment_function: A function that adjusts the mean of the bound
(bound) perturbation effects
:type adjustment_function: Callable[[torch.Tensor, float, float,
float, dict[int, list[int]]], torch.Tensor]
:raises TypeError: If batch_size is not an positive integer
:raises TypeError: If num_genes is not an positive integer
:raises TypeError: If signal is not a list of integers or floats
:raises TypeError: If bound is not a list of integers or floats
:raises TypeError: If n_sample is not a list of integers
:raises TypeError: If val_size is not a float between 0 and 1 (inclusive)
:raises TypeError: If test_size is not a float between 0 and 1 (inclusive)
:raises TypeError: If random_state is not an integer
:raises TypeError: If signal_mean is not a float
:raises TypeError: If bound_mean is not a float
:raises ValueError: If val_size + test_size is greater than 1 (i.e. the splits
are too large)

Expand All @@ -85,10 +85,10 @@ def __init__(
raise TypeError("batch_size must be a positive integer")
if not isinstance(num_genes, int) or num_genes < 1:
raise TypeError("num_genes must be a positive integer")
if not isinstance(signal, list) or not all(
isinstance(x, (int, float)) for x in signal
if not isinstance(bound, list) or not all(
isinstance(x, (int, float)) for x in bound
):
raise TypeError("signal must be a list of integers or floats")
raise TypeError("bound must be a list of integers or floats")
if not isinstance(n_sample, list) or not all(
isinstance(x, int) for x in n_sample
):
Expand All @@ -99,17 +99,17 @@ def __init__(
raise TypeError("test_size must be a float between 0 and 1 (inclusive)")
if not isinstance(random_state, int):
raise TypeError("random_state must be an integer")
if not isinstance(signal_mean, float):
raise TypeError("signal_mean must be a float")
if not isinstance(bound_mean, float):
raise TypeError("bound_mean must be a float")
if test_size + val_size > 1:
raise ValueError("val_size + test_size must be less than or equal to 1")

super().__init__()
self.batch_size = batch_size
self.num_genes = num_genes
self.signal_mean = signal_mean
self.signal = signal or [0.1, 0.15, 0.2, 0.25, 0.3]
self.n_sample = n_sample or [1 for _ in range(len(self.signal))]
self.bound_mean = bound_mean
self.bound = bound or [0.1, 0.15, 0.2, 0.25, 0.3]
self.n_sample = n_sample or [1 for _ in range(len(self.bound))]
self.num_tfs = sum(self.n_sample) # sum of all n_sample is the number of TFs
self.val_size = val_size
self.test_size = test_size
Expand All @@ -132,10 +132,10 @@ def prepare_data(self) -> None:
performed as that is handled in the functions in generate_data.py."""
# this will be a list of length 10 with a GenePopulation object in each element
gene_populations_list = []
for signal_proportion, n_draws in zip(self.signal, self.n_sample):
for bound_proportion, n_draws in zip(self.bound, self.n_sample):
for _ in range(n_draws):
gene_populations_list.append(
generate_gene_population(self.num_genes, signal_proportion)
generate_gene_population(self.num_genes, bound_proportion)
)

# Generate binding data for each gene population
Expand Down Expand Up @@ -166,7 +166,7 @@ def prepare_data(self) -> None:
if self.max_mean_adjustment > 0:
perturbation_effects_list = generate_perturbation_effects(
binding_data_tensor,
signal_mean=self.signal_mean,
bound_mean=self.bound_mean,
tf_index=0, # unused
max_mean_adjustment=self.max_mean_adjustment,
adjustment_function=self.adjustment_function,
Expand All @@ -188,7 +188,7 @@ def prepare_data(self) -> None:
perturbation_effects_list = [
generate_perturbation_effects(
binding_data_tensor[:, tf_index, :].unsqueeze(1),
signal_mean=self.signal_mean,
bound_mean=self.bound_mean,
tf_index=0, # unused
)
for tf_index in range(sum(self.n_sample))
Expand Down
Loading
Loading