Skip to content

Commit

Permalink
Add residual plots.
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Oct 19, 2024
1 parent 256e492 commit 21306c6
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
10 changes: 10 additions & 0 deletions src/jnotype/checks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
calculate_mcc,
calculate_mutation_frequencies,
calculate_number_of_mutations_histogram,
convert_genotypes_to_integers,
convert_integers_to_genotypes,
calculate_atoms_occurrence,
subsample_pytree,
simulate_summary_statistic,
)
from jnotype.checks._plots import rc_context, rcParams, plot_summary_statistic

Expand All @@ -17,4 +22,9 @@
"rc_context",
"rcParams",
"plot_summary_statistic",
"convert_genotypes_to_integers",
"convert_integers_to_genotypes",
"calculate_atoms_occurrence",
"subsample_pytree",
"simulate_summary_statistic",
]
30 changes: 27 additions & 3 deletions src/jnotype/checks/_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _wrap_array(y):
Optional[ArrayLike] -> Optional[NumPy Array]
"""
if y is not None:
return np.asarray(y)
return np.array(y)
else:
return None

Expand Down Expand Up @@ -75,7 +75,7 @@ def _plot_quantiles(

# If alpha is not set, calculate a reasonable value
if alpha is None:
alpha = min(0.1, 1 / (1 + len(quantiles)))
alpha = min(0.2, 1 / (1 + len(quantiles)))

if color is None:
color = rcParams["color_simulations"]
Expand Down Expand Up @@ -121,7 +121,7 @@ def _plot_trajectories(
if color is None:
color = rcParams["color_simulations"]
if alpha is None:
alpha = min(0.02, 1 / (1 + num_trajectories))
alpha = min(0.1, 1 / (1 + num_trajectories))

num_simulations = y_simulated.shape[0]

Expand Down Expand Up @@ -288,6 +288,8 @@ def plot_summary_statistic(
data_linewidth: Optional[float] = None,
data_markersize: Optional[float] = None,
data_marker: str = "default",
residuals: bool = False,
residuals_type: Literal[None, "mean", "median"] = None,
) -> None:
"""Plots a summary statistic together with uncertainty.
Expand Down Expand Up @@ -325,6 +327,28 @@ def plot_summary_statistic(
if len(y_simulated.shape) != 2 or y_simulated.shape[-1] != n_points:
raise ValueError("Simulated data has wrong shape.")

# Transform data
if residuals:
if y_simulated is None:
raise ValueError("For residual plot one has to provide simulated data.")
# Try to infer residuals_type from summary_type, if not provided
if residuals_type is None and summary_type in ["median", "mean"]:
residuals_type = summary_type # type: ignore

if residuals_type is None:
raise ValueError("Residuals type could not be automatically inferred.")
elif residuals_type == "mean":
y_perfect = np.mean(y_simulated, axis=0)
elif residuals_type == "median":
y_perfect = np.mean(y_simulated, axis=0)
else:
raise ValueError(f"Residuals type {residuals_type} not known.")

# Calculate the residuals
y_simulated = y_simulated - y_perfect[None, :]
if y_data is not None:
y_data = y_data - y_perfect

# Plot simulated data
if y_simulated is not None:
# Start by plotting uncertainty
Expand Down
13 changes: 10 additions & 3 deletions src/jnotype/checks/_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,17 @@ def calculate_atoms_occurrence(X: _DataSet) -> Int[Array, " 2**n_genes"]:
return jnp.bincount(indices, length=length) # type: ignore


def _get_leading_axis_size(pytree):
def get_leading_axis_size(pytree) -> int:
"""Infers the number of samples in a PyTree."""
# Extract all leaf nodes from the PyTree
leaves = jax.tree_util.tree_leaves(pytree)

if not leaves:
raise ValueError("The PyTree has no leaves.")

# TODO(Pawel): Go through all the leaves and check
# if shapes agree

# Assume the first leaf contains the leading axis
first_leaf = leaves[0]

Expand All @@ -125,7 +129,10 @@ def subsample_pytree(
n_samples: Optional[int] = None,
):
"""Subsamples a PyTree along the leading axis."""
leading_size = _get_leading_axis_size(samples)
leading_size = get_leading_axis_size(samples)

if n_samples is None:
n_samples = leading_size

if n_samples > leading_size:
raise ValueError("n_samples cannot be larger than the leading axis size.")
Expand Down Expand Up @@ -162,7 +169,7 @@ def simulate_summary_statistic(
which has a leading (0th) axis in each leaf
corresponding to the samples from the distribution
"""
n_samples = _get_leading_axis_size(samples)
n_samples = get_leading_axis_size(samples)
keys = jax.random.split(key, n_samples)

def f(subkey, sample):
Expand Down

0 comments on commit 21306c6

Please sign in to comment.