Source code for cellrank.kernels._velocity_kernel

import logging
import time as _time
import warnings
from collections.abc import Callable, Sequence
from typing import Any, Literal

import numpy as np
import scipy.sparse as sp
from anndata import AnnData

from cellrank._utils._docs import d, inject_docs
from cellrank._utils._enum import DEFAULT_BACKEND, Backend_t
from cellrank.kernels._base_kernel import BidirectionalKernel
from cellrank.kernels.mixins import ConnectivityMixin
from cellrank.kernels.utils import Deterministic, MonteCarlo, Stochastic
from cellrank.kernels.utils._moments import _knn_moments, _row_normalize_connectivities
from cellrank.kernels.utils._similarity import Similarity, SimilarityABC
from cellrank.kernels.utils._velocity_model import BackwardMode, VelocityModel

logger = logging.getLogger(__name__)
__all__ = ["VelocityKernel"]


[docs] @d.dedent class VelocityKernel(ConnectivityMixin, BidirectionalKernel): """Kernel which computes a transition matrix based on RNA velocity. .. seealso:: - See :doc:`../../../notebooks/tutorials/kernels/200_rna_velocity` on how to compute the :attr:`~cellrank.kernels.VelocityKernel.transition_matrix` based on RNA velocity. This borrows ideas from both :cite:`manno:18` and :cite:`bergen:20`. In short, for each cell :math:`i`, we compute transition probabilities :math:`T_{i, j}` to each cell :math:`j` in the neighborhood of :math:`i`. We quantify how much the velocity vector :math:`v_i` of cell :math:`i` points towards each of its nearst neighbors. For this comparison, we support various schemes including cosine similarity and pearson correlation. Parameters ---------- %(adata)s %(backward)s attr Attribute of :class:`~anndata.AnnData` to read from. xkey Key in :attr:`~anndata.AnnData.layers` or :attr:`~anndata.AnnData.obsm` where expected gene expression counts are stored. vkey Key in :attr:`~anndata.AnnData.layers` or :attr:`~anndata.AnnData.obsm` where velocities are stored. gene_subset List of genes to be used to compute transition probabilities. If not specified, genes from :attr:`adata.var['{vkey}_genes'] <anndata.AnnData.var>` are used. This feature is only available when reading from :attr:`anndata.AnnData.layers` and will be ignored otherwise. kwargs Keyword arguments for the :class:`~cellrank.kernels.Kernel`. """ def __init__( self, adata: AnnData, backward: bool = False, attr: Literal["layers", "obsm"] | None = "layers", xkey: str | None = "Ms", vkey: str | None = "velocity", **kwargs: Any, ): super().__init__( adata, backward=backward, xkey=xkey, vkey=vkey, attr=attr, **kwargs, ) self._logits: np.ndarray | None = None def _read_from_adata( self, xkey: str | None = "Ms", vkey: str | None = "velocity", attr: Literal["layers", "obsm"] | None = "layers", gene_subset: str | Sequence[str] | None = None, **kwargs: Any, ) -> None: super()._read_from_adata(**kwargs) if attr == "layers" and gene_subset is None and f"{vkey}_genes" in self.adata.var: gene_subset = self.adata.var[f"{vkey}_genes"] elif attr == "obsm" and gene_subset is not None: logger.warning( "Found `gene_subset != None`, but it is not supported for `adata.%s`. Using `gene_subset = None`.", attr, ) gene_subset = None self._xdata = self._extract_data(key=xkey, attr=attr, subset=gene_subset) self._vdata = self._extract_data(key=vkey, attr=attr, subset=gene_subset) np.testing.assert_array_equal( self._xdata.shape, self._vdata.shape, err_msg=f"Shape mismatch: {self._xdata.shape} vs {self._vdata.shape}", ) nans = np.isnan(np.sum(self._vdata, axis=0)) if np.any(nans): self._xdata = self._xdata[:, ~nans] self._vdata = self._vdata[:, ~nans] W = _row_normalize_connectivities(self.connectivities) self._vexp, self._vvar = _knn_moments(W, self._vdata) self._vexp = self._vexp.astype(np.float64, copy=False) self._vvar = self._vvar.astype(np.float64, copy=False) # TODO(michalk8): remove the docrep in 2.0
[docs] @inject_docs(m=VelocityModel, b=BackwardMode, s=Similarity) # don't swap the order @d.dedent def compute_transition_matrix( self, model: Literal["deterministic", "stochastic", "monte_carlo"] = VelocityModel.DETERMINISTIC, backward_mode: Literal["transpose", "negate"] = BackwardMode.TRANSPOSE, similarity: Literal["correlation", "cosine", "dot_product"] | Callable[[np.ndarray, np.ndarray, float], tuple[np.ndarray, np.ndarray]] = Similarity.CORRELATION, softmax_scale: float | None = None, n_samples: int = 1000, seed: int | None = None, **kwargs: Any, ) -> "VelocityKernel": """Compute transition matrix based on velocity directions on the local manifold. For each cell, infer transition probabilities based on the cell's velocity-extrapolated cell state and the cell states of its *K* nearest neighbors. Parameters ---------- %(velocity_mode)s %(velocity_backward_mode)s %(softmax_scale)s %(velocity_scheme)s n_samples Number of samples when ``mode = {m.MONTE_CARLO!r}``. seed Random seed when ``mode = {m.MONTE_CARLO!r}``. %(parallel)s kwargs Keyword arguments for the underlying ``model``. Returns ------- Returns self and updates :attr:`transition_matrix`, :attr:`logits` and :attr:`params`. """ _start = _time.perf_counter() logger.info("Computing transition matrix using %r model", model) if VelocityModel(model) != VelocityModel.DETERMINISTIC: warnings.warn( f"model={model!r} is deprecated and will be removed in CellRank 3.0. " "Use model='deterministic' instead.", DeprecationWarning, stacklevel=2, ) # fmt: off params = {"model": str(model), "similarity": str(similarity), "softmax_scale": softmax_scale} if self.backward: params["bwd_mode"] = str(backward_mode) if VelocityModel(model) == VelocityModel.MONTE_CARLO: params["n_samples"] = n_samples params["seed"] = seed if self._reuse_cache(params): return self if softmax_scale is None: softmax_scale = self._estimate_softmax_scale(backward_mode=backward_mode, similarity=similarity) logger.info("Using `softmax_scale=%.4f`", softmax_scale) params["softmax_scale"] = softmax_scale # fmt: on model = self._create_model( model, backward_mode=backward_mode, similarity=similarity, softmax_scale=softmax_scale, n_samples=n_samples, seed=seed, ) if isinstance(model, Stochastic): kwargs["backend"] = DEFAULT_BACKEND self.transition_matrix, self._logits = model(**kwargs) logger.info(" Finish (%.2fs)", _time.perf_counter() - _start) return self
def _create_model( self, model: str | VelocityModel, backward_mode: Literal["negate", "transpose"], similarity: str | Similarity | Callable, n_samples: int = 1000, seed: int | None = None, **kwargs: Any, ): model = VelocityModel(model) backward_mode = BackwardMode(backward_mode) if self.backward else None if isinstance(similarity, str): similarity = SimilarityABC.create(Similarity(similarity)) elif not callable(similarity): raise TypeError(f"Expected `scheme` to be a function, found `{type(similarity).__name__}`.") if self.backward and model != VelocityModel.DETERMINISTIC: logger.warning( "Mode `%r` is currently not supported for backward process. Using to `mode=%r`", model, VelocityModel.DETERMINISTIC, ) model = VelocityModel.DETERMINISTIC if model == VelocityModel.STOCHASTIC and not hasattr(similarity, "hessian"): model = VelocityModel.MONTE_CARLO logger.warning( "Unable to detect a method for Hessian computation. If using one of the " "predefined similarity functions, consider installing `jax` as " "`pip install jax`. Using `mode=%r` and `n_samples=%s`", model, n_samples, ) if model == VelocityModel.DETERMINISTIC: return Deterministic( self.connectivities, self._xdata, self._vdata, similarity=similarity, backward_mode=backward_mode, **kwargs, ) if model == VelocityModel.STOCHASTIC: return Stochastic( self.connectivities, self._xdata, self._vexp, self._vvar, similarity=similarity, backward_mode=backward_mode, **kwargs, ) if model == VelocityModel.MONTE_CARLO: return MonteCarlo( self.connectivities, self._xdata, self._vexp, self._vvar, similarity=similarity, backward_mode=backward_mode, n_samples=n_samples, seed=seed, **kwargs, ) raise NotImplementedError(f"Model `{model}` is not yet implemented.") def _estimate_softmax_scale( self, n_jobs: int | None = None, backend: Backend_t = DEFAULT_BACKEND, **kwargs, ) -> float: model = self._create_model(VelocityModel.DETERMINISTIC, softmax_scale=1.0, **kwargs) _, logits = model(n_jobs, backend) return 1.0 / np.median(np.abs(logits.data)) def _extract_data( self, key: str | None = None, attr: None | Literal["layers", "obsm"] | None = None, subset: str | Sequence[str] | None = None, dtype: np.dtype = np.float64, ) -> np.ndarray: if key in (None, "X"): data = self.adata.X elif attr == "layers" and key in self.adata.layers: data = self.adata.layers[key] elif attr == "obsm" and key in self.adata.obsm: data = np.asarray(self.adata.obsm[key]) else: raise KeyError(f"Unable to find data in `adata.{attr}[{key!r}]`.") if subset is not None: if isinstance(subset, str): subset = self.adata.var[subset] subset = np.asarray(subset) if np.issubdtype(subset.dtype, bool) and subset.shape == (data.shape[1],): data = data[:, subset] else: data = data[:, np.isin(self.adata.var_names, subset)] data = data.astype(dtype, copy=False) return data.toarray() if sp.issparse(data) else data @property def logits(self) -> np.ndarray | None: """Array of shape ``(n_cells, n_cells)`` containing the unnormalized transition matrix.""" return self._logits def __invert__(self) -> "VelocityKernel": dk = self._copy_ignore("_transition_matrix", "_logits") dk._backward = not self.backward dk._params = {} return dk