from typing import Any, Tuple, Union, Iterable, Optional, Sequence
from copy import copy
from pathlib import Path
from anndata import AnnData
from cellrank import logging as logg
from cellrank._key import Key
from cellrank.ul._docs import d
from cellrank.pl._utils import _position_legend, _get_categorical_colors
from cellrank.tl._utils import save_fig, _unique_order_preserving
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from seaborn import stripplot
from matplotlib.cm import ScalarMappable
from matplotlib.axes import Axes
from matplotlib.colors import Normalize, to_hex
[docs]@d.dedent
def log_odds(
adata: AnnData,
lineage_1: str,
lineage_2: Optional[str] = None,
time_key: str = "exp_time",
backward: bool = False,
keys: Optional[Union[str, Sequence[str]]] = None,
threshold: Optional[Union[float, Sequence]] = None,
threshold_color: str = "red",
layer: Optional[str] = None,
use_raw: bool = False,
size: float = 2.0,
cmap: str = "viridis",
alpha: Optional[float] = 0.8,
ncols: Optional[int] = None,
fontsize: Optional[Union[float, str]] = None,
xticks_step_size: Optional[int] = 1,
legend_loc: Optional[str] = "best",
jitter: Union[bool, float] = True,
seed: Optional[int] = None,
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[Union[str, Path]] = None,
show: bool = True,
**kwargs: Any,
) -> Optional[Union[Axes, Sequence[Axes]]]:
"""
Plot log-odds ratio between lineages.
Log-odds are plotted as a function of the experimental time.
Parameters
----------
%(adata)s
lineage_1
The first lineage for which to compute the log-odds.
lineage_2
The second lineage for which to compute the log-odds. If `None`, use the rest of the lineages.
time_key
Key in :attr:`anndata.AnnData.obs` containing the experimental time.
%(backward)s
keys
Key in :attr:`anndata.AnnData.obs` or :attr:`anndata.AnnData.var_names`.
threshold
Visualize whether total expression per cell is greater than ``threshold``.
If a :class:`typing.Sequence`, it should be the same length as ``keys``.
threshold_color
Color to use when plotting thresholded expression values.
layer
Which layer to use to get expression values. If `None` or `'X'`, use :attr:`anndata.AnnData.X`.
use_raw
Whether to access :attr:`anndata.AnnData.raw`. If `True`, ``layer`` is ignored.
size
Size of the dots.
cmap
Colormap to use for continuous variables in ``keys``.
alpha
Alpha values for the dots.
ncols
Number of columns.
fontsize
Size of the font for the title, x- and y-label.
xticks_step_size
Show only every n-th ticks on x-axis. If `None`, don't show any ticks.
legend_loc
Position of the legend. If `None`, do not show the legend.
jitter
Amount of jitter to apply along x-axis.
seed
Seed for ``jitter`` to ensure reproducibility.
%(plotting)s
show
If `False`, return :class:`matplotlib.pyplot.Axes` or a sequence of them.
kwargs
Keyword arguments for :func:`seaborn.stripplot`.
Returns
-------
The axes object(s), if ``show = False``.
%(just_plots)s
"""
from cellrank.tl.kernels._utils import _ensure_numeric_ordered
def decorate(
ax: Axes, *, title: Optional[str] = None, show_ylabel: bool = True
) -> None:
ax.set_xlabel(time_key, fontsize=fontsize)
ax.set_title(title, fontdict={"fontsize": fontsize})
ax.set_ylabel(ylabel if show_ylabel else "", fontsize=fontsize)
if xticks_step_size is None:
ax.set_xticks([])
else:
step = max(1, xticks_step_size)
ax.set_xticks(np.arange(0, n_cats, step))
ax.set_xticklabels(df[time_key].cat.categories[::step])
def cont_palette(values: np.ndarray) -> Tuple[np.ndarray, ScalarMappable]:
cm = copy(plt.get_cmap(cmap))
cm.set_bad("grey")
sm = ScalarMappable(
cmap=cm, norm=Normalize(vmin=np.nanmin(values), vmax=np.nanmax(values))
)
return np.array([to_hex(v) for v in (sm.to_rgba(values))]), sm
def get_data(
key: str,
thresh: Optional[float] = None,
) -> Tuple[
Optional[str], Optional[np.ndarray], Optional[np.ndarray], ScalarMappable
]:
try:
_, palette = _get_categorical_colors(adata, key)
df[key] = adata.obs[key].values[mask]
df[key] = df[key].cat.remove_unused_categories()
try:
# seaborn doesn't like numeric categories
df[key] = df[key].astype(float)
palette = {float(k): v for k, v in palette.items()}
except ValueError:
pass
# otherwise seaborn plots all
palette = {k: palette[k] for k in df[key].unique()}
hue, thresh_mask, sm = key, None, None
except TypeError:
palette, hue, thresh_mask, sm = (
cont_palette(adata.obs[key].values[mask])[0],
None,
None,
None,
)
except KeyError:
try:
# fmt: off
if thresh is None:
values = adata.raw.obs_vector(key) if use_raw else adata.obs_vector(key, layer=layer)
palette, sm = cont_palette(values)
hue, thresh_mask = None, None
else:
if use_raw:
values = np.asarray(adata.raw[:, key].X[mask].sum(1)).squeeze()
elif layer not in (None, "X"):
values = np.asarray(adata[:, key].layers[layer][mask].sum(1)).squeeze()
else:
values = np.asarray(adata[:, key].X[mask].sum(1)).squeeze()
thresh_mask = values > thresh
hue, palette, sm = None, None, None
# fmt: on
except KeyError as e:
raise e from None
return hue, palette, thresh_mask, sm
np.random.seed(seed)
_ = kwargs.pop("orient", None)
if use_raw and adata.raw is None:
logg.warning("No raw attribute set. Setting `use_raw=False`")
use_raw = False
# define log-odds
lineage_key = Key.obsm.abs_probs(backward)
if lineage_key not in adata.obsm:
raise KeyError(f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")
time = _ensure_numeric_ordered(adata, time_key)
order = time.cat.categories[:: -1 if backward else 1]
fate1 = adata.obsm[lineage_key][lineage_1].X.squeeze(-1)
if lineage_2 is None:
fate2 = 1 - fate1
ylabel = rf"$\log{{\frac{{{lineage_1}}}{{rest}}}}$"
else:
fate2 = adata.obsm[lineage_key][lineage_2].X.squeeze(-1)
ylabel = rf"$\log{{\frac{{{lineage_1}}}{{{lineage_2}}}}}$"
# fmt: off
df = pd.DataFrame(
{
"log_odds": np.log(np.divide(fate1, fate2, where=fate2 != 0, out=np.zeros_like(fate1)) + 1e-12),
time_key: time,
}
)
mask = (fate1 != 0) & (fate2 != 0)
df = df[mask]
n_cats = len(df[time_key].cat.categories)
# fmt: on
if keys is None:
if figsize is None:
figsize = np.array([n_cats, n_cats * 4 / 6]) / 2
fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True)
ax = stripplot(
x=time_key,
y="log_odds",
data=df,
order=order,
jitter=jitter,
color="k",
size=size,
ax=ax,
**kwargs,
)
decorate(ax)
if save is not None:
save_fig(fig, save)
return None if show else ax
if isinstance(keys, str):
keys = (keys,)
if not len(keys):
raise ValueError("No keys have been selected.")
keys = _unique_order_preserving(keys)
if not isinstance(threshold, Iterable):
threshold = (threshold,) * len(keys)
if len(threshold) != len(keys):
raise ValueError(
f"Expected `threshold` to be of length `{len(keys)}`, found `{len(threshold)}`."
)
ncols = max(len(keys) if ncols is None else ncols, 1)
nrows = int(np.ceil(len(keys) / ncols))
if figsize is None:
figsize = np.array([n_cats * ncols, n_cats * nrows * 4 / 6]) / 2
fig, axes = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=figsize,
dpi=dpi,
constrained_layout=True,
sharey="all",
)
axes = np.ravel([axes])
i = 0
for i, (key, ax, thresh) in enumerate(zip(keys, axes, threshold)):
hue, palette, thresh_mask, sm = get_data(key, thresh)
show_ylabel = i % ncols == 0
ax = stripplot(
x=time_key,
y="log_odds",
data=df if thresh_mask is None else df[~thresh_mask],
hue=hue,
order=order,
jitter=jitter,
color="black",
palette=palette,
size=size,
alpha=alpha if alpha is not None else None if thresh_mask is None else 0.8,
ax=ax,
**kwargs,
)
if thresh_mask is not None:
stripplot(
x=time_key,
y="log_odds",
data=df if thresh_mask is None else df[thresh_mask],
hue=hue,
order=order,
jitter=jitter,
color=threshold_color,
palette=palette,
size=size * 2,
alpha=0.9,
ax=ax,
**kwargs,
)
key = rf"${key} > {thresh}$"
if sm is not None:
cax = ax.inset_axes([1.02, 0, 0.025, 1], transform=ax.transAxes)
fig.colorbar(sm, ax=ax, cax=cax)
else:
if legend_loc in (None, "none"):
legend = ax.get_legend()
if legend is not None:
legend.remove()
else:
handles, labels = ax.get_legend_handles_labels()
if len(handles):
_position_legend(
ax, legend_loc=legend_loc, handles=handles, labels=labels
)
decorate(ax, title=key, show_ylabel=show_ylabel)
for ax in axes[i + 1 :]:
ax.remove()
axes = axes[: i + 1]
if save is not None:
save_fig(fig, save)
return None if show else axes[0] if len(axes) == 1 else axes