Source code for cellrank.tl.kernels._precomputed_kernel

from typing import Any, Dict, Union, Optional

from copy import copy

from anndata import AnnData
from cellrank import logging as logg
from cellrank._key import Key
from cellrank.ul._docs import d
from cellrank.ul._utils import _read_graph_data
from cellrank.tl.kernels._base_kernel import _RTOL, Kernel, KernelExpression

import numpy as np
from scipy.sparse import spmatrix, csr_matrix


[docs]@d.dedent class PrecomputedKernel(Kernel): """ Kernel which contains a precomputed transition matrix. Parameters ---------- transition_matrix Row-normalized transition matrix or a key in :attr:`anndata.AnnData.obsp`. or a :class:`cellrank.tl.kernels.KernelExpression` with a precomputed transition matrix. If `None`, try to determine the key based on ``backward``. %(adata)s If `None`, a temporary placeholder object is created. %(backward)s %(cond_num)s kwargs Keyword arguments for :class:`cellrank.tl.kernels.Kernel`. """ def __init__( self, transition_matrix: Optional[ Union[np.ndarray, spmatrix, KernelExpression, str] ] = None, adata: Optional[AnnData] = None, backward: bool = False, compute_cond_num: bool = False, **kwargs: Any, ): origin = "'array'" params: Dict[str, Any] = {} if transition_matrix is None: transition_matrix = Key.uns.kernel(backward) logg.info(f"Accessing `adata.obsp[{transition_matrix!r}]`") if isinstance(transition_matrix, str): if adata is None: raise ValueError( "When `transition_matrix` specifies a key to `adata.obsp`, `adata` cannot be None." ) origin = f"adata.obsp[{transition_matrix!r}]" backward = Key.uns.kernel(bwd=True) == transition_matrix transition_matrix = _read_graph_data(adata, transition_matrix) elif isinstance(transition_matrix, KernelExpression): if transition_matrix._transition_matrix is None: raise ValueError( "Compute transition matrix first as `.compute_transition_matrix()`." ) if adata is not None and adata is not transition_matrix.adata: logg.warning( "Ignoring supplied `adata` object because it differs from the kernel's `adata` object." ) # use `str` rather than `repr` because it captures the parameters origin = str(transition_matrix).strip("~<>") params = transition_matrix.params.copy() backward = transition_matrix.backward adata = transition_matrix.adata transition_matrix = transition_matrix.transition_matrix if not isinstance(transition_matrix, (np.ndarray, spmatrix)): raise TypeError( f"Expected transition matrix to be of type `numpy.ndarray` or `scipy.sparse.spmatrix`, " f"found `{type(transition_matrix).__name__}`." ) if transition_matrix.shape[0] != transition_matrix.shape[1]: raise ValueError( f"Expected transition matrix to be square, found `{transition_matrix.shape}`." ) if not np.allclose(np.sum(transition_matrix, axis=1), 1.0, rtol=_RTOL): raise ValueError("Not a valid transition matrix, not all rows sum to 1.") if adata is None: logg.warning("Creating empty `AnnData` object") adata = AnnData( csr_matrix((transition_matrix.shape[0], 1), dtype=np.float32) ) super().__init__( adata, backward=backward, compute_cond_num=compute_cond_num, **kwargs ) self._transition_matrix = csr_matrix(transition_matrix) self._maybe_compute_cond_num() self._params = params self._origin = origin def _read_from_adata(self, **kwargs: Any) -> None: self._conn = None
[docs] @d.dedent def copy(self) -> "PrecomputedKernel": """%(copy)s""" # noqa pk = PrecomputedKernel( copy(self.transition_matrix), adata=self.adata, backward=self.backward, compute_cond_num=False, ) pk._cond_num = self.condition_number pk._origin = self._origin pk._params = self._params.copy() return pk
[docs] def compute_transition_matrix( self, *args: Any, **kwargs: Any ) -> "PrecomputedKernel": """Return self.""" return self
def __invert__(self) -> "PrecomputedKernel": # do not call parent's invert, since it removes the transition matrix self._backward = not self.backward return self def __repr__(self): return ( f"{'~' if self.backward and self._parent is None else ''}" f"<{self.__class__.__name__}[origin={self._origin}]>" ) def __str__(self): return repr(self)