Source code for cellrank.tl.kernels._transport_map_kernel

from typing import Any, Dict, Tuple, Union, Mapping, Iterable, Optional, Sequence
from typing_extensions import Literal

from abc import ABC, abstractmethod
from types import MappingProxyType
from contextlib import contextmanager

import scanpy as sc
from anndata import AnnData
from cellrank import logging as logg
from cellrank.settings import settings
from cellrank.tl._enum import ModeEnum
from cellrank.ul._docs import d, inject_docs
from cellrank.tl._utils import _normalize
from cellrank.tl.kernels._experimental_time_kernel import ExperimentalTimeKernel

import numpy as np
import pandas as pd
from scipy.sparse import bmat, spdiags, issparse, spmatrix, csr_matrix

__all__ = ("TransportMapKernel",)


class SelfTransitions(ModeEnum):
    UNIFORM = "uniform"
    DIAGONAL = "diagonal"
    CONNECTIVITIES = "connectivities"
    ALL = "all"


Numeric_t = Union[int, float]
Pair_t = Tuple[Numeric_t, Numeric_t]
Threshold_t = Union[int, float, Literal["auto", "auto_local"]]


[docs]class TransportMapKernel(ExperimentalTimeKernel, ABC): """Kernel base class which computes transition matrix based on transport maps for consecutive time point pairs.""" def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self._tmaps: Optional[Dict[Pair_t, AnnData]] = None @abstractmethod def _compute_tmap( self, t1: Numeric_t, t2: Numeric_t, **kwargs: Any ) -> Union[np.ndarray, spmatrix, AnnData]: """ Compute transport matrix for a time point pair. Parameters ---------- t1 Earlier time point in :attr:`experimental_time`. t2 Later time point in :attr:`experimental_time`. kwargs Additional keyword arguments. Returns ------- Array of shape ``(n_cells_early, n_cells_late)``. If :class:`anndata.AnnData`, :attr:`anndata.AnnData.obs_names` and :attr:`anndata.AnnData.var_names` correspond to subsets of observations from :attr:`adata`. """ @d.get_sections(base="tmk_thresh", sections=["Parameters"]) def _threshold_transport_maps( self, tmaps: Dict[Pair_t, AnnData], threshold: Threshold_t, copy: bool = False, ) -> Optional[Mapping[Pair_t, AnnData]]: """ Remove small non-zero values from :attr:`transition_matrix`. Parameters ---------- threshold How to remove small non-zero values from the transition matrix. Valid options are: - `'auto'` - find the maximum threshold value which will not remove every non-zero value from any row. - `'auto_local'` - same as above, but done for each transport separately. - :class:`float` - value in `[0, 100]` corresponding to a percentage of non-zeros to remove in each transport map. Rows where all values are removed will have uniform distribution and a warning will be issued. copy Whether to return a copy of thresholded ``tmaps`` or modify inplace. Returns ------- If ``copy = True``, returns a thresholded transport maps. Otherwise, modifies ``tmaps`` inplace. """ if threshold == "auto": thresh = min( adata.X[i].max() for adata in tmaps.values() for i in range(adata.n_obs) ) logg.info(f"Using automatic `threshold={thresh}`") elif threshold == "auto_local": logg.info("Using automatic `threshold` for each time point pair separately") elif not (0 <= threshold <= 100): raise ValueError( f"Expected `threshold` to be in `[0, 100]`, found `{threshold}`.`" ) for key, adata in tmaps.items(): if copy: adata = adata.copy() tmat = adata.X if threshold == "auto_local": thresh = min(tmat[i].max() for i in range(tmat.shape[0])) logg.debug(f"Using `threshold={thresh}` at `{key}`") elif isinstance(threshold, (int, float)): thresh = np.percentile(tmat.data, threshold) logg.debug(f"Using `threshold={thresh}` at `{key}`") tmat = csr_matrix(tmat, dtype=tmat.dtype) tmat.data[tmat.data < thresh] = 0.0 zeros_mask = np.where(np.asarray(tmat.sum(1)).squeeze() == 0)[0] if np.any(zeros_mask): logg.warning( f"After thresholding, `{len(zeros_mask)}` row(s) of transport at `{key}` are forced to be uniform" ) for ix in zeros_mask: start, end = tmat.indptr[ix], tmat.indptr[ix + 1] size = end - start tmat.data[start:end] = np.ones(size, dtype=tmat.dtype) / size # after `zeros_mask` has been handled, to have access to removed row indices tmat.eliminate_zeros() tmaps[key] = AnnData(tmat, obs=adata.obs, var=adata.var, dtype=tmat.dtype) return tmaps if copy else None
[docs] @d.get_sections(base="tmk_tmat", sections=["Parameters"]) @d.dedent @inject_docs(st=SelfTransitions) def compute_transition_matrix( self, threshold: Optional[Threshold_t] = "auto", self_transitions: Union[ Literal["uniform", "diagonal", "connectivities", "all"], Sequence[Numeric_t], ] = SelfTransitions.CONNECTIVITIES, conn_weight: Optional[float] = None, conn_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> "TransportMapKernel": """ Compute transition matrix using transport maps. Parameters ---------- %(tmk_thresh.parameters)s self_transitions How to define transitions within the diagonal blocks that correspond to transitions within the same time point. Valid options are: - `{st.UNIFORM!r}` - row-normalized matrix of 1s for transitions. Only applied to the last time point. - `{st.DIAGONAL!r}` - diagonal matrix with 1s on the diagonal. Only applied to the last time point. - `{st.CONNECTIVITIES!r}` - use transition matrix from :class:`cellrank.tl.kernels.ConnectivityKernel`. Only applied to the last time point. - :class:`typing.Sequence` - sequence of source time points defining which blocks should be weighted by connectivities. Always applied to the last time point. - `{st.ALL!r}` - same as above, but for all time points. conn_weight Weight of connectivities self transitions. Only used when ``self_transitions = {st.ALL!r}`` or a sequence of source time points is passed. conn_kwargs Keyword arguments for :func:`scanpy.pp.neighbors` when using ``self_transitions`` use :class:`cellrank.tl.kernels.ConnectivityKernel`. Can contain `'density_normalize'` for :meth:`cellrank.tl.kernels.ConnectivityKernel.compute_transition_matrix`. Returns ------- Self and updates the following attributes: - :attr:`transition_matrix` - transition matrix. - :attr:`transport_maps` - transport maps between consecutive time points. """ cache_params = dict(kwargs) cache_params["threshold"] = threshold cache_params["self_transitions"] = self_transitions if SelfTransitions(self_transitions) == SelfTransitions.ALL: conn_kwargs["conn_weight"] = conn_weight if self._reuse_cache(cache_params): return self timepoints = self.experimental_time.cat.categories timepoints = list(zip(timepoints[:-1], timepoints[1:])) self._tmaps = { (t1, t2): self._tmat_to_adata(t1, t2, self._compute_tmap(t1, t2, **kwargs)) for t1, t2 in timepoints } if threshold is not None: self._threshold_transport_maps(self._tmaps, threshold=threshold, copy=False) self.transition_matrix = self._restich_tmaps( self._tmaps, self_transitions=self_transitions, conn_weight=conn_weight, conn_kwargs=conn_kwargs, ).X return self
@d.dedent def _restich_tmaps( self, tmaps: Mapping[Pair_t, AnnData], self_transitions: Union[ str, SelfTransitions, Sequence[Numeric_t] ] = SelfTransitions.DIAGONAL, conn_weight: Optional[float] = None, conn_kwargs: Mapping[str, Any] = MappingProxyType({}), ) -> AnnData: """ Group individual transport maps into 1 matrix aligned with :attr:`adata`. Parameters ---------- tmaps Sorted transport maps as ``{{(t1, t2): tmat_2, (t2, t3): tmat_2, ...}}``. %(tmk_tmat.parameters)s Returns ------- Merged transport maps into one :class:`anndata.AnnData` object. """ if isinstance(self_transitions, (str, SelfTransitions)): self_transitions = SelfTransitions(self_transitions) if self_transitions == SelfTransitions.ALL: self_transitions = tuple(t1 for (t1, _) in tmaps.keys()) elif isinstance(self_transitions, Iterable): self_transitions = tuple(self_transitions) else: raise TypeError( f"Expected `self_transitions` to be a `str` or a `Sequence`, " f"found `{type(self_transitions).__name__}`." ) tmaps = self._validate_tmaps(tmaps) conn_kwargs = dict(conn_kwargs) conn_kwargs["copy"] = False _ = conn_kwargs.pop("key_added", None) blocks = [[None] * (len(tmaps) + 1) for _ in range(len(tmaps) + 1)] nrows, ncols = 0, 0 obs_names, obs = [], [] for i, tmap in enumerate(tmaps.values()): blocks[i][i + 1] = tmap.X nrows += tmap.n_obs ncols += tmap.n_vars obs_names.extend(tmap.obs_names) obs.append(tmap.obs) obs_names.extend(tmap.var_names) n = self.adata.n_obs - nrows if self_transitions == SelfTransitions.DIAGONAL: blocks[-1][-1] = spdiags([1] * n, 0, n, n) elif self_transitions == SelfTransitions.UNIFORM: blocks[-1][-1] = np.ones((n, n)) / float(n) elif self_transitions == SelfTransitions.CONNECTIVITIES: blocks[-1][-1] = self._compute_connectivity_tmat( self.adata[tmap.var_names], **conn_kwargs ) elif isinstance(self_transitions, tuple): verbosity = settings.verbosity try: # ignore overly verbose logging settings.verbosity = 0 if conn_weight is None or not (0 < conn_weight < 1): raise ValueError( "Please specify `conn_weight` in interval `(0, 1)`." ) for i, ((t1, _), tmap) in enumerate(tmaps.items()): if t1 not in self_transitions: continue blocks[i][i] = conn_weight * self._compute_connectivity_tmat( self.adata[tmap.obs_names], **conn_kwargs ) blocks[i][i + 1] = (1 - conn_weight) * _normalize(blocks[i][i + 1]) blocks[-1][-1] = self._compute_connectivity_tmat( self.adata[tmap.var_names], **conn_kwargs ) finally: settings.verbosity = verbosity else: raise NotImplementedError( f"Self transitions' mode `{self_transitions}` is not yet implemented." ) if not isinstance(self_transitions, tuple): # prevent the last block from disappearing n = blocks[0][1].shape[0] blocks[0][0] = spdiags([], 0, n, n) tmp = AnnData(bmat(blocks, format="csr"), dtype="float64") tmp.obs_names = obs_names tmp.var_names = obs_names tmp = tmp[self.adata.obs_names, :][:, self.adata.obs_names] tmp.obs = pd.merge( tmp.obs, pd.concat(obs), left_index=True, right_index=True, how="left", ) return tmp def _validate_tmaps( self, tmaps: Dict[Pair_t, AnnData], allow_reorder: bool = True, ) -> Mapping[Pair_t, AnnData]: """ Validate that transport maps conform to various invariants. Parameters ---------- tmaps Transport maps where :attr:`anndata.AnnData.var_names` in the earlier time point must correspond to :attr:`anndata.AnnData.obs_names` in the later time point. allow_reorder Whether the target cells in the earlier transport map can be used to reorder the source cells of the later transport map. Returns ------- Possibly reordered transport maps. Raises ------ ValueError If ``allow_subset = False`` and the target/source cell names on earlier/later transport map don't match. KeyError If ``allow_subset = True`` and the target cells in earlier transport map are not found in the later one or if the transport maps haven't been computed for all cells in :attr:`adata`. """ if not len(tmaps): raise ValueError("No transport maps have been computed.") tps, tmap2 = list(tmaps.keys()), None for (t1, t2), (t3, t4) in zip(tps[:-1], tps[1:]): tmap1, tmap2 = tmaps[t1, t2], tmaps[t3, t4] try: np.testing.assert_array_equal(tmap1.var_names, tmap2.obs_names) except AssertionError: if not allow_reorder: raise ValueError( f"Target cells of coupling with shape `{tmap1.shape}` at `{(t1, t2)}` does not " f"match the source cells of coupling with shape `{tmap2.shape}` at `{(t3, t4)}`" ) from None try: # if a subset happens, we still assume that the transport maps cover all cells from `adata` # i.e. the transport maps define a superset tmaps[t3, t4] = tmap2[tmap1.var_names] # keep the view except KeyError as e: raise KeyError( f"Unable to reorder transport map at `{(t3, t4)}` " f"using transport map at `{(t1, t2)}`." ) from e seen_obs = [] for tmap in tmaps.values(): seen_obs.extend(tmap.obs_names) seen_obs.extend(tmap.var_names) try: pd.testing.assert_series_equal( pd.Series(sorted(seen_obs)), pd.Series(sorted(self.adata.obs_names)), check_names=False, check_flags=False, check_index=True, ) except AssertionError as e: raise KeyError( "Observations from transport maps don't match " "the observations from the underlying `AnnData` object." ) from e return tmaps def _tmat_to_adata( self, t1: Numeric_t, t2: Numeric_t, tmat: Union[np.ndarray, spmatrix, AnnData] ) -> AnnData: """Convert transport map ``tmat`` to :class:`anndata.AnnData`. Do nothing if ``tmat`` is of the correct type.""" if isinstance(tmat, AnnData): return tmat tmat = AnnData(X=tmat, dtype=tmat.dtype) tmat.obs_names = self.adata[self.adata.obs[self._time_key] == t1].obs_names tmat.var_names = self.adata[self.adata.obs[self._time_key] == t2].obs_names return tmat @contextmanager def _tmap_as_tmat(self, **kwargs: Any) -> None: """Temporarily set :attr:`transport_maps` as :attr:`transition_matrix`.""" if self.transport_maps is None: raise RuntimeError( "Compute transport maps first as `.compute_transition_matrix()`." ) tmat = self._transition_matrix try: # fmt: off self._transition_matrix = self._restich_tmaps(self.transport_maps, **kwargs).X # fmt: on yield finally: self._transition_matrix = tmat @staticmethod def _compute_connectivity_tmat( adata: AnnData, **kwargs: Any ) -> Union[np.ndarray, spmatrix]: from cellrank.tl.kernels import ConnectivityKernel # same default as in `ConnectivityKernel` density_normalize = kwargs.pop("density_normalize", True) sc.pp.neighbors(adata, **kwargs) return ( ConnectivityKernel(adata) .compute_transition_matrix(density_normalize=density_normalize) .transition_matrix )
[docs] @d.dedent def plot_single_flow( self, cluster: str, cluster_key: str, time_key: Optional[str] = None, use_transport_maps: bool = False, threshold: Optional[Union[float, Literal["auto"]]] = None, **kwargs: Any, ) -> None: """ %(plot_single_flow.full_desc)s Parameters ---------- %(plot_single_flow.parameters)s Returns ------- %(plot_single_flow.returns)s """ # noqa: D400 if use_transport_maps: with self._tmap_as_tmat(threshold): return super().plot_single_flow( cluster, cluster_key, time_key=time_key, **kwargs ) return super().plot_single_flow( cluster, cluster_key, time_key=time_key, **kwargs )
@property def transport_maps(self) -> Optional[Dict[Pair_t, AnnData]]: """Transport maps for consecutive time pairs.""" return self._tmaps def __invert__(self) -> "TransportMapKernel": tk = super().__invert__("_tmaps") # don't copy transport maps tk._tmaps = None return tk