import copy
import pathlib
from collections.abc import Iterable, Sequence
from typing import Any, Optional, Union
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize, to_hex
from anndata import AnnData
from cellrank import logging as logg
from cellrank._utils import Lineage
from cellrank._utils._docs import d
from cellrank._utils._utils import _unique_order_preserving, save_fig
from cellrank.pl._utils import _get_categorical_colors, _position_legend
__all__ = ["log_odds"]
[docs]
@d.dedent
def log_odds(
adata: AnnData,
lineage_1: str,
lineage_2: Optional[str] = None,
time_key: str = "exp_time",
backward: bool = False,
keys: Optional[Union[str, Sequence[str]]] = None,
threshold: Optional[Union[float, Sequence]] = None,
threshold_color: str = "red",
layer: Optional[str] = None,
use_raw: bool = False,
size: float = 2.0,
cmap: str = "viridis",
alpha: Optional[float] = 0.8,
ncols: Optional[int] = None,
fontsize: Optional[Union[float, str]] = None,
xticks_step_size: Optional[int] = 1,
legend_loc: Optional[str] = "best",
jitter: Union[bool, float] = True,
seed: Optional[int] = None,
figsize: Optional[tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[Union[str, pathlib.Path]] = None,
show: bool = True,
**kwargs: Any,
) -> Optional[Union[Axes, Sequence[Axes]]]:
"""Plot log-odds ratio between trajectories.
This plotting function is geared towards time-series datasets that have been analyzed
using the :class:`~cellrank.kernels.RealTimeKernel`. It visualizes log-odd ratios
between different trajectories per cell, with the option to color in certain
molecular features, like genes. This can be useful to detect and visualize fate-decisive genes.
Parameters
----------
%(adata)s
lineage_1
The first lineage for which to compute the log-odds.
lineage_2
The second lineage for which to compute the log-odds. If :obj:`None`, use the rest of the lineages.
time_key
Key in :attr:`~anndata.AnnData.obs` containing the experimental time.
%(backward)s
keys
Key in :attr:`~anndata.AnnData.obs` or :attr:`~anndata.AnnData.var_names`.
threshold
Visualize whether total expression per cell is greater than ``threshold``.
If a :class:`~typing.Sequence`, it should be the same length as ``keys``.
threshold_color
Color to use when plotting thresholded expression values.
layer
Which layer to use to get expression values. If :obj:`None` or ``'X'``, use :attr:`~anndata.AnnData.X`.
use_raw
Whether to access :attr:`~anndata.AnnData.raw`. If :obj:`True`, ``layer`` is ignored.
size
Size of the dots.
cmap
Colormap to use for continuous variables in ``keys``.
alpha
Alpha values for the dots.
ncols
Number of columns.
fontsize
Size of the font for the title, x- and y-label.
xticks_step_size
Show only every other *n-th* tick on the x-axis. If :obj:`None`, don't show any ticks.
legend_loc
Position of the legend. If :obj:`None`, do not show the legend.
jitter
Amount of jitter to apply along x-axis.
seed
Seed for ``jitter`` to ensure reproducibility.
%(plotting)s
show
If `False`, return :class:`~matplotlib.axes.Axes` or a sequence of them.
kwargs
Keyword arguments for :func:`~seaborn.stripplot`.
Returns
-------
%(just_plots)s
If ``show = False``, returns the axes object.
"""
from cellrank.kernels._utils import _ensure_numeric_ordered
def decorate(ax: Axes, *, title: Optional[str] = None, show_ylabel: bool = True) -> None:
ax.set_xlabel(time_key, fontsize=fontsize)
ax.set_title(title, fontdict={"fontsize": fontsize})
ax.set_ylabel(ylabel if show_ylabel else "", fontsize=fontsize)
if xticks_step_size is None:
ax.set_xticks([])
else:
step = max(1, xticks_step_size)
ax.set_xticks(np.arange(0, n_cats, step))
ax.set_xticklabels(df[time_key].cat.categories[::step])
def cont_palette(values: np.ndarray) -> tuple[list[str], ScalarMappable]:
cm = copy.copy(plt.get_cmap(cmap))
cm.set_bad("grey")
sm = ScalarMappable(cmap=cm, norm=Normalize(vmin=np.nanmin(values), vmax=np.nanmax(values)))
return [to_hex(v) for v in (sm.to_rgba(values))], sm
def get_data(
key: str,
thresh: Optional[float] = None,
) -> tuple[Optional[str], Optional[np.ndarray], Optional[np.ndarray], ScalarMappable]:
try:
_, palette = _get_categorical_colors(adata, key)
df[key] = adata.obs[key].values[mask]
df[key] = df[key].cat.remove_unused_categories()
try:
# seaborn doesn't like numeric categories
df[key] = df[key].astype(float)
palette = {float(k): v for k, v in palette.items()}
except ValueError:
pass
# otherwise seaborn plots all
palette = {k: palette[k] for k in df[key].unique()}
hue, thresh_mask, sm = key, None, None
except TypeError:
# not a categorical
palette, hue, thresh_mask, sm = (
cont_palette(adata.obs[key].values[mask])[0],
None,
None,
None,
)
except KeyError:
# unable to find data in `adata.obs`, try res
try:
# fmt: off
if thresh is None:
values = adata.raw.obs_vector(key) if use_raw else adata.obs_vector(key, layer=layer)
palette, sm = cont_palette(values)
hue, thresh_mask = None, None
else:
if use_raw:
values = np.asarray(adata.raw[:, key].X[mask].sum(1)).squeeze()
elif layer not in (None, "X"):
values = np.asarray(adata[:, key].layers[layer][mask].sum(1)).squeeze()
else:
values = np.asarray(adata[:, key].X[mask].sum(1)).squeeze()
thresh_mask = values > thresh
hue, palette, sm = None, None, None
# fmt: on
except KeyError as e:
raise e from None
return hue, palette, thresh_mask, sm
np.random.seed(seed) # noqa: NPY002
_ = kwargs.pop("orient", None)
if use_raw and adata.raw is None:
logg.warning("No raw attribute set. Setting `use_raw=False`")
use_raw = False
probs = Lineage.from_adata(adata, backward=backward)
time = _ensure_numeric_ordered(adata, time_key)
order = time.cat.categories[:: -1 if backward else 1]
fate1 = probs[lineage_1].X.squeeze(-1)
if lineage_2 is None:
fate2 = 1 - fate1
ylabel = rf"$\log{{\frac{{{lineage_1}}}{{rest}}}}$"
else:
fate2 = probs[lineage_2].X.squeeze(-1)
ylabel = rf"$\log{{\frac{{{lineage_1}}}{{{lineage_2}}}}}$"
# fmt: off
df = pd.DataFrame(
{
"log_odds": np.log(np.divide(fate1, fate2, where=fate2 != 0, out=np.zeros_like(fate1)) + 1e-12),
time_key: time,
}
)
mask = (fate1 != 0) & (fate2 != 0)
df = df[mask]
n_cats = len(df[time_key].cat.categories)
# fmt: on
if keys is None:
if figsize is None:
figsize = np.array([n_cats, n_cats * 4 / 6]) / 2
fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True)
ax = sns.stripplot(
x=time_key,
y="log_odds",
data=df,
order=order,
jitter=jitter,
color="k",
size=size,
ax=ax,
**kwargs,
)
decorate(ax)
if save is not None:
save_fig(fig, save)
return None if show else ax
if isinstance(keys, str):
keys = (keys,)
if not len(keys):
raise ValueError("No keys have been selected.")
keys = _unique_order_preserving(keys)
if not isinstance(threshold, Iterable):
threshold = (threshold,) * len(keys)
if len(threshold) != len(keys):
raise ValueError(f"Expected `threshold` to be of length `{len(keys)}`, found `{len(threshold)}`.")
ncols = max(len(keys) if ncols is None else ncols, 1)
nrows = int(np.ceil(len(keys) / ncols))
if figsize is None:
figsize = np.array([n_cats * ncols, n_cats * nrows * 4 / 6]) / 2
fig, axes = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=figsize,
dpi=dpi,
constrained_layout=True,
sharey="all",
)
axes = np.ravel([axes])
i = 0
for i, (key, ax, thresh) in enumerate(zip(keys, axes, threshold)):
hue, palette, thresh_mask, sm = get_data(key, thresh)
show_ylabel = i % ncols == 0
ax = sns.stripplot(
x=time_key,
y="log_odds",
data=df if thresh_mask is None else df[~thresh_mask],
hue=hue,
order=order,
jitter=jitter,
color="black",
palette=palette,
size=size,
alpha=alpha if alpha is not None else None if thresh_mask is None else 0.8,
ax=ax,
**kwargs,
)
if thresh_mask is not None:
sns.stripplot(
x=time_key,
y="log_odds",
data=df if thresh_mask is None else df[thresh_mask],
hue=hue,
order=order,
jitter=jitter,
color=threshold_color,
palette=palette,
size=size * 2,
alpha=0.9,
ax=ax,
**kwargs,
)
key = rf"${key} > {thresh}$"
if sm is not None:
cax = ax.inset_axes([1.02, 0, 0.025, 1], transform=ax.transAxes)
fig.colorbar(sm, ax=ax, cax=cax)
else:
if legend_loc in (None, "none"):
legend = ax.get_legend()
if legend is not None:
legend.remove()
else:
handles, labels = ax.get_legend_handles_labels()
if len(handles):
_position_legend(ax, legend_loc=legend_loc, handles=handles, labels=labels)
decorate(ax, title=key, show_ylabel=show_ylabel)
for ax in axes[i + 1 :]:
ax.remove()
axes = axes[: i + 1]
if save is not None:
save_fig(fig, save)
return None if show else axes[0] if len(axes) == 1 else axes