Source code for

from typing import Any, Tuple, Union, Literal, Mapping, Optional, Sequence

import math
from enum import auto
from types import MappingProxyType
from pathlib import Path
from collections import OrderedDict as odict

from anndata import AnnData
from cellrank import logging as logg
from cellrank._utils import Lineage
from scanpy.plotting import violin
from scvelo.plotting import paga
from import _position_legend
from cellrank._utils._key import Key
from cellrank._utils._docs import d, inject_docs
from cellrank._utils._enum import ModeEnum
from cellrank._utils._utils import (

import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix

import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from seaborn import heatmap, clustermap
from matplotlib import cm

__all__ = ["aggregate_absorption_probabilities"]

class AggregationMode(ModeEnum):
    BAR = auto()
    PAGA = auto()
    PAGA_PIE = auto()
    VIOLIN = auto()
    HEATMAP = auto()
    CLUSTERMAP = auto()

[docs]@d.dedent @inject_docs(m=AggregationMode) def aggregate_absorption_probabilities( adata: AnnData, mode: Literal[ "bar", "paga", "paga_pie", "violin", "heatmap", "clustermap" ] = AggregationMode.PAGA_PIE, backward: bool = False, lineages: Optional[Union[str, Sequence[str]]] = None, cluster_key: Optional[str] = "clusters", clusters: Optional[Union[str, Sequence[str]]] = None, basis: Optional[str] = None, cbar: bool = True, ncols: Optional[int] = None, sharey: bool = False, fmt: str = "0.2f", xrot: float = 90, legend_kwargs: Mapping[str, Any] = MappingProxyType({"loc": "best"}), figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, **kwargs: Any, ) -> None: """ Plot aggregate lineage probabilities at a cluster level. This can be used to investigate how likely a certain cluster is to go to the %(terminal)s states, or in turn to have descended from the %(initial)s states. For mode `{m.PAGA!r}` and `{m.PAGA_PIE!r}`, we use *PAGA* :cite:`wolf:19`. Parameters ---------- %(adata)s mode Type of plot to show. Valid options are: - `{m.BAR!r}` - barplot, one panel per cluster. The whiskers correspond to the standard error of the mean. - `{m.PAGA!r}` - scanpy's PAGA, one per %(initial_or_terminal)s state, colored in by fate. - `{m.PAGA_PIE!r}` - scanpy's PAGA with pie charts indicating aggregated fates. - `{m.VIOLIN!r}` - violin plots, one per %(initial_or_terminal)s state. - `{m.HEATMAP!r}` - a heatmap, showing average fates per cluster. - `{m.CLUSTERMAP!r}` - same as a heatmap, but with a dendrogram. %(backward)s lineages Lineages for which to visualize absorption probabilities. If `None`, use all lineages. cluster_key Key in :attr:`anndata.AnnData.obs` containing the clusters. clusters Clusters to visualize. If `None`, all clusters will be plotted. basis Basis for scatterplot to use when ``mode = {m.PAGA_PIE!r}``. If `None`, don't show the scatterplot. cbar Whether to show colorbar when ``mode = {m.PAGA_PIE!r}``. ncols Number of columns when ``mode = {m.BAR!r}`` or ``mode = {m.PAGA!r}``. sharey Whether to share y-axis when ``mode = {m.BAR!r}``. fmt Format when using ``mode = {m.HEATMAP!r}`` or ``mode = {m.CLUSTERMAP!r}``. xrot Rotation of the labels on the x-axis. figsize Size of the figure. legend_kwargs Keyword arguments for :func:`matplotlib.axes.Axes.legend`, such as `'loc'` for legend position. For ``mode = {m.PAGA_PIE!r}`` and ``basis = '...'``, this controls the placement of the absorption probabilities legend. %(plotting)s kwargs Keyword arguments for :func:``, :func:`` or :func:``, depending on the ``mode``. Returns ------- %(just_plots)s """ @valuedispatch def plot(mode: AggregationMode, *_args, **_kwargs): raise NotImplementedError(mode.value) @plot.register(AggregationMode.BAR) def _(): cols = 4 if ncols is None else ncols n_rows = math.ceil(len(clusters) / cols) fig = plt.figure( None, (3.5 * cols, 5 * n_rows) if figsize is None else figsize, dpi=dpi ) fig.tight_layout() gs = plt.GridSpec(n_rows, cols, figure=fig, wspace=0.5, hspace=0.5) ax = None colors = list(probs.colors) for g, k in zip(gs, d.keys()): current_ax = fig.add_subplot(g, sharey=ax) x=np.arange(probs.nlin), height=d[k][0], color=colors, yerr=d[k][1], ecolor="black", capsize=10, **kwargs, ) if sharey: ax = current_ax current_ax.set_xticks(np.arange(probs.nlin)) current_ax.set_xticklabels(probs.names, rotation=xrot) if not is_all: current_ax.set_xlabel(term_states) current_ax.set_ylabel("absorption probability") current_ax.set_title(k) return fig @plot.register(AggregationMode.PAGA) def _(): kwargs["save"] = None kwargs["show"] = False if "cmap" not in kwargs: kwargs["cmap"] = cm.viridis cols = probs.nlin if ncols is None else ncols nrows = math.ceil(probs.nlin / cols) fig, axes = plt.subplots( nrows, cols, figsize=(7 * cols, 4 * nrows) if figsize is None else figsize, constrained_layout=True, dpi=dpi, ) # fig.tight_layout() can't use this because colorbar.make_axes fails i = 0 axes = [axes] if not isinstance(axes, np.ndarray) else np.ravel(axes) vmin, vmax = np.inf, -np.inf if basis is not None: kwargs["basis"] = basis kwargs["scatter_flag"] = True kwargs["color"] = cluster_key for i, (ax, lineage_name) in enumerate(zip(axes, probs.names)): colors = [v[0][i] for v in d.values()] kwargs["ax"] = ax kwargs["colors"] = tuple(colors) kwargs["title"] = f"{direction} {lineage_name}" vmin = np.min(colors + [vmin]) vmax = np.max(colors + [vmax]) paga(adata, **kwargs) if cbar: norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) cax, _ = mpl.colorbar.make_axes(ax, aspect=60) _ = mpl.colorbar.ColorbarBase( cax, ticks=np.linspace(norm.vmin, norm.vmax, 5), norm=norm, cmap=kwargs["cmap"], label="average absorption probability", ) for ax in axes[i + 1 :]: ax.remove() return fig @plot.register(AggregationMode.PAGA_PIE) def _(): colors = { i: odict(zip(probs.colors, mean)) for i, (mean, _) in enumerate(d.values()) } fig, ax = plt.subplots(figsize=figsize, dpi=dpi) fig.tight_layout() kwargs["ax"] = ax kwargs["show"] = False kwargs["colorbar"] = False # has to be disabled kwargs["show"] = False kwargs["node_colors"] = colors kwargs.pop("save", None) # we will handle saving kwargs["transitions"] = kwargs.get("transitions", "transitions_confidence") if "legend_loc" in kwargs: orig_ll = kwargs["legend_loc"] if orig_ll != "on data": kwargs["legend_loc"] = "none" # we will handle legend else: orig_ll = None kwargs["legend_loc"] = "on data" if basis is not None: kwargs["basis"] = basis kwargs["scatter_flag"] = True kwargs["color"] = cluster_key ax = paga(adata, **kwargs) ax.set_title(kwargs.get("title", cluster_key)) if basis is not None and orig_ll not in ("none", "on data", None): handles = [] for cluster_name, color in zip( adata.obs[f"{cluster_key}"].cat.categories, adata.uns[f"{cluster_key}_colors"], ): handles += [ax.scatter([], [], label=cluster_name, c=color)] first_legend = _position_legend( ax, legend_loc=orig_ll, handles=handles, **{k: v for k, v in legend_kwargs.items() if k != "loc"}, title=cluster_key, ) fig.add_artist(first_legend) if legend_kwargs.get("loc", None) not in ("none", "on data", None): # we need to use these, because scvelo can have its own handles and # they would be plotted here handles = [] for lineage_name, color in zip(probs.names, colors[0].keys()): handles += [ax.scatter([], [], label=lineage_name, c=color)] if ( len(colors[0].keys()) != Lineage.from_adata(adata, backward=backward).nlin ): handles += [ax.scatter([], [], label="Rest", c="grey")] second_legend = _position_legend( ax, legend_loc=legend_kwargs["loc"], handles=handles, **{k: v for k, v in legend_kwargs.items() if k != "loc"}, title=term_states, ) fig.add_artist(second_legend) return fig @plot.register(AggregationMode.VIOLIN) def _(): kwargs.pop("ax", None) kwargs.pop("keys", None) kwargs.pop("save", None) # we will handle saving kwargs["show"] = False kwargs["groupby"] = cluster_key kwargs["rotation"] = xrot cols = probs.nlin if ncols is None else ncols nrows = math.ceil(probs.nlin / cols) fig, axes = plt.subplots( nrows, cols, figsize=(6 * cols, 4 * nrows) if figsize is None else figsize, sharey=sharey, dpi=dpi, ) fig.tight_layout() fig.subplots_adjust(wspace=0.2, hspace=0.3) if not isinstance(axes, np.ndarray): axes = [axes] axes = np.ravel(axes) with RandomKeys(adata, probs.nlin, where="obs") as keys: _i = 0 for _i, (name, key, ax) in enumerate(zip(probs.names, keys, axes)): adata.obs[key] = probs[name].X[:, 0] ax.set_title(f"{direction} {name}") violin( adata, ylabel="absorption probability", keys=key, ax=ax, **kwargs ) for ax in axes[_i + 1 :]: ax.remove() return fig def plot_violin_no_cluster_key(): kwargs.pop("ax", None) kwargs.pop("keys", None) # don't care kwargs.pop("save", None) kwargs["show"] = False kwargs["groupby"] = term_states kwargs["xlabel"] = None kwargs["rotation"] = xrot data = np.ravel(probs.X.T)[..., None] tmp = AnnData(csr_matrix(data.shape, dtype=data.dtype), dtype=data.dtype) tmp.obs["absorption probability"] = data tmp.obs[term_states] = ( pd.Series( np.concatenate( [[f"{direction.lower()} {n}"] * adata.n_obs for n in probs.names] ) ) .astype("category") .values ) tmp.obs[term_states] = tmp.obs[term_states].cat.reorder_categories( [f"{direction.lower()} {n}" for n in probs.names] ) tmp.uns[f"{term_states}_colors"] = probs.colors fig, ax = plt.subplots( figsize=figsize if figsize is not None else (8, 6), dpi=dpi ) ax.set_title(term_states.capitalize()) violin(tmp, keys=["absorption probability"], ax=ax, **kwargs) return fig @plot.register(AggregationMode.HEATMAP) def _(): data = pd.DataFrame( [mean for mean, _ in d.values()], columns=probs.names, index=clusters ).T title = kwargs.pop("title", "average fate per cluster") vmin, vmax = data.values.min(), data.values.max() cbar_kws = { "label": "probability", "ticks": np.linspace(vmin, vmax, 5), "format": "%.3f", } kwargs.setdefault("cmap", "viridis") if use_clustermap: kwargs["cbar_pos"] = (0, 0.9, 0.025, 0.15) if cbar else None max_size = float(max(data.shape)) g = clustermap( data=data, annot=True, vmin=vmin, vmax=vmax, fmt=fmt, row_colors=probs.colors, dendrogram_ratio=( 0.15 * data.shape[0] / max_size, 0.15 * data.shape[1] / max_size, ), cbar_kws=cbar_kws, figsize=figsize, **kwargs, ) g.ax_heatmap.set_xlabel(cluster_key) g.ax_heatmap.set_ylabel("lineage") g.ax_col_dendrogram.set_title(title) fig = g.fig g = g.ax_heatmap else: fig, ax = plt.subplots(figsize=figsize, dpi=dpi) g = heatmap( data=data, vmin=vmin, vmax=vmax, annot=True, fmt=fmt, cbar=cbar, cbar_kws=cbar_kws, ax=ax, **kwargs, ) ax.set_title(title) ax.set_xlabel(cluster_key) ax.set_ylabel("lineage") g.set_xticklabels(g.get_xticklabels(), rotation=xrot) g.set_yticklabels(g.get_yticklabels(), rotation=0) return fig mode = AggregationMode(mode) if cluster_key is not None: if cluster_key not in adata.obs: raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.") elif mode not in (mode.BAR, mode.VIOLIN): raise ValueError( f"Not specifying cluster key is only available for modes " f"`{AggregationMode.BAR!r}` and `{AggregationMode.VIOLIN!r}`, found `mode={mode!r}`." ) term_states = Key.obs.term_states(backward) direction = Key.where(backward) if cluster_key is not None: is_all = False if clusters is not None: if isinstance(clusters, str): clusters = [clusters] clusters = _unique_order_preserving(clusters) if mode in (mode.PAGA, mode.PAGA_PIE): logg.debug( f"Setting `clusters` to all available ones because of `mode={mode!r}`" ) clusters = list(adata.obs[cluster_key].cat.categories) else: for cname in clusters: if cname not in adata.obs[cluster_key].cat.categories: raise KeyError( f"Cluster `{cname!r}` not found in `adata.obs[{cluster_key!r}]`." ) else: clusters = list(adata.obs[cluster_key].cat.categories) else: is_all = True clusters = [term_states] probs = Lineage.from_adata(adata, backward=backward) if lineages is None: # must be list for ``, else cats str lineages = list(probs.names) elif isinstance(lineages, str): lineages = [lineages] lineages = _unique_order_preserving(lineages) probs = probs[:, lineages] if mode == mode.VIOLIN and not is_all: adata = adata[np.isin(adata.obs[cluster_key], clusters)].copy() d = odict() for name in clusters: mask = ( np.ones((adata.n_obs,), dtype=bool) if is_all else (adata.obs[cluster_key] == name).values ) mask = np.asarray(mask, dtype=bool) data = probs[mask, :].X mean = np.nanmean(data, axis=0) std = np.nanstd(data, axis=0) / np.sqrt(data.shape[0]) d[name] = [mean, std] logg.debug(f"Plotting in mode `{mode!r}`") use_clustermap = False if mode == mode.CLUSTERMAP: use_clustermap = True mode = mode.HEATMAP elif ( mode in (AggregationMode.PAGA, AggregationMode.PAGA_PIE) and "paga" not in adata.uns ): raise KeyError( "Compute PAGA first as `` or ``." ) fig = ( plot_violin_no_cluster_key() if mode == AggregationMode.VIOLIN and cluster_key is None else plot(mode) ) if save is not None: save_fig(fig, save)