Source code for

"""Precomputed kernel module."""
from copy import copy
from typing import Union, Optional

from anndata import AnnData

import numpy as np
from scipy.sparse import eye as speye
from scipy.sparse import spmatrix, csr_matrix

from cellrank import logging as logg
from cellrank.ul._docs import d
from cellrank.ul._utils import _read_graph_data
from import Kernel
from import Direction, _transition
from import _RTOL, KernelExpression

[docs]@d.dedent class PrecomputedKernel(Kernel): """ Kernel which contains a precomputed transition matrix. Parameters ---------- transition_matrix Row-normalized transition matrix or a key in :paramref:`adata` ``.obsp`` or a :class:`` with the computed transition matrix. If `None`, try to determine the key based on ``backward``. %(adata)s %(backward)s %(cond_num)s """ def __init__( self, transition_matrix: Optional[ Union[np.ndarray, spmatrix, KernelExpression, str] ] = None, adata: Optional[AnnData] = None, backward: bool = False, compute_cond_num: bool = False, ): from anndata import AnnData as _AnnData self._origin = "'array'" params = {} if transition_matrix is None: transition_matrix = _transition( Direction.BACKWARD if backward else Direction.FORWARD ) logg.debug(f"Setting transition matrix key to `{transition_matrix!r}`") if isinstance(transition_matrix, str): if adata is None: raise ValueError( "When `transition_matrix` specifies a key to `adata.obsp`, `adata` cannot be None." ) self._origin = f"adata.obsp[{transition_matrix!r}]" transition_matrix = _read_graph_data(adata, transition_matrix) elif isinstance(transition_matrix, KernelExpression): if transition_matrix._transition_matrix is None: raise ValueError( "Compute transition matrix first as `.compute_transition_matrix()`." ) if adata is not None and adata is not transition_matrix.adata: logg.warning( "Ignoring supplied `adata` object because it differs from the kernel's `adata` object." ) # use `str` because it captures the params self._origin = str(transition_matrix).strip("~<>") params = transition_matrix.params.copy() backward = transition_matrix.backward adata = transition_matrix.adata transition_matrix = transition_matrix.transition_matrix if not isinstance(transition_matrix, (np.ndarray, spmatrix)): raise TypeError( f"Expected transition matrix to be of type `numpy.ndarray` or `scipy.sparse.spmatrix`, " f"found `{type(transition_matrix).__name__!r}`." ) if transition_matrix.shape[0] != transition_matrix.shape[1]: raise ValueError( f"Expected transition matrix to be square, found `{transition_matrix.shape}`." ) if not np.allclose(np.sum(transition_matrix, axis=1), 1.0, rtol=_RTOL): raise ValueError("Not a valid transition matrix, not all rows sum to 1") if adata is None: logg.warning("Creating empty `AnnData` object") adata = _AnnData( csr_matrix((transition_matrix.shape[0], 1), dtype=np.float32) ) super().__init__(adata, backward=backward, compute_cond_num=compute_cond_num) self._params = params self._transition_matrix = csr_matrix(transition_matrix) self._maybe_compute_cond_num() def _read_from_adata(self, **kwargs): pass
[docs] @d.dedent def copy(self) -> "PrecomputedKernel": """%(copy)s""" # noqa pk = PrecomputedKernel( copy(self.transition_matrix), adata=self.adata, backward=self.backward, compute_cond_num=False, ) pk._cond_num = self.condition_number pk._origin = self._origin pk._params = self._params.copy() return pk
[docs] def compute_transition_matrix(self, *args, **kwargs) -> "PrecomputedKernel": """Return self.""" return self
def __invert__(self) -> "PrecomputedKernel": # do not call parent's invert, since it removes the transition matrix self._direction = Direction.FORWARD if self.backward else Direction.BACKWARD return self def __repr__(self): return ( f"{'~' if self.backward and self._parent is None else ''}" f"<{self.__class__.__name__}[origin={self._origin}]>" ) def __str__(self): return repr(self)
@d.dedent class DummyKernel(PrecomputedKernel): """ Kernel with 1s on the diagonal. Parameters ---------- %(adata)s %(backward)s """ def __init__( self, adata: AnnData, backward: bool = False, ): super().__init__( speye(adata.n_obs, format="csr"), adata, backward=backward, compute_cond_num=False, )