Source code for cellrank.estimators.terminal_states._term_states_estimator

import abc
import types
from typing import Any, Dict, Literal, Optional, Sequence, Tuple, Union

import scvelo as scv

import numpy as np
import pandas as pd
import scipy.sparse as sp
from pandas.api.types import infer_dtype

from matplotlib.colors import to_hex

from anndata import AnnData

from cellrank import logging as logg
from cellrank._utils._colors import (
    _convert_to_hex_colors,
    _create_categorical_colors,
    _map_names_and_colors,
)
from cellrank._utils._docs import d, inject_docs
from cellrank._utils._key import Key
from cellrank._utils._lineage import Lineage
from cellrank._utils._utils import (
    RandomKeys,
    _convert_to_categorical_series,
    _merge_categorical_series,
    _unique_order_preserving,
)
from cellrank.estimators._base_estimator import BaseEstimator
from cellrank.estimators.mixins._utils import (
    PlotMode,
    SafeGetter,
    StatesHolder,
    logger,
    shadow,
)
from cellrank.kernels._base_kernel import KernelExpression

__all__ = ["TermStatesEstimator"]


[docs] @d.dedent class TermStatesEstimator(BaseEstimator, abc.ABC): """Base class for all estimators predicting the initial and terminal states. Parameters ---------- %(base_estimator.parameters)s """ def __init__( self, object: Union[AnnData, np.ndarray, sp.spmatrix, KernelExpression], **kwargs: Any, ): super().__init__(object=object, **kwargs) self._init_states = StatesHolder() self._term_states = StatesHolder() @property @d.get_summary(base="tse_term_states") def terminal_states(self) -> Optional[pd.Series]: """Categorical annotation of terminal states. By default, all transient cells will be labeled as `NaN`. """ return self._term_states.assignment @property @d.get_summary(base="tse_term_states_probs") def terminal_states_probabilities(self) -> Optional[pd.Series]: """Probability to be a terminal state.""" return self._term_states.probs @property @d.get_summary(base="tse_init_states") def initial_states(self) -> Optional[pd.Series]: """Categorical annotation of initial states. By default, all transient cells will be labeled as `NaN`. """ return self._init_states.assignment @property @d.get_summary(base="tse_init_states_probs") def initial_states_probabilities(self) -> Optional[pd.Series]: """Probability to be an initial state.""" return self._init_states.probs
[docs] @d.dedent def set_terminal_states( self, states: Union[pd.Series, Dict[str, Sequence[Any]]], cluster_key: Optional[str] = None, allow_overlap: bool = False, **kwargs: Any, ) -> "TermStatesEstimator": """Set the :attr:`terminal_states`. Parameters ---------- states States to select. Valid options are: - categorical :class:`~pandas.Series` where each category corresponds to an individual state. `NaN` entries denote cells that do not belong to any state, i.e., transient cells. - :class:`dict` where keys are states and values are lists of cell barcodes corresponding to annotations in :attr:`~anndata.AnnData.obs_names`. If only 1 key is provided, values should correspond to clusters if a categorical :class:`~pandas.Series` can be found in :attr:`~anndata.AnnData.obs`. cluster_key Key in :attr:`~anndata.AnnData.obs` to associate names and colors with :attr:`terminal_states`. Each state will be given the name and color corresponding to the cluster it mostly overlaps with. %(allow_overlap)s kwargs Additional keyword arguments. Returns ------- Returns self and updates the following fields: - :attr:`terminal_states` - %(tse_term_states.summary)s - :attr:`terminal_states_probabilities` - %(tse_term_states_probs.summary)s """ states, colors = self._set_categorical_labels( categories=states, cluster_key=cluster_key, existing=None, ) self._write_states( "terminal", states=states, colors=colors, allow_overlap=allow_overlap, **kwargs, ) return self
[docs] @d.dedent def set_initial_states( self, states: Union[pd.Series, Dict[str, Sequence[Any]]], cluster_key: Optional[str] = None, allow_overlap: bool = False, **kwargs: Any, ) -> "TermStatesEstimator": """Set the :attr:`initial_states`. Parameters ---------- states Which states to select. Valid options are: - categorical :class:`~pandas.Series` where each category corresponds to an individual state. `NaN` entries denote cells that do not belong to any state, i.e., transient cells. - :class:`dict` where keys are states and values are lists of cell barcodes corresponding to annotations in :attr:`~anndata.AnnData.obs_names`. If only 1 key is provided, values should correspond to clusters if a categorical :class:`~pandas.Series` can be found in :attr:`~anndata.AnnData.obs`. cluster_key Key in :attr:`~anndata.AnnData.obs` to associate names and colors :attr:`initial_states`. Each state will be given the name and color corresponding to the cluster it mostly overlaps with. %(allow_overlap)s kwargs Additional keyword arguments. Returns ------- Returns self and updates the following fields: - :attr:`initial_states` - %(tse_init_states.summary)s - :attr:`initial_states_probabilities` - %(tse_init_states_probs.summary)s """ states, colors = self._set_categorical_labels( categories=states, cluster_key=cluster_key, existing=None, ) self._write_states( "initial", states=states, colors=colors, allow_overlap=allow_overlap, **kwargs, ) return self
[docs] @d.get_sections(base="tse_rename_term_states", sections=["Parameters", "Returns"]) @d.get_full_description(base="tse_rename_term_states") @d.dedent def rename_terminal_states(self, old_new: Dict[str, str]) -> "TermStatesEstimator": """Rename the :attr:`terminal_states`. Parameters ---------- old_new Dictionary that maps old names to unique new names. Returns ------- Returns self and updates the following fields: - :attr:`terminal_states` - %(tse_term_states.summary)s """ states = self.terminal_states if states is None: raise RuntimeError( "Compute terminal states first as `.predict_terminal_states()` or " "set them manually as `.set_terminal_states()`." ) # fmt: off if not isinstance(old_new, dict): raise TypeError(f"Expected new names to be a `dict`, found `{type(old_new)}`.") if not len(old_new): return self old_names = states.cat.categories mask = np.isin(list(old_new.keys()), old_names) if not np.all(mask): invalid = sorted(np.array(list(old_new.keys()))[~mask]) raise ValueError(f"Invalid terminal states names: `{invalid}`. Valid names are: `{sorted(old_names)}`.") names_after_renaming = [old_new.get(n, n) for n in old_names] if len(set(names_after_renaming)) != len(old_names): raise ValueError(f"After renaming, terminal states will no longer unique: `{names_after_renaming}`.") # fmt: on assignment = states.cat.rename_categories(old_new) memberships = self._term_states.memberships # save before overwriting self._write_states( "terminal", states=assignment, colors=self._term_states.colors, probs=self.terminal_states_probabilities, log=False, ) if memberships is not None: memberships.names = [old_new.get(n, n) for n in memberships.names] self._term_states = self._term_states.set(assignment=assignment, memberships=memberships) return self
[docs] @d.dedent def rename_initial_states(self, old_new: Dict[str, str]) -> "TermStatesEstimator": """Rename the :attr:`initial_states`. Parameters ---------- old_new Dictionary that maps old names to unique new names. Returns ------- Returns self and updates the following fields: - :attr:`initial_states` - %(tse_init_states.summary)s """ states = self.initial_states if states is None: raise RuntimeError( "Compute initial states first as `.predict_initial_states()` or " "set them manually as `.set_initial_states()`." ) # fmt: off if not isinstance(old_new, dict): raise TypeError(f"Expected new names to be a `dict`, found `{type(old_new)}`.") if not len(old_new): return self old_names = states.cat.categories mask = np.isin(list(old_new.keys()), old_names) if not np.all(mask): invalid = sorted(np.array(list(old_new.keys()))[~mask]) raise ValueError(f"Invalid terminal states names: `{invalid}`. Valid names are: `{sorted(old_names)}`.") names_after_renaming = [old_new.get(n, n) for n in old_names] if len(set(names_after_renaming)) != len(old_names): raise ValueError(f"After renaming, terminal states will no longer unique: `{names_after_renaming}`.") # fmt: on assignment = states.cat.rename_categories(old_new) memberships = self._init_states.memberships # save overwriting self._write_states( "initial", states=assignment, colors=self._init_states.colors, probs=self.initial_states_probabilities, log=False, ) if memberships is not None: memberships.names = [old_new.get(n, n) for n in memberships.names] self._init_states = self._init_states.set(assignment=assignment, memberships=memberships) return self
[docs] @d.dedent @inject_docs(m=PlotMode) def plot_macrostates( self, which: Literal["all", "initial", "terminal"], states: Optional[Union[str, Sequence[str]]] = None, color: Optional[str] = None, discrete: bool = True, mode: Literal["embedding", "time"] = PlotMode.EMBEDDING, time_key: str = "latent_time", same_plot: bool = True, title: Optional[Union[str, Sequence[str]]] = None, cmap: str = "viridis", **kwargs: Any, ) -> None: """Plot macrostates on an embedding or along pseudotime. Parameters ---------- which Which macrostates to plot. Valid options are: - ``'all'`` - plot all macrostates. - ``'initial'`` - plot macrostates marked as :attr:`initial_states`. - ``'terminal'`` - plot macrostates marked as :attr:`terminal_states`. states Subset of the macrostates to show. If :obj:`None`, plot all macrostates. color Key in :attr:`~anndata.AnnData.obs` or :attr:`~anndata.AnnData.var` used to color the observations. discrete Whether to plot the data as continuous or discrete observations. If the data cannot be plotted as continuous observations, it will be plotted as discrete. mode Whether to plot the probabilities in an embedding or along the pseudotime. time_key Key in :attr:`~anndata.AnnData.obs` where pseudotime is stored. Only used when ``mode = {m.TIME!r}``. title Title of the plot. same_plot Whether to plot the data on the same plot or not. Only use when ``mode = {m.EMBEDDING!r}``. If `True` and ``discrete = False``, ``color`` is ignored. cmap Colormap for continuous annotations. kwargs Keyword arguments for :func:`~scvelo.pl.scatter`. Returns ------- %(just_plots)s """ if which == "all": obj: Optional[StatesHolder] = getattr(self, "_macrostates", None) if obj is None: raise RuntimeError(f"`{type(self).__name__}` cannot plot macrostates.") elif which == "initial": obj = self._init_states elif which == "terminal": obj = self._term_states else: raise ValueError( f"Unable to plot `{which!r}` states. " f"Valid options are: `{['all', 'initial', 'terminal']}`." ) name = "macrostates" if which == "macro" else f"{which} states" if obj.assignment is None and obj.memberships is None: raise RuntimeError(f"Compute {name} first.") if not discrete and obj.memberships is None: logg.warning(f"Unable to plot {name} in continuous mode, using discrete") discrete = True data = obj.assignment if discrete else obj.memberships colors = obj.colors if discrete: return self._plot_discrete( _data=data, _colors=colors, _title=name, states=states, color=color, same_plot=same_plot, title=title, cmap=cmap, **kwargs, ) return self._plot_continuous( _data=data, _colors=colors, _title=name, states=states, color=color, mode=mode, time_key=time_key, same_plot=same_plot, title=title, cmap=cmap, **kwargs, )
def _plot_discrete( self, _data: pd.Series, _colors: Optional[np.ndarray] = None, _title: Optional[str] = None, states: Optional[Union[str, Sequence[str]]] = None, color: Optional[str] = None, title: Optional[Union[str, Sequence[str]]] = None, same_plot: bool = True, cmap: str = "viridis", **kwargs: Any, ) -> None: if not isinstance(_data, pd.Series): raise TypeError(f"Expected `data` to be of type `pandas.Series`, found `{type(_data)}`.") if not isinstance(_data.dtype, pd.CategoricalDtype): raise TypeError(f"Expected `data` to be `categorical`, found `{infer_dtype(_data)}`.") names = list(_data.cat.categories) if _colors is None: _colors = _create_categorical_colors(len(names)) if len(_colors) != len(names): raise ValueError(f"Expected `colors` to be of length `{len(names)}`, found `{len(_colors)}`.") color_mapper = dict(zip(names, _colors)) states = _unique_order_preserving(states or names) if not len(states): raise ValueError("No states have been selected.") for name in states: if name not in names: raise ValueError(f"Invalid name `{name!r}`. Valid options are: `{sorted(names)}`.") _data = _data.cat.set_categories(states) color = [] if color is None else (color,) if isinstance(color, str) else color color = _unique_order_preserving(color) same_plot = same_plot or len(names) == 1 kwargs.setdefault("legend_loc", "on data") kwargs["color_map"] = cmap # fmt: off with RandomKeys(self.adata, n=1 if same_plot else len(states), where="obs") as keys: if same_plot: outline = _data.cat.categories.to_list() _data = _data.cat.add_categories(["nan"]).fillna("nan") states.append("nan") color_mapper["nan"] = "#dedede" self.adata.obs[keys[0]] = _data self.adata.uns[f"{keys[0]}_colors"] = [color_mapper[name] for name in states] title = _title if title is None else title else: outline = None for key, cat in zip(keys, states): self.adata.obs[key] = _data.cat.set_categories([cat]) self.adata.uns[f"{key}_colors"] = [color_mapper[cat]] title = [f"{_title} {name}" for name in states] if title is None else title if isinstance(title, str): title = [title] scv.pl.scatter( self.adata, color=color + keys, title=color + title, add_outline=outline, **kwargs, ) # fmt: on def _plot_continuous( self, _data: Lineage, _colors: Optional[np.ndarray] = None, _title: Optional[str] = None, states: Optional[Union[str, Sequence[str]]] = None, color: Optional[str] = None, mode: Literal["embedding", "time"] = PlotMode.EMBEDDING, time_key: str = "latent_time", title: Optional[Union[str, Sequence[str]]] = None, same_plot: bool = True, cmap: str = "viridis", **kwargs: Any, ) -> None: mode = PlotMode(mode) if not isinstance(_data, Lineage): raise TypeError(f"Expected data to be of type `Lineage`, found `{type(_data)}`.") if states is None: states = _data.names if not len(states): raise ValueError("No lineages have been selected.") is_singleton = _data.shape[1] == 1 _data = _data[states].copy() if mode == "time" and same_plot: logg.warning("Invalid combination `mode='time'` and `same_plot=True`. Using `same_plot=False`") same_plot = False _data_X = _data.X # list(_data.T) behaves differently than a numpy.array if _data_X.shape[1] == 1: same_plot = False if np.allclose(_data_X, 1.0): # matplotlib shows even tiny perturbations in the colormap _data_X = np.ones_like(_data_X) for col in _data_X.T: mask = ~np.isclose(col, 1.0) # change the maximum value - the 1 is artificial and obscures the color scaling if np.any(mask): col[~mask] = np.nanmax(col[mask]) # fmt: off color = [] if color is None else (color,) if isinstance(color, str) else color color = _unique_order_preserving(color) if mode == PlotMode.TIME: kwargs.setdefault("legend_loc", "best") if time_key is None: raise KeyError( "The name of the column in `adata.obs` defining the pseudotime needs to be defined via the " "`time_key` argument." ) if title is None: title = [f"{_title} {state}" for state in states] if time_key not in self.adata.obs: raise KeyError(f"Unable to find pseudotime in `adata.obs[{time_key!r}]`.") if len(color) and len(color) not in (1, _data_X.shape[1]): raise ValueError(f"Expected `color` to be of length `1` or `{_data_X.shape[1]}`, " f"found `{len(color)}`.") kwargs["x"] = self.adata.obs[time_key] kwargs["y"] = list(_data_X.T) kwargs["color"] = color if len(color) else None kwargs["xlabel"] = [time_key] * len(states) kwargs["ylabel"] = ["probability"] * len(states) elif mode == PlotMode.EMBEDDING: kwargs.setdefault("legend_loc", "on data") if same_plot: if color: # https://github.com/theislab/scvelo/issues/673 logg.warning("Ignoring `color` when `mode='embedding'` and `same_plot=True`") title = [_title] if title is None else title kwargs["color_gradients"] = _data else: title = [f"{_title} {state}" for state in states] if title is None else title if isinstance(title, str): title = [title] title = color + title kwargs["color"] = color + list(_data_X.T) else: raise NotImplementedError(f"Mode `{mode}` is not yet implemented.") # fmt: on # e.g. a stationary distribution if is_singleton and not np.allclose(_data_X, 1.0): kwargs.setdefault("perc", [0, 95]) _ = kwargs.pop("color_gradients", None) scv.pl.scatter( self.adata, title=title, color_map=cmap, **kwargs, ) def _set_categorical_labels( self, categories: Union[pd.Series, Dict[str, Any]], cluster_key: Optional[str] = None, existing: Optional[pd.Series] = None, ) -> Tuple[pd.Series, np.ndarray]: # fmt: off if isinstance(categories, dict): key = next(iter(categories.keys())) data = self.adata.obs.get(key, None) is_categorical = data is not None and isinstance(data.dtype, pd.CategoricalDtype) if len(categories) == 1 and is_categorical: vals = categories[key] if isinstance(vals, str) or not isinstance(vals, Sequence): vals = (categories[key],) clusters = self.adata.obs[key] clusters = clusters.cat.rename_categories({c: str(c) for c in clusters.cat.categories}) vals = tuple(str(v) for v in vals) categories = {cat: self.adata[clusters == cat].obs_names for cat in vals} categories = _convert_to_categorical_series(categories, list(self.adata.obs_names)) if not isinstance(categories.dtype, pd.CategoricalDtype): raise TypeError(f"Expected object to be `categorical`, found `{infer_dtype(categories)}`.") if existing is not None: categories = _merge_categorical_series(old=existing, new=categories) if cluster_key is not None: # check that we can load the reference series from adata if cluster_key not in self.adata.obs: raise KeyError(f"Unable to find clusters in `adata.obs[{cluster_key!r}]`.") series_query, series_reference = categories, self.adata.obs[cluster_key] series_reference = series_reference.cat.rename_categories( {c: str(c) for c in series_reference.cat.categories} ) # load the reference colors if they exist if Key.uns.colors(cluster_key) in self.adata.uns: colors_reference = _convert_to_hex_colors(self.adata.uns[Key.uns.colors(cluster_key)]) else: colors_reference = _create_categorical_colors(len(series_reference.cat.categories)) names, colors = _map_names_and_colors( series_reference=series_reference, series_query=series_query, colors_reference=colors_reference, ) cats = categories.cat.categories categories = categories.cat.rename_categories(dict(zip(cats, names))) else: colors = _create_categorical_colors(len(categories.cat.categories)) return categories, colors # fmt: on @logger @shadow def _write_states( self, which: Literal["initial", "terminal"], states: Optional[pd.Series], colors: Optional[np.ndarray], probs: Optional[pd.Series] = None, params: Dict[str, Any] = types.MappingProxyType({}), allow_overlap: bool = False, ) -> str: # fmt: off backward = which == "initial" if not allow_overlap: fwd_states, bwd_states = (self.terminal_states, states) if backward else (states, self.initial_states) if fwd_states is not None and bwd_states is not None: overlap = np.sum(~pd.isnull(fwd_states) & ~pd.isnull(bwd_states)) if overlap: raise ValueError( f"Found `{overlap}` overlapping cells between initial and terminal states. " f"If this is intended, please use `allow_overlap=True`." ) key = Key.obs.term_states(self.backward, bwd=backward) self._set(obj=self.adata.obs, key=key, value=states) self._set(obj=self.adata.obs, key=Key.obs.probs(key), value=probs) self._set(obj=self.adata.uns, key=Key.uns.colors(key), value=colors) if backward: self._init_states = self._init_states.set(assignment=states, probs=probs, colors=colors) else: self._term_states = self._term_states.set(assignment=states, probs=probs, colors=colors) self.params[key] = dict(params) # fmt: on return ( f"Adding `adata.obs[{key!r}]`\n" f" `adata.obs[{Key.obs.probs(key)!r}]`\n" f" `.{which}_states`\n" f" `.{which}_states_probabilities`\n" " Finish" ) def _read_from_adata(self, adata: AnnData, **kwargs: Any) -> bool: ok = super()._read_from_adata(adata, **kwargs) if not ok: return False # fmt: off for backward in [True, False]: key = Key.obs.term_states(self.backward, bwd=backward) with SafeGetter(self, allowed=KeyError) as sg: assignment = self._get(obj=adata.obs, key=key, shadow_attr="obs", dtype=pd.Series) probs = self._get(obj=adata.obs, key=Key.obs.probs(key), shadow_attr="obs", dtype=pd.Series) colors = self._get(obj=adata.uns, key=Key.uns.colors(key), shadow_attr="uns", dtype=(list, tuple, np.ndarray)) colors = np.asarray([to_hex(c) for c in colors]) if backward: self._init_states = StatesHolder(assignment=assignment, probs=probs, colors=colors) else: self._term_states = StatesHolder(assignment=assignment, probs=probs, colors=colors) self.params[key] = self._read_params(key) # fmt: on # status is based on `backward=False` by design return sg.ok def _format_params(self) -> str: fmt = super()._format_params() # fmt: off init_states = None if self.initial_states is None else sorted(self.initial_states.cat.categories) term_states = None if self.terminal_states is None else sorted(self.terminal_states.cat.categories) return fmt + f", initial_states={init_states}, terminal_states={term_states}"
# fmt: on