Source code for cellrank.pl._circular_projection

import enum
import pathlib
import types
from typing import Any, Callable, Literal, Mapping, Optional, Sequence, Tuple, Union

import scvelo as scv

import numpy as np
import pandas as pd
from sklearn.metrics import pairwise_distances

import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import LinearSegmentedColormap, LogNorm

from anndata import AnnData

from cellrank import logging as logg
from cellrank._utils import Lineage
from cellrank._utils._docs import d
from cellrank._utils._enum import ModeEnum
from cellrank._utils._key import Key
from cellrank._utils._lineage import PrimingDegree
from cellrank._utils._utils import _check_collection, _unique_order_preserving, save_fig
from cellrank.pl._utils import _held_karp

__all__ = ["circular_projection"]


class LineageOrder(ModeEnum):
    DEFAULT = enum.auto()
    OPTIMAL = enum.auto()


class LabelRot(ModeEnum):
    DEFAULT = enum.auto()
    BEST = enum.auto()


Metric_T = Union[str, Callable, np.ndarray, pd.DataFrame]
_N = 200


def _get_distances(data: Union[np.ndarray, Lineage], metric: Metric_T) -> np.ndarray:
    if isinstance(data, Lineage):
        data = data.X

    if isinstance(metric, str) or callable(metric):
        metric = pairwise_distances(data.T, metric=metric)
    elif isinstance(metric, (pd.DataFrame, np.ndarray)):
        shape = (data.shape[1], data.shape[1])
        if metric.shape != shape:
            raise ValueError(
                f"Expected an `numpy.array` or `pandas.DataFrame` of shape `{shape}`, found `{metric.shape}`."
            )
    else:
        raise TypeError(
            f"Expected either metric defined by `str`, `callable` or a pairwise distance matrix of type"
            f" `numpy.ndarray` or `pandas.DataFrame`, found `{type(metric).__name__}`."
        )

    return np.asarray(metric, dtype=np.float64)


def _get_optimal_order(data: Lineage, metric: Metric_T) -> Tuple[float, np.ndarray]:
    """Solve the TSP using dynamic programming."""
    return _held_karp(_get_distances(data, metric))


