import collections
import enum
import math
import pathlib
import types
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Optional, Union
from scvelo.plotting import paga
import numpy as np
import pandas as pd
import scipy.sparse as sp
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import cm, colors
from anndata import AnnData
from scanpy.plotting import violin
from cellrank import logging as logg
from cellrank._utils import Lineage
from cellrank._utils._docs import d, inject_docs
from cellrank._utils._enum import ModeEnum
from cellrank._utils._key import Key
from cellrank._utils._utils import (
RandomKeys,
_unique_order_preserving,
save_fig,
valuedispatch,
)
from cellrank.pl._utils import _position_legend
__all__ = ["aggregate_fate_probabilities"]
class AggregationMode(ModeEnum):
BAR = enum.auto()
PAGA = enum.auto()
PAGA_PIE = enum.auto()
VIOLIN = enum.auto()
HEATMAP = enum.auto()
CLUSTERMAP = enum.auto()
[docs]
@d.dedent
@inject_docs(m=AggregationMode)
def aggregate_fate_probabilities(
adata: AnnData,
mode: Literal["bar", "paga", "paga_pie", "violin", "heatmap", "clustermap"] = AggregationMode.PAGA_PIE,
backward: bool = False,
lineages: Optional[Union[str, Sequence[str]]] = None,
cluster_key: Optional[str] = "clusters",
clusters: Optional[Union[str, Sequence[str]]] = None,
basis: Optional[str] = None,
cbar: bool = True,
ncols: Optional[int] = None,
sharey: bool = False,
fmt: str = "0.2f",
xrot: float = 90,
legend_kwargs: Mapping[str, Any] = types.MappingProxyType({"loc": "best"}),
figsize: Optional[tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[Union[str, pathlib.Path]] = None,
**kwargs: Any,
) -> None:
"""Plot aggregate lineage probabilities at a cluster level.
This can be used to investigate how likely a certain cluster is to go to the %(terminal)s states,
or in turn to have descended from the %(initial)s states.
For mode `{m.PAGA!r}` and `{m.PAGA_PIE!r}`, we use *PAGA* :cite:`wolf:19`.
Parameters
----------
%(adata)s
mode
Type of plot to show. Valid options are:
- ``{m.BAR!r}`` - barplot, one panel per cluster. The whiskers correspond to the standard error of the mean.
- ``{m.PAGA!r}`` - :func:`~scvelo.pl.paga`, one per %(initial_or_terminal)s state, colored in by fate.
- ``{m.PAGA_PIE!r}`` - :func:`~scvelo.pl.paga` with pie charts indicating aggregated fates.
- ``{m.VIOLIN!r}`` - violin plots, one per %(initial_or_terminal)s state.
- ``{m.HEATMAP!r}`` - a heatmap, showing average fates per cluster.
- ``{m.CLUSTERMAP!r}`` - same as a heatmap, but with a dendrogram.
%(backward)s
lineages
Lineages for which to visualize the fate probabilities. If :obj:`None`, use all lineages.
cluster_key
Key in :attr:`~anndata.AnnData.obs` containing the clusters.
clusters
Clusters to visualize. If :obj:`None`, all clusters will be plotted.
basis
Basis for scatterplot to use when ``mode = {m.PAGA_PIE!r}``. If :obj:`None`, don't show the scatterplot.
cbar
Whether to show colorbar when ``mode = {m.PAGA_PIE!r}``.
ncols
Number of columns when ``mode = {m.BAR!r}`` or ``mode = {m.PAGA!r}``.
sharey
Whether to share y-axis when ``mode = {m.BAR!r}``.
fmt
Format when using ``mode = {m.HEATMAP!r}`` or ``mode = {m.CLUSTERMAP!r}``.
xrot
Rotation of the labels on the x-axis.
figsize
Size of the figure.
legend_kwargs
Keyword arguments for :meth:`~matplotlib.axes.Axes.legend`.
For ``mode = {m.PAGA_PIE!r}`` and ``basis = '...'``, this controls the placement of the
fate probabilities legend.
%(plotting)s
kwargs
Keyword arguments for :func:`~scvelo.pl.paga`, :func:`~scanpy.pl.violin` or
:func:`~matplotlib.pyplot.bar`, depending on the ``mode``.
Returns
-------
%(just_plots)s
"""
@valuedispatch
def plot(mode: AggregationMode, *_args, **_kwargs):
raise NotImplementedError(mode.value)
@plot.register(AggregationMode.BAR)
def _():
cols = 4 if ncols is None else ncols
n_rows = math.ceil(len(clusters) / cols)
fig = plt.figure(None, (3.5 * cols, 5 * n_rows) if figsize is None else figsize, dpi=dpi)
fig.tight_layout()
gs = plt.GridSpec(n_rows, cols, figure=fig, wspace=0.5, hspace=0.5)
ax = None
colors = list(probs.colors)
for g, k in zip(gs, d.keys()):
current_ax = fig.add_subplot(g, sharey=ax)
current_ax.bar(
x=np.arange(probs.nlin),
height=d[k][0],
color=colors,
yerr=d[k][1],
ecolor="black",
capsize=10,
**kwargs,
)
if sharey:
ax = current_ax
current_ax.set_xticks(np.arange(probs.nlin))
current_ax.set_xticklabels(probs.names, rotation=xrot)
if not is_all:
current_ax.set_xlabel(term_states)
current_ax.set_ylabel("fate probability")
current_ax.set_title(k)
return fig
@plot.register(AggregationMode.PAGA)
def _():
kwargs["save"] = None
kwargs["show"] = False
if "cmap" not in kwargs:
kwargs["cmap"] = cm.viridis
cols = probs.nlin if ncols is None else ncols
nrows = math.ceil(probs.nlin / cols)
fig, axes = plt.subplots(
nrows,
cols,
figsize=(7 * cols, 4 * nrows) if figsize is None else figsize,
constrained_layout=True,
dpi=dpi,
)
# fig.tight_layout() can't use this because colorbar.make_axes fails
i = 0
axes = [axes] if not isinstance(axes, np.ndarray) else np.ravel(axes)
vmin, vmax = np.inf, -np.inf
if basis is not None:
kwargs["basis"] = basis
kwargs["scatter_flag"] = True
kwargs["color"] = cluster_key
for i, (ax, lineage_name) in enumerate(zip(axes, probs.names)):
cols = [v[0][i] for v in d.values()]
kwargs["ax"] = ax
kwargs["colors"] = tuple(cols)
kwargs["title"] = f"{direction} {lineage_name}"
vmin = np.min(cols + [vmin])
vmax = np.max(cols + [vmax])
paga(adata, **kwargs)
if cbar:
norm = colors.Normalize(vmin=vmin, vmax=vmax)
cax, _ = mpl.colorbar.make_axes(ax, aspect=60)
_ = mpl.colorbar.ColorbarBase(
cax,
ticks=np.linspace(norm.vmin, norm.vmax, 5),
norm=norm,
cmap=kwargs["cmap"],
label="average fate probability",
)
for ax in axes[i + 1 :]:
ax.remove()
return fig
@plot.register(AggregationMode.PAGA_PIE)
def _():
colors = {i: collections.OrderedDict(zip(probs.colors, mean)) for i, (mean, _) in enumerate(d.values())}
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
fig.tight_layout()
kwargs["ax"] = ax
kwargs["show"] = False
kwargs["colorbar"] = False # has to be disabled
kwargs["show"] = False
kwargs["node_colors"] = colors
kwargs.pop("save", None) # we will handle saving
kwargs["transitions"] = kwargs.get("transitions", "transitions_confidence")
if "legend_loc" in kwargs:
orig_ll = kwargs["legend_loc"]
if orig_ll != "on data":
kwargs["legend_loc"] = "none" # we will handle legend
else:
orig_ll = None
kwargs["legend_loc"] = "on data"
if basis is not None:
kwargs["basis"] = basis
kwargs["scatter_flag"] = True
kwargs["color"] = cluster_key
ax = paga(adata, **kwargs)
ax.set_title(kwargs.get("title", cluster_key))
if basis is not None and orig_ll not in ("none", "on data", None):
handles = []
for cluster_name, color in zip(
adata.obs[f"{cluster_key}"].cat.categories,
adata.uns[f"{cluster_key}_colors"],
):
handles += [ax.scatter([], [], label=cluster_name, c=color)]
first_legend = _position_legend(
ax,
legend_loc=orig_ll,
handles=handles,
**{k: v for k, v in legend_kwargs.items() if k != "loc"},
title=cluster_key,
)
fig.add_artist(first_legend)
if legend_kwargs.get("loc", None) not in ("none", "on data", None):
# we need to use these, because scvelo can have its own handles and
# they would be plotted here
handles = []
for lineage_name, color in zip(probs.names, colors[0].keys()):
handles += [ax.scatter([], [], label=lineage_name, c=color)]
if len(colors[0].keys()) != Lineage.from_adata(adata, backward=backward).nlin:
handles += [ax.scatter([], [], label="Rest", c="grey")]
second_legend = _position_legend(
ax,
legend_loc=legend_kwargs["loc"],
handles=handles,
**{k: v for k, v in legend_kwargs.items() if k != "loc"},
title=term_states,
)
fig.add_artist(second_legend)
return fig
@plot.register(AggregationMode.VIOLIN)
def _():
kwargs.pop("ax", None)
kwargs.pop("keys", None)
kwargs.pop("save", None) # we will handle saving
kwargs["show"] = False
kwargs["groupby"] = cluster_key
kwargs["rotation"] = xrot
cols = probs.nlin if ncols is None else ncols
nrows = math.ceil(probs.nlin / cols)
fig, axes = plt.subplots(
nrows,
cols,
figsize=(6 * cols, 4 * nrows) if figsize is None else figsize,
sharey=sharey,
dpi=dpi,
)
fig.tight_layout()
fig.subplots_adjust(wspace=0.2, hspace=0.3)
if not isinstance(axes, np.ndarray):
axes = [axes]
axes = np.ravel(axes)
with RandomKeys(adata, probs.nlin, where="obs") as keys:
_i = 0
for _i, (name, key, ax) in enumerate(zip(probs.names, keys, axes)):
adata.obs[key] = probs[name].X[:, 0]
ax.set_title(f"{direction} {name}")
violin(adata, ylabel="fate probability", keys=key, ax=ax, **kwargs)
for ax in axes[_i + 1 :]:
ax.remove()
return fig
def plot_violin_no_cluster_key():
kwargs.pop("ax", None)
kwargs.pop("keys", None) # don't care
kwargs.pop("save", None)
kwargs["show"] = False
kwargs["groupby"] = term_states
kwargs["xlabel"] = None
kwargs["rotation"] = xrot
data = np.ravel(probs.X.T)[..., None]
tmp = AnnData(sp.csr_matrix(data.shape, dtype=data.dtype))
tmp.obs["fate probability"] = data
tmp.obs[term_states] = (
pd.Series(np.concatenate([[f"{direction.lower()} {n}"] * adata.n_obs for n in probs.names]))
.astype("category")
.values
)
tmp.obs[term_states] = tmp.obs[term_states].cat.reorder_categories(
[f"{direction.lower()} {n}" for n in probs.names]
)
tmp.uns[f"{term_states}_colors"] = probs.colors
fig, ax = plt.subplots(figsize=figsize if figsize is not None else (8, 6), dpi=dpi)
ax.set_title(term_states.capitalize())
violin(tmp, keys=["fate probability"], ax=ax, **kwargs)
return fig
@plot.register(AggregationMode.HEATMAP)
def _():
data = pd.DataFrame([mean for mean, _ in d.values()], columns=probs.names, index=clusters).T
title = kwargs.pop("title", "average fate per cluster")
vmin, vmax = data.values.min(), data.values.max()
cbar_kws = {
"label": "probability",
"ticks": np.linspace(vmin, vmax, 5),
"format": "%.3f",
}
kwargs.setdefault("cmap", "viridis")
if use_clustermap:
kwargs["cbar_pos"] = (0, 0.9, 0.025, 0.15) if cbar else None
max_size = float(max(data.shape))
g = sns.clustermap(
data=data,
annot=True,
vmin=vmin,
vmax=vmax,
fmt=fmt,
row_colors=probs.colors,
dendrogram_ratio=(
0.15 * data.shape[0] / max_size,
0.15 * data.shape[1] / max_size,
),
cbar_kws=cbar_kws,
figsize=figsize,
**kwargs,
)
g.ax_heatmap.set_xlabel(cluster_key)
g.ax_heatmap.set_ylabel("lineage")
g.ax_col_dendrogram.set_title(title)
fig = g.fig
g = g.ax_heatmap
else:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
g = sns.heatmap(
data=data,
vmin=vmin,
vmax=vmax,
annot=True,
fmt=fmt,
cbar=cbar,
cbar_kws=cbar_kws,
ax=ax,
**kwargs,
)
ax.set_title(title)
ax.set_xlabel(cluster_key)
ax.set_ylabel("lineage")
g.set_xticklabels(g.get_xticklabels(), rotation=xrot)
g.set_yticklabels(g.get_yticklabels(), rotation=0)
return fig
mode = AggregationMode(mode)
if cluster_key is not None:
if cluster_key not in adata.obs:
raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.")
elif mode not in (mode.BAR, mode.VIOLIN):
raise ValueError(
f"Not specifying cluster key is only available for modes "
f"`{AggregationMode.BAR!r}` and `{AggregationMode.VIOLIN!r}`, found `mode={mode!r}`."
)
term_states = Key.obs.term_states(backward)
direction = Key.where(backward)
if cluster_key is not None:
is_all = False
if clusters is not None:
if isinstance(clusters, str):
clusters = [clusters]
clusters = _unique_order_preserving(clusters)
if mode in (mode.PAGA, mode.PAGA_PIE):
logg.debug(f"Setting `clusters` to all available ones because of `mode={mode!r}`")
clusters = list(adata.obs[cluster_key].cat.categories)
else:
for cname in clusters:
if cname not in adata.obs[cluster_key].cat.categories:
raise KeyError(f"Cluster `{cname!r}` not found in `adata.obs[{cluster_key!r}]`.")
else:
clusters = list(adata.obs[cluster_key].cat.categories)
else:
is_all = True
clusters = [term_states]
probs = Lineage.from_adata(adata, backward=backward)
if lineages is None:
# must be list for `sc.pl.violin`, else cats str
lineages = list(probs.names)
elif isinstance(lineages, str):
lineages = [lineages]
lineages = _unique_order_preserving(lineages)
probs = probs[:, lineages]
if mode == mode.VIOLIN and not is_all:
mask = np.isin(adata.obs[cluster_key], clusters)
adata = adata[mask].copy()
probs = probs[mask]
d = collections.OrderedDict()
for name in clusters:
data = probs.X if is_all else probs.X[(adata.obs[cluster_key] == name).to_numpy()]
mean = np.nanmean(data, axis=0)
std = np.nanstd(data, axis=0) / np.sqrt(data.shape[0])
d[name] = [mean, std]
logg.debug(f"Plotting in mode `{mode!r}`")
use_clustermap = False
if mode == mode.CLUSTERMAP:
use_clustermap = True
mode = mode.HEATMAP
elif mode in (AggregationMode.PAGA, AggregationMode.PAGA_PIE) and "paga" not in adata.uns:
raise KeyError("Compute PAGA first as `scvelo.tl.paga()` or `scanpy.tl.paga()`.")
fig = plot_violin_no_cluster_key() if mode == AggregationMode.VIOLIN and cluster_key is None else plot(mode)
if save is not None:
save_fig(fig, save)