Skip to content

Commit

Permalink
Update embedding_scatterplot fn to support different dim. red. meth…
Browse files Browse the repository at this point in the history
…ods, w/ default being PaCMAP
  • Loading branch information
nathanpainchaud committed Oct 17, 2023
1 parent 2553430 commit 0ff576a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ scipy = "*"
seaborn = "*"
matplotlib = "*"
umap-learn = { version = "*", extras = ["plot"] }
pacmap = "*"
pandas = "*"
h5py = "*"
PyYAML = "*"
Expand Down
46 changes: 33 additions & 13 deletions vital/utils/plot.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
import logging
from typing import Any, Dict, Iterable, Iterator, Union
from typing import Any, Dict, Iterable, Iterator, Literal, Union

import matplotlib
import numpy as np
import pandas as pd
import seaborn as sns
import umap
from matplotlib import pyplot as plt
from matplotlib.axes import Axes

logger = logging.getLogger(__name__)


def embedding_scatterplot(
data: pd.DataFrame, plots_kwargs: Iterable[Dict[str, Any]], umap_kwargs: Dict[str, Any] = None, data_tag: str = None
data: pd.DataFrame,
plots_kwargs: Iterable[Dict[str, Any]],
data_tag: str = None,
method: Literal["tsne", "umap", "pacmap"] = "pacmap",
**embedding_kwargs,
) -> Iterator[Axes]:
"""Generates 2D scatter plots of some data, reducing its dimensionality to 2 using UMAP if it's not already 2D.
"""Generates 2D scatter plots of some data, reducing its dimensionality to 2 if it's not already 2D.
Args:
data: Dataframe with each column representing a dimension of the data, and relevant metadata being stored in a
multiindex.
plots_kwargs: Sets of kwargs to use to generate different versions of the scatter plot, e.g. modifying the
variables used for hue and/or style.
umap_kwargs: If the data has more than 2 dimensions, UMAP is used to reduce the dimensionality of the data for
plotting purposes. This parameter is passed along to the UMAP estimator's `init`.
data_tag: String describing the data used in the titles/logs, etc. If not specified, it defaults to 'data'.
method: If the data has more than 2 dimensions, this parameter specifies the method to use to reduce the
dimensionality of the data for plotting purposes.
**embedding_kwargs: Parameters passed along to the embedding's constructor.
Returns:
An iterator over the generated scatter plots.
Expand All @@ -40,15 +44,31 @@ def embedding_scatterplot(
elif len(data.columns) == 2:
plot_title = f"2D {data_tag}"
else: # len(encoding_dims) > 2
if umap_kwargs is None:
umap_kwargs = {}
plot_title = f"2D UMAP embedding of the {len(data.columns)}D {data_tag}"
logger.info(f"Generating 2D UMAP embedding of {len(data.columns)}D {data_tag}...")
umap_embedding = umap.UMAP(**umap_kwargs).fit_transform(data)
if embedding_kwargs is None:
embedding_kwargs = {}
match method:
case "tsne":
from sklearn.manifold import TSNE

# Update the encodings dataframe with the new UMAP embedding
embedding_cls = TSNE
case "umap":
import umap

embedding_cls = umap.UMAP
case "pacmap":
from pacmap import PaCMAP

embedding_cls = PaCMAP
case _:
raise ValueError(f"Unknown embedding method '{method}'. Must be one of: ['tsne', 'umap', 'pacmap'].")

plot_title = f"2D {embedding_cls.__name__} embedding of the {len(data.columns)}D {data_tag}"
logger.info(f"Generating 2D {method} embedding of {len(data.columns)}D {data_tag}...")
data_2d = embedding_cls(**embedding_kwargs).fit_transform(data.to_numpy())

# Update the encodings dataframe with the new 2D embedding
data = data.drop(labels=data.columns, axis="columns")
data[[0, 1]] = umap_embedding
data[[0, 1]] = data_2d

# Generate a plot of the embedding for each set of plot kwargs provided
for plot_kwargs in plots_kwargs:
Expand Down

0 comments on commit 0ff576a

Please sign in to comment.