from types import MappingProxyType
from typing import Any, Tuple, Union, Mapping, Callable, Optional, Sequence
from pathlib import Path
import scvelo as scv
from anndata import AnnData
import numpy as np
import pandas as pd
from sklearn.metrics import pairwise_distances
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, LinearSegmentedColormap
from matplotlib.collections import LineCollection
from cellrank import logging as logg
from cellrank.tl import Lineage
from cellrank.ul._docs import d
from cellrank.pl._utils import _held_karp
from cellrank.tl._utils import save_fig, _unique_order_preserving
from cellrank.ul._utils import _check_collection
from cellrank.tl._constants import ModeEnum, AbsProbKey
class LineageOrder(ModeEnum): # noqa: D101
DEFAULT = "default"
OPTIMAL = "optimal"
class SpecialKey(ModeEnum): # noqa: D101
DEGREE = "priming_degree"
DIRECTION = "priming_direction"
class LabelRot(ModeEnum): # noqa: D101
DEFAULT = "default"
BEST = "best"
Metric_T = Union[str, Callable, np.ndarray, pd.DataFrame]
_N = 200
def _get_distances(data: Union[np.ndarray, Lineage], metric: Metric_T) -> np.ndarray:
if isinstance(data, Lineage):
data = data.X
if isinstance(metric, str) or callable(metric):
metric = pairwise_distances(data.T, metric=metric)
elif isinstance(metric, np.ndarray):
shape = (data.shape[1], data.shape[1])
if metric.shape != shape:
raise ValueError(
f"Expected a `numpy.ndarray` of shape `{shape}`, found `{metric.shape}`."
)
elif isinstance(metric, pd.DataFrame):
if np.any(metric.index != data.names):
raise ValueError(
f"Expected `pandas.DataFrame` to have the following index `{list(data.names)}`, "
f"found `{list(metric.columns)}`."
)
if np.any(metric.columns != data.names):
raise ValueError(
f"Expected `pandas.DataFrame` to have the following columns `{list(data.names)}`, "
f"found `{list(metric.columns)}`."
)
else:
raise TypeError(
f"Expected either metric defined by `str`, `callable` or a pairwise distance matrix of type"
f" `numpy.ndarray` or `pandas.DataFrame`, found `{type(metric).__name__}`."
)
return np.asarray(metric, dtype=np.float64)
def _get_optimal_order(data: Lineage, metric: Metric_T) -> Tuple[float, np.ndarray]:
"""Solve the TSP using dynamic programming."""
return _held_karp(_get_distances(data, metric))
def _priming_degree(probs: Union[np.ndarray, Lineage]) -> np.ndarray:
if isinstance(probs, Lineage):
probs = probs.X
return np.sum(probs * np.log2(probs / np.mean(probs, axis=0)), axis=1)
def _priming_direction(probs: Union[np.ndarray, Lineage]) -> np.ndarray:
if isinstance(probs, Lineage):
probs = probs.X
return np.argmax(probs / np.sum(probs, axis=0), axis=1)
[docs]@d.dedent
def circular_projection(
adata: AnnData,
keys: Union[str, Sequence[str]],
backward: bool = False,
lineages: Optional[Union[str, Sequence[str]]] = None,
lineage_order: Optional[str] = None,
metric: Union[str, Callable, np.ndarray, pd.DataFrame] = "correlation",
normalize_by_mean: bool = True,
ncols: int = 4,
space: float = 0.25,
use_raw: bool = False,
text_kwargs: Mapping[str, Any] = MappingProxyType({}),
labeldistance: float = 1.25,
labelrot: Union[str, float] = "best",
show_edges: bool = True,
key_added: Optional[str] = None,
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[Union[str, Path]] = None,
**kwargs,
):
r"""
Plot absorption probabilities on a circular embedding as done in [Velten17]_.
Parameters
----------
%(adata)s
keys
Keys in :attr:`anndata.AnnData.obs` or :attr:`anndata.AnnData.var_names`. Extra available keys:
- `'priming_degree'`: the entropy relative to the "average" cell, as described in [Velten17]_.
It follows the formula :math:`\sum_{j} p_{i, j} \log \frac{p_{ij}}{\overline{p_j}}` where :math:`i`
corresponds to a cell, :math:`j` corresponds to a lineage and :math:`\overline{p_j}` is the average
probability for lineage :math:`j`.
- `'priming_direction'`: defined as :math:`d_i = argmax_j \frac{p_{ij}}{\overline{p_j}}` in [Velten17]_
where :math:`i`, :math:`j` and :math:`\overline{p_j}` are defined as above.
%(backward)s
lineages
Lineages to plot. If `None`, plot all lineages.
lineage_order
Can be one of the following:
- `None`: it will determined automatically, based on the number of lineages.
- `'optimal'`: order the lineages optimally by solving the Travelling salesman problem (TSP).
Recommended for <= `20` lineages.
- `'default'`: use the order as specified in ``lineages``.
metric
Metric to use when constructing pairwise distance matrix when ``lineage_order = 'optimal'``. For available
options, see :func:`sklearn.metrics.pairwise_distances`.
normalize_by_mean
If `True`, normalize each lineage by its mean probability, as done in [Velten17]_.
ncols
Number of columns when plotting multiple ``keys``.
space
Horizontal and vertical space between for :func:`matplotlib.pyplot.subplots_adjust`.
use_raw
Whether to access :attr:`anndata.AnnData.raw` when there are ``keys`` in :attr:`anndata.AnnData.var_names`.
text_kwargs
Keyword arguments for :func:`matplotlib.pyplot.text`.
labeldistance
Distance at which the lineage labels will be drawn.
labelrot
How to rotate the labels. Valid options are:
- `'best'`: rotate labels so that they are easily readable.
- `'default'`: use :mod:`matplotlib`'s default.
- `None`: same as `'default'`.
If a :class:`float`, all labels will be rotated by this many degrees.
show_edges
Whether to show the edges surrounding the simplex.
key_added
Key in :attr:`anndata.AnnData.obsm` where to add the circular embedding. If `None`, it will be set to
`'X_fate_simplex_{fwd,bwd}'`, based on ``backward``.
%(plotting)s
kwargs
Keyword arguments for :func:`scvelo.pl.scatter`.
Returns
-------
%(just_plots)s
Also updates ``adata`` with the following fields:
- :attr:`anndata.AnnData.obsm` ``[{key_added}]``: the circular projection.
- :attr:`anndata.AnnData.obs` ``['priming_degree_{fwd,bwd}']``: the priming degree, if in ``keys``.
- :attr:`anndata.AnnData.obs` ``['priming_direction_{fwd,bwd}']``: the priming direction, if in ``keys``.
"""
if labeldistance is not None and labeldistance < 0:
raise ValueError(f"Expected `delta` to be positive, found `{labeldistance}`.")
if labelrot is None:
labelrot = LabelRot.DEFAULT
if isinstance(labelrot, str):
labelrot = LabelRot(labelrot)
suffix = "bwd" if backward else "fwd"
if key_added is None:
key_added = "X_fate_simplex_" + suffix
if isinstance(keys, str):
keys = (keys,)
keys = _unique_order_preserving(keys)
keys_ = _check_collection(
adata, keys, "obs", key_name="Observation", raise_exc=False
) + _check_collection(
adata, keys, "var_names", key_name="Gene", raise_exc=False, use_raw=use_raw
)
haystack = {s.s for s in SpecialKey}
keys = keys_ + [k for k in keys if k in haystack]
keys = _unique_order_preserving(keys)
if not len(keys):
raise ValueError("No valid keys have been selected.")
lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
if lineage_key not in adata.obsm:
raise KeyError(f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")
probs = adata.obsm[lineage_key]
if isinstance(lineages, str):
lineages = (lineages,)
elif lineages is None:
lineages = probs.names
probs = adata.obsm[lineage_key][lineages]
n_lin = probs.shape[1]
if n_lin <= 2:
raise ValueError(f"Expected at least `3` lineages, found `{n_lin}`")
X = probs.X.copy()
if normalize_by_mean:
X /= np.mean(X, axis=0)[None, :]
X /= X.sum(1)[:, None]
# this happens when cells for sel. lineages sum to 1 (or when the lineage average is 0, which is unlikely)
X = np.nan_to_num(X, nan=1.0 / n_lin, copy=False)
if lineage_order is None:
lineage_order = LineageOrder.OPTIMAL if n_lin <= 15 else LineageOrder.DEFAULT
logg.debug(f"Set ordering to `{lineage_order}`")
lineage_order = LineageOrder(lineage_order)
if lineage_order == LineageOrder.OPTIMAL:
logg.info(f"Solving TSP for `{n_lin}` states")
_, order = _get_optimal_order(X, metric=metric)
else:
order = np.arange(n_lin)
probs = probs[:, order]
X = X[:, order]
angle_vec = np.linspace(0, 2 * np.pi, n_lin, endpoint=False)
angle_vec_sin = np.cos(angle_vec)
angle_vec_cos = np.sin(angle_vec)
x = np.sum(X * angle_vec_sin, axis=1)
y = np.sum(X * angle_vec_cos, axis=1)
adata.obsm[key_added] = np.c_[x, y]
nrows = int(np.ceil(len(keys) / ncols))
fig, ax = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(ncols * 5, nrows * 5) if figsize is None else figsize,
dpi=dpi,
)
fig.subplots_adjust(wspace=space, hspace=space)
axes = np.ravel([ax])
text_kwargs = dict(text_kwargs)
text_kwargs["ha"] = "center"
text_kwargs["va"] = "center"
_i = 0
for _i, (k, ax) in enumerate(zip(keys, axes)):
set_lognorm, colorbar = False, kwargs.pop("colorbar", True)
try:
k = SpecialKey(k)
logg.debug(f"Calculating `{k}`")
if k == SpecialKey.DEGREE:
val = _priming_degree(probs)
set_lognorm = True
elif k == SpecialKey.DIRECTION:
val = pd.Series(
probs.names[_priming_direction(probs)], dtype="category"
)
color_mapper = dict(zip(probs.names, probs.colors))
adata.uns[f"{k.s}_{suffix}_colors"] = [
color_mapper[c] for c in val.cat.categories
]
val = val.values
else:
raise NotImplementedError(f"Special key `{k}` is not implemented.")
k = f"{k}_{suffix}"
if k in adata.obs:
logg.debug(f"Overwriting `adata.obs[{k!r}]`")
adata.obs[k] = val
except ValueError:
pass
scv.pl.scatter(
adata,
basis=key_added,
color=k,
show=False,
ax=ax,
use_raw=use_raw,
norm=LogNorm() if set_lognorm else None,
colorbar=colorbar,
**kwargs,
)
if colorbar and set_lognorm:
cbar = ax.collections[0].colorbar
cax = cbar.locator.axis
ticks = cax.minor.locator.tick_values(cbar.vmin, cbar.vmax)
ticks = [ticks[0], ticks[len(ticks) // 2 + 1], ticks[-1]]
cbar.set_ticks(ticks)
cbar.set_ticklabels([f"{t:.2f}" for t in ticks])
cbar.update_ticks()
patches, texts = ax.pie(
np.ones_like(angle_vec),
labeldistance=labeldistance,
rotatelabels=True,
labels=probs.names[::-1],
startangle=-360 / len(angle_vec) / 2,
counterclock=False,
textprops=text_kwargs,
)
for patch in patches:
patch.set_visible(False)
# clockwise
for color, text in zip(probs.colors[::-1], texts):
if isinstance(labelrot, (int, float)):
text.set_rotation(labelrot)
elif labelrot == LabelRot.BEST:
rot = text.get_rotation()
text.set_rotation(rot + 90 + (1 - rot // 180) * 180)
elif labelrot != LabelRot.DEFAULT:
raise NotImplementedError(
f"Label rotation `{labelrot}` is not yet implemented."
)
text.set_color(color)
if not show_edges:
continue
for i, color in enumerate(probs.colors):
next = (i + 1) % n_lin
x = 1.04 * np.linspace(angle_vec_sin[i], angle_vec_sin[next], _N)
y = 1.04 * np.linspace(angle_vec_cos[i], angle_vec_cos[next], _N)
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
cmap = LinearSegmentedColormap.from_list(
"abs_prob_cmap", [color, probs.colors[next]], N=_N
)
lc = LineCollection(segments, cmap=cmap, zorder=-1)
lc.set_array(np.linspace(0, 1, _N))
lc.set_linewidth(2)
ax.add_collection(lc)
for j in range(_i + 1, len(axes)):
axes[j].remove()
if save is not None:
save_fig(fig, save)