[docs] @d.dedent def circular_projection( adata: AnnData, keys: Union[str, Sequence[str]], backward: bool = False, lineages: Optional[Union[str, Sequence[str]]] = None, early_cells: Optional[Union[Mapping[str, Sequence[str]], Sequence[str]]] = None, lineage_order: Optional[Literal["default", "optimal"]] = None, metric: Union[str, Callable, np.ndarray, pd.DataFrame] = "correlation", normalize_by_mean: bool = True, ncols: int = 4, space: float = 0.25, use_raw: bool = False, text_kwargs: Mapping[str, Any] = types.MappingProxyType({}), label_distance: float = 1.25, label_rot: Union[Literal["default", "best"], float] = "best", show_edges: bool = True, key_added: Optional[str] = None, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, pathlib.Path]] = None, **kwargs: Any, ) -> None: r"""Visualize fate probabilities in a circular embedding :cite:`velten:17,jaitin:14`. We arrange all computed terminal states evenly spaced around the unit circle and place cells inside the unit circle in a way that reflects their fate probabilities. In other words, the more certain we are that a cell will transition towards a given terminal state, the closer we place it to that terminal state. Uncommitted cells thus reside in the middle of the circle. Parameters ---------- %(adata)s keys Keys in :attr:`~anndata.AnnData.obs` or :attr:`~anndata.AnnData.var_names`. Additional keys are: - ``'kl_divergence'`` - as in :cite:`velten:17`, computes KL-divergence between the fate probabilities of a cell and the average fate probabilities. See ``early_cells`` for more information. - ``'entropy'`` - as in :cite:`setty:19`, computes entropy over a cells fate probabilities. %(backward)s lineages Lineages to plot. If :obj:`None`, plot all lineages. early_cells Cell ids or a mask marking early cells used to define the average fate probabilities. If :obj:`None`, use all cells. Only used when ``'kl_divergence'`` is in the ``keys``. If a :class:`dict`, key specifies a cluster key in :attr:`~anndata.AnnData.obs` and the values specify cluster labels containing early cells. lineage_order Can be one of the following: - :obj:`None` - it will be determined automatically, based on the number of lineages. - ``'optimal'`` - order lineages optimally by solving the Traveling salesman problem (TSP). Recommended for <= :math:`20` lineages. - ``'default'`` - use the order as specified by ``lineages``. metric Metric to use when constructing pairwise distance matrix when ``lineage_order = 'optimal'``. For available options, see :func:`~sklearn.metrics.pairwise_distances`. normalize_by_mean If :obj:`True`, normalize each lineage by its mean probability, as done in :cite:`velten:17`. ncols Number of columns when plotting multiple ``keys``. space Horizontal and vertical space between for :func:`~matplotlib.pyplot.subplots_adjust`. use_raw Whether to access :attr:`~anndata.AnnData.raw` when there are ``keys`` in :attr:`~anndata.AnnData.var_names`. text_kwargs Keyword arguments for :meth:`~matplotlib.axes.Axes.text`. label_distance Distance at which the lineage labels will be drawn. label_rot How to rotate the labels. Valid options are: - ``'best'`` - rotate labels so that they are easily readable. - ``'default'`` - use :mod:`matplotlib`'s default. - :obj:`None` - same as ``'default'``. - :class:`float`, all labels will be rotated by this many degrees. show_edges Whether to show the edges surrounding the simplex. key_added Key in :attr:`~anndata.AnnData.obsm` where to add the circular embedding. If :obj:`None`, it will be set to ``'X_fate_simplex_{fwd,bwd}'``, based on the ``backward``. %(plotting)s kwargs Keyword arguments for :func:`~scvelo.pl.scatter`. Returns ------- %(just_plots)s Also updates ``adata`` with the following fields: - :attr:`adata.obsm['{key_added}'] <anndata.AnnData.obsm>` - the circular projection. - :attr:`adata.obs['to_{initial,terminal}_states_{method}'] <anndata.AnnData.obs>` - the priming degree, if a method is present in the ``keys``. """ if label_distance is not None and label_distance < 0: raise ValueError(f"Expected `label_distance` to be positive, found `{label_distance}`.") if label_rot is None: label_rot = LabelRot.DEFAULT label_rot = LabelRot(label_rot) suffix = "bwd" if backward else "fwd" if key_added is None: key_added = "X_fate_simplex_" + suffix if isinstance(keys, str): keys = (keys,) keys = _unique_order_preserving(keys) keys_ = _check_collection(adata, keys, "obs", key_name="Observation", raise_exc=False) + _check_collection( adata, keys, "var_names", key_name="Gene", raise_exc=False, use_raw=use_raw ) haystack = set(PrimingDegree) keys = keys_ + [k for k in keys if k in haystack] keys = _unique_order_preserving(keys) if not len(keys): raise ValueError("No valid keys have been selected.") probs = Lineage.from_adata(adata, backward=backward) if isinstance(lineages, str): lineages = [lineages] elif lineages is None: lineages = probs.names lineages = _unique_order_preserving(lineages) probs = probs[lineages] if probs.nlin < 3: raise ValueError(f"Expected at least `3` lineages, found `{probs.nlin}`.") X = probs.X.copy() if normalize_by_mean: X /= np.mean(X, axis=0)[None, :] X /= X.sum(1)[:, None] # this happens when cells for sel. lineages sum to 1 (or when the lineage average is 0, which is unlikely) X = np.nan_to_num(X, nan=1.0 / probs.nlin, copy=False) if lineage_order is None: lineage_order = LineageOrder.OPTIMAL if 3 < probs.nlin <= 20 else LineageOrder.DEFAULT logg.debug(f"Set ordering to `{lineage_order}`") lineage_order = LineageOrder(lineage_order) if lineage_order == LineageOrder.OPTIMAL: logg.info(f"Solving TSP for `{probs.nlin}` states") _, order = _get_optimal_order(X, metric=metric) else: order = np.arange(probs.nlin) probs = probs[:, order] X = X[:, order] angle_vec = np.linspace(0, 2 * np.pi, probs.nlin, endpoint=False) angle_vec_sin = np.cos(angle_vec) angle_vec_cos = np.sin(angle_vec) x = np.sum(X * angle_vec_sin, axis=1) y = np.sum(X * angle_vec_cos, axis=1) adata.obsm[key_added] = np.c_[x, y] nrows = int(np.ceil(len(keys) / ncols)) fig, ax = plt.subplots( nrows=nrows, ncols=ncols, figsize=(ncols * 5, nrows * 5) if figsize is None else figsize, dpi=dpi, ) fig.subplots_adjust(wspace=space, hspace=space) axes = np.ravel([ax]) text_kwargs = dict(text_kwargs) text_kwargs["ha"] = "center" text_kwargs["va"] = "center" _i, lineage_key = 0, Key.obsm.fate_probs(backward) for _i, (k, ax) in enumerate(zip(keys, axes)): set_lognorm, colorbar = False, kwargs.pop("colorbar", True) try: _ = PrimingDegree(k) logg.debug(f"Calculating priming degree using `method={k}`") val = probs.priming_degree(method=k, early_cells=early_cells) k = f"{lineage_key}_{k}" adata.obs[k] = val except ValueError: # TODO(michalk8): parse the exception pass scv.pl.scatter( adata, basis=key_added, color=k, show=False, ax=ax, use_raw=use_raw, norm=LogNorm() if set_lognorm else None, colorbar=colorbar, **kwargs, ) if colorbar and set_lognorm: cbar = ax.collections[0].colorbar cax = cbar.locator.axis ticks = cax.minor.locator.tick_values(cbar.vmin, cbar.vmax) ticks = [ticks[0], ticks[len(ticks) // 2 + 1], ticks[-1]] cbar.set_ticks(ticks) cbar.set_ticklabels([f"{t:.2f}" for t in ticks]) cbar.update_ticks() patches, texts = ax.pie( np.ones_like(angle_vec), labeldistance=label_distance, rotatelabels=True, labels=probs.names[::-1], startangle=-360 / len(angle_vec) / 2, counterclock=False, textprops=text_kwargs, ) for patch in patches: patch.set_visible(False) # clockwise for color, text in zip(probs.colors[::-1], texts): if isinstance(label_rot, (int, float, np.number)): text.set_rotation(label_rot) elif label_rot == LabelRot.BEST: rot = text.get_rotation() text.set_rotation(rot + 90 + (1 - rot // 180) * 180) elif label_rot != LabelRot.DEFAULT: raise NotImplementedError(f"Label rotation `{label_rot}` is not yet implemented.") text.set_color(color) if not show_edges: continue for i, color in enumerate(probs.colors): next = (i + 1) % probs.nlin x = 1.04 * np.linspace(angle_vec_sin[i], angle_vec_sin[next], _N) y = 1.04 * np.linspace(angle_vec_cos[i], angle_vec_cos[next], _N) points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) cmap = LinearSegmentedColormap.from_list("fate_prob_cmap", [color, probs.colors[next]], N=_N) lc = LineCollection(segments, cmap=cmap, zorder=-1) lc.set_array(np.linspace(0, 1, _N)) lc.set_linewidth(2) ax.add_collection(lc) for j in range(_i + 1, len(axes)): axes[j].remove() if save is not None: save_fig(fig, save)