Source code for

# -*- coding: utf-8 -*-
"""Precomputed kernel module."""
from copy import copy
from typing import Union, Optional

import numpy as np
from scipy.sparse import eye as speye
from scipy.sparse import spmatrix, csr_matrix

from cellrank import logging as logg
from cellrank.ul._docs import d
from cellrank.ul._utils import _read_graph_data
from import Kernel
from import Direction, _transition
from import _RTOL, AnnData, KernelExpression

[docs]@d.dedent class PrecomputedKernel(Kernel): """ Kernel which contains a precomputed transition matrix. Parameters ---------- transition_matrix Row-normalized transition matrix or a key in :paramref:`adata` ``.obsp`` or a :class:`` with the computed transition matrix. If `None`, try to determine the key based on ``backward``. %(adata)s %(backward)s """ def __init__( self, transition_matrix: Optional[ Union[np.ndarray, spmatrix, KernelExpression, str] ] = None, adata: Optional[AnnData] = None, backward: bool = False, compute_cond_num: bool = False, ): from anndata import AnnData as _AnnData self._origin = "'array'" params = {} if transition_matrix is None: transition_matrix = _transition( Direction.BACKWARD if backward else Direction.FORWARD ) logg.debug(f"Setting transition matrix key to `{transition_matrix!r}`") if isinstance(transition_matrix, str): if adata is None: raise ValueError( "When `transition_matrix` specifies a key to `adata.obsp`, `adata` cannot be None." ) self._origin = f"adata.obsp[{transition_matrix!r}]" transition_matrix = _read_graph_data(adata, transition_matrix) elif isinstance(transition_matrix, KernelExpression): if transition_matrix._transition_matrix is None: raise ValueError( "Compute transition matrix first as `.compute_transition_matrix()`." ) if adata is not None and adata is not transition_matrix.adata: logg.warning( "Ignoring supplied `adata` object because it differs from the kernel's `adata` object." ) # use `str` because it captures the params self._origin = str(transition_matrix).strip("~<>") params = transition_matrix.params.copy() backward = transition_matrix.backward adata = transition_matrix.adata transition_matrix = transition_matrix.transition_matrix if not isinstance(transition_matrix, (np.ndarray, spmatrix)): raise TypeError( f"Expected transition matrix to be of type `numpy.ndarray` or `scipy.sparse.spmatrix`, " f"found `{type(transition_matrix).__name__!r}`." ) if transition_matrix.shape[0] != transition_matrix.shape[1]: raise ValueError( f"Expected transition matrix to be square, found `{transition_matrix.shape}`." ) if not np.allclose(np.sum(transition_matrix, axis=1), 1.0, rtol=_RTOL): raise ValueError("Not a valid transition matrix, not all rows sum to 1") if adata is None: logg.warning("Creating empty `AnnData` object") adata = _AnnData( csr_matrix((transition_matrix.shape[0], 1), dtype=np.float32) ) super().__init__(adata, backward=backward, compute_cond_num=compute_cond_num) self._params = params self._transition_matrix = csr_matrix(transition_matrix) self._maybe_compute_cond_num() def _read_from_adata(self, **kwargs): pass
[docs] @d.dedent def copy(self) -> "PrecomputedKernel": """%(copy)s""" # noqa pk = PrecomputedKernel( copy(self.transition_matrix), adata=self.adata, backward=self.backward, compute_cond_num=False, ) pk._cond_num = self.condition_number pk._origin = self._origin pk._params = self._params.copy() return pk
[docs] def compute_transition_matrix(self, *args, **kwargs) -> "PrecomputedKernel": """Return self.""" return self
def __invert__(self) -> "PrecomputedKernel": # do not call parent's invert, since it removes the transition matrix self._direction = Direction.FORWARD if self.backward else Direction.BACKWARD return self def __repr__(self): return f"{'~' if self.backward and self._parent is None else ''}<Precomputed[origin={self._origin}]>" def __str__(self): return repr(self)
@d.dedent class DummyKernel(PrecomputedKernel): """ Kernel with 1s on the diagonal. Parameters ---------- %(adata)s %(backward)s """ def __init__( self, adata: AnnData, backward: bool = False, ): super().__init__( speye(adata.n_obs, format="csr"), adata, backward=backward, compute_cond_num=False, )