Source code for cellrank.kernels._base_kernel

import abc
import copy
import pathlib
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import scipy.sparse as sp

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

from anndata import AnnData

from cellrank import logging as logg
from cellrank._utils._docs import d, inject_docs
from cellrank._utils._utils import _normalize, _read_graph_data, save_fig
from cellrank.kernels._utils import require_tmat
from cellrank.kernels.mixins import BidirectionalMixin, IOMixin, UnidirectionalMixin
from cellrank.kernels.utils import FlowPlotter, RandomWalk, TmatProjection
from cellrank.kernels.utils._random_walk import Indices_t

__all__ = ["Kernel", "UnidirectionalKernel", "BidirectionalKernel"]

Tmat_t = Union[np.ndarray, sp.spmatrix]


class KernelExpression(IOMixin, abc.ABC):
    def __init__(
        self,
        parent: Optional["KernelExpression"] = None,
        **kwargs: Any,
    ):
        super().__init__()
        self._parent = parent
        self._normalize = parent is None
        self._transition_matrix = None
        self._params: Dict[str, Any] = {}
        self._init_kwargs = kwargs  # for `_read_from_adata`

    def __init_subclass__(cls, **_: Any) -> None:
        super().__init_subclass__()

    @abc.abstractmethod
    def compute_transition_matrix(self, *args: Any, **kwargs: Any) -> "KernelExpression":
        """Compute transition matrix.

        Parameters
        ----------
        args
            Positional arguments.
        kwargs
            Keyword arguments.

        Returns
        -------
        Modifies :attr:`transition_matrix` and returns self.
        """

    @abc.abstractmethod
    def copy(self, *, deep: bool = False) -> "KernelExpression":
        """Return a copy of self. The :attr:`adata` object is not copied."""

    @property
    @abc.abstractmethod
    def adata(self) -> AnnData:
        """Annotated data object."""

    @adata.setter
    @abc.abstractmethod
    def adata(self, value: Optional[AnnData]) -> None:
        pass

    @property
    @abc.abstractmethod
    def kernels(self) -> Tuple["KernelExpression", ...]:
        """Underlying base kernels."""

    @property
    @abc.abstractmethod
    def shape(self) -> Tuple[int, int]:
        """``(n_cells, n_cells)``."""

    @property
    @abc.abstractmethod
    def backward(self) -> Optional[bool]:
        """Direction of the process."""

    @abc.abstractmethod
    def __getitem__(self, ix: int) -> "KernelExpression":
        pass

    @abc.abstractmethod
    def __len__(self) -> int:
        pass

    @d.get_full_description(base="plot_single_flow")
    @d.get_sections(base="plot_single_flow", sections=["Parameters", "Returns"])
    @d.dedent
    @require_tmat
    def plot_single_flow(
        self,
        cluster: str,
        cluster_key: str,
        time_key: str,
        clusters: Optional[Sequence[Any]] = None,
        time_points: Optional[Sequence[Union[int, float]]] = None,
        min_flow: float = 0,
        remove_empty_clusters: bool = True,
        ascending: Optional[bool] = False,
        legend_loc: Optional[str] = "upper right out",
        alpha: Optional[float] = 0.8,
        xticks_step_size: Optional[int] = 1,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, pathlib.Path]] = None,
        show: bool = True,
    ) -> Optional[plt.Axes]:
        """Visualize outgoing flow from a cluster of cells :cite:`mittnenzweig:21`.

        Parameters
        ----------
        cluster
            Cluster for which to visualize outgoing flow.
        cluster_key
            Key in :attr:`~anndata.AnnData.obs` where clustering is stored.
        time_key
            Key in :attr:`~anndata.AnnData.obs` where experimental time is stored.
        clusters
            Visualize flow only for these clusters. If :obj:`None`, use all clusters.
        time_points
            Visualize flow only for these time points. If :obj:`None`, use all time points.
        %(flow.parameters)s
        %(plotting)s
        show
            If :obj:`False`, return :class:`~matplotlib.axes.Axes`.

        Returns
        -------
        The axes object, if ``show = False``.
        %(just_plots)s

        Notes
        -----
        This function is a Python re-implementation of the following
        `original R function <https://github.com/tanaylab/embflow/blob/main/scripts/generate_paper_figures/plot_vein.r>`_
        with some minor stylistic differences.
        This function will not recreate the results from :cite:`mittnenzweig:21`, because there, the *Metacell* model
        :cite:`baran:19` was used to compute the flow, whereas here the transition matrix is used.
        """  # noqa: E501
        fp = FlowPlotter(self.adata, self.transition_matrix, cluster_key, time_key)
        fp = fp.prepare(cluster, clusters, time_points)

        ax = fp.plot(
            min_flow=min_flow,
            remove_empty_clusters=remove_empty_clusters,
            ascending=ascending,
            alpha=alpha,
            xticks_step_size=xticks_step_size,
            legend_loc=legend_loc,
            figsize=figsize,
            dpi=dpi,
        )

        if save is not None:
            save_fig(ax.figure, save)

        if not show:
            return ax

    @d.dedent
    @require_tmat
    def plot_random_walks(
        self,
        n_sims: int = 100,
        max_iter: Union[int, float] = 0.25,
        seed: Optional[int] = None,
        successive_hits: int = 0,
        start_ixs: Indices_t = None,
        stop_ixs: Indices_t = None,
        basis: str = "umap",
        cmap: Union[str, LinearSegmentedColormap] = "gnuplot",
        linewidth: float = 1.0,
        linealpha: float = 0.3,
        ixs_legend_loc: Optional[str] = None,
        n_jobs: Optional[int] = None,
        backend: str = "loky",
        show_progress_bar: bool = True,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, pathlib.Path]] = None,
        **kwargs: Any,
    ) -> None:
        """Plot random walks in an embedding.

        This method simulates random walks on the Markov chain defined though the corresponding transition matrix. The
        method is intended to give qualitative rather than quantitative insights into the transition matrix. Random
        walks are simulated by iteratively choosing the next cell based on the current cell's transition probabilities.

        Parameters
        ----------
        n_sims
            Number of random walks to simulate.
        %(rw_sim.parameters)s
        start_ixs
            Cells from which to sample the starting points. If :obj:`None`, use all cells.
            %(rw_ixs)s
            For example ``{'dpt_pseudotime': [0, 0.1]}`` means that starting points for random walks
            will be sampled uniformly from cells whose pseudotime is in :math:`[0, 0.1]`.
        stop_ixs
            Cells which when hit, the random walk is terminated. If :obj:`None`, terminate after ``max_iters``.
            %(rw_ixs)s
            For example ``{'clusters': ['Alpha', 'Beta']}`` and ``successive_hits = 3`` means that the random walk will
            stop prematurely after cells in the above specified clusters have been visited successively 3 times in
            a row.
        basis
            Basis in :attr:`~anndata.AnnData.obsm` to use as an embedding.
        cmap
            Colormap for the random walk lines.
        linewidth
            Width of the random walk lines.
        linealpha
            Alpha value of the random walk lines.
        ixs_legend_loc
            Legend location for the start/top indices.
        %(parallel)s
        %(plotting)s
        kwargs
            Keyword arguments for :func:`~scvelo.pl.scatter`.

        Returns
        -------
        %(just_plots)s
        For each random walk, the first/last cell is marked by the start/end colors of ``cmap``.
        """
        rw = RandomWalk(self.adata, self.transition_matrix, start_ixs=start_ixs, stop_ixs=stop_ixs)
        sims = rw.simulate_many(
            n_sims=n_sims,
            max_iter=max_iter,
            seed=seed,
            n_jobs=n_jobs,
            backend=backend,
            successive_hits=successive_hits,
            show_progress_bar=show_progress_bar,
        )

        rw.plot(
            sims,
            basis=basis,
            cmap=cmap,
            linewidth=linewidth,
            linealpha=linealpha,
            ixs_legend_loc=ixs_legend_loc,
            figsize=figsize,
            dpi=dpi,
            save=save,
            **kwargs,
        )

    @require_tmat
    def plot_projection(
        self,
        basis: str = "umap",
        key_added: Optional[str] = None,
        recompute: bool = False,
        stream: bool = True,
        connectivities: Optional[sp.spmatrix] = None,
        **kwargs: Any,
    ) -> None:
        """Plot :attr:`transition_matrix` as a stream or a grid plot.

        Parameters
        ----------
        basis
            Key in :attr:`~anndata.AnnData.obsm` containing the basis.
        key_added
            If not :obj:`None`, save the result to :attr:`adata.obsm['{key_added}'] <anndata.AnnData.obsm>`.
            Otherwise, save the result to ``'T_fwd_{basis}'`` or ``'T_bwd_{basis}'``, depending on the direction.
        recompute
            Whether to recompute the projection if it already exists.
        stream
            If :obj:`True`, use :func:`~scvelo.pl.velocity_embedding_stream`.
            Otherwise, use :func:`~scvelo.pl.velocity_embedding_grid`.
        connectivities
            Connectivity matrix to use for projection. If :obj:`None`, use ones from the underlying kernel, is possible.
        kwargs
            Keyword argument for the above-mentioned plotting function.

        Returns
        -------
        Nothing, just plots and modifies :attr:`~anndata.AnnData.obsm` with a key based on the ``key_added``.
        """
        proj = TmatProjection(self, basis=basis)
        proj.project(key_added=key_added, recompute=recompute, connectivities=connectivities)
        proj.plot(stream=stream, **kwargs)

    def __add__(self, other: "KernelExpression") -> "KernelExpression":
        return self.__radd__(other)

    def __radd__(self, other: "KernelExpression") -> "KernelExpression":
        def same_level_add(k1: "KernelExpression", k2: "KernelExpression") -> bool:
            if not (isinstance(k1, KernelAdd) and isinstance(k2, KernelMul)):
                return False

            for kexpr in k1:
                if not isinstance(kexpr, KernelMul):
                    return False
                if not kexpr._bin_consts:
                    return False
            return True

        if not isinstance(other, KernelExpression):
            return NotImplemented

        s = self * 1.0 if isinstance(self, Kernel) else self
        o = other * 1.0 if isinstance(other, Kernel) else other

        # (c1 * x + c2 * y + ...) + (c3 * z) => (c1 * x + c2 * y + c3 + ... + * z)
        if same_level_add(s, o):
            return KernelAdd(*tuple(s) + (o,))
        if same_level_add(o, s):
            return KernelAdd(*tuple(o) + (s,))

        # add virtual constant
        if not isinstance(s, KernelMul):
            s = s * 1.0
        if not isinstance(o, KernelMul):
            o = o * 1.0

        return KernelAdd(s, o)

    def __mul__(self, other: Union[float, int, "KernelExpression"]) -> "KernelExpression":
        return self.__rmul__(other)

    def __rmul__(self, other: Union[int, float, "KernelExpression"]) -> "KernelExpression":
        def same_level_mul(k1: "KernelExpression", k2: "KernelExpression") -> bool:
            return isinstance(k1, KernelMul) and isinstance(k2, Constant) and k1._bin_consts

        if isinstance(other, (int, float, np.integer, np.floating)):
            other = Constant(self.adata, other)

        if not isinstance(other, KernelExpression):
            return NotImplemented

        # fmt: off
        s = self if isinstance(self, (KernelMul, Constant)) else KernelMul(Constant(self.adata, 1.0), self)
        o = other if isinstance(other, (KernelMul, Constant)) else KernelMul(Constant(other.adata, 1.0), other)
        # fmt: on

        # at this point, only KernelMul and Constant is possible (two constants are not possible)
        # (c1 * k) * c2 => (c1 * c2) * k
        if same_level_mul(s, o):
            c, expr = s._split_const
            return KernelMul(c * o, expr)
        if same_level_mul(o, s):
            c, expr = o._split_const
            return KernelMul(c * s, expr)

        return KernelMul(s, o)

    @d.get_sections(base="write_to_adata", sections=["Parameters"])
    @inject_docs()  # gets rid of {{}} in %(write_to_adata)s
    @d.dedent
    @require_tmat
    def write_to_adata(self, key: Optional[str] = None, copy: bool = False) -> None:
        """Write the transition matrix and parameters used for computation to the underlying :attr:`adata` object.

        Parameters
        ----------
        key
            Key used when writing transition matrix to :attr:`adata`.
            If :obj:`None`, the key will be determined automatically.
        copy
            Whether to copy the :attr:`transition_matrix`.

        Returns
        -------
        %(write_to_adata)s
        """
        from cellrank._utils._key import Key

        if self.adata is None:
            raise ValueError("Underlying annotated data object is not set.")

        key = Key.uns.kernel(self.backward, key=key)
        # retain the embedding info
        self.adata.uns[f"{key}_params"] = {
            **self.adata.uns.get(f"{key}_params", {}),
            **{"params": self.params},
            **{"init": self._init_kwargs},
        }
        self.adata.obsp[key] = self.transition_matrix.copy() if copy else self.transition_matrix

    @property
    def transition_matrix(self) -> Union[np.ndarray, sp.csr_matrix]:
        """Row-normalized transition matrix."""
        if self._parent is None and self._transition_matrix is None:
            self.compute_transition_matrix()
        return self._transition_matrix

    @transition_matrix.setter
    def transition_matrix(self, matrix: Tmat_t) -> None:
        """Set the transition matrix.

        Parameters
        ----------
        matrix
            Transition matrix. The matrix is row-normalized if necessary.

        Returns
        -------
        Nothing, just updates the :attr:`transition_matrix` and optionally normalizes it.
        """
        # fmt: off
        if matrix.shape != self.shape:
            raise ValueError(
                f"Expected matrix to be of shape `{self.shape}`, found `{matrix.shape}`."
            )

        def should_norm(mat: Tmat_t) -> bool:
            return not np.all(np.isclose(np.asarray(mat.sum(1)).squeeze(), 1.0, rtol=1e-12))

        if sp.issparse(matrix) and not sp.isspmatrix_csr(matrix):
            matrix = sp.csr_matrix(matrix)
        matrix = matrix.astype(np.float64, copy=False)

        force_normalize = (self._parent is None or self._normalize) and should_norm(matrix)
        if force_normalize:
            if np.any((matrix.data if sp.issparse(matrix) else matrix) < 0):
                raise ValueError("Unable to normalize matrix with negative values.")
            matrix = _normalize(matrix)
            if should_norm(matrix):  # some rows are all 0s/contain invalid values
                n_inv = np.sum(~np.isclose(np.asarray(matrix.sum(1)).squeeze(), 1.0, rtol=1e-12))
                raise ValueError(f"Transition matrix is not row stochastic, {n_inv} rows do not sum to 1.")
        # fmt: on

        self._transition_matrix = matrix

    @property
    def params(self) -> Dict[str, Any]:
        """Parameters which are used to compute the transition matrix."""
        if len(self.kernels) == 1:
            return self._params
        return {f"{k!r}:{i}": k.params for i, k in enumerate(self.kernels)}

    def _reuse_cache(self, expected_params: Dict[str, Any], *, time: Optional[Any] = None) -> bool:
        # fmt: off
        try:
            if expected_params == self._params:
                assert self.transition_matrix is not None
                logg.debug("Using cached transition matrix")
                logg.info("    Finish", time=time)
                return True
            return False
        except AssertionError:
            logg.warning("Transition matrix does not exist for the given parameters. Recomputing")
            return False
        except Exception as e:  # noqa: BLE001
            # e.g. the dict is not comparable
            logg.warning(f"Expected and actually parameters are not comparable, reason `{e}`. Recomputing")
            expected_params = {}  # clear the params
            return False
        finally:
            self._params = expected_params
        # fmt: on

    def _get_boundary(self, source: str, target: str, cluster_key: str, graph_key: str = "distances") -> List[int]:
        """Identify source observations at boundary to target cluster.

        Parameters
        ----------
        source
            Name of source cluster.
        target
            Name of target cluster.
        cluster_key
            Key in :attr:`~anndata.AnnData.obs` to obtain cluster annotations.
        graph_key
            Name of graph representation to use from :attr:`~anndata.AnnData.obsp`.

        Returns
        -------
        List of observation IDs at boundary to target cluster.
        """
        source_obs_mask = self.adata.obs[cluster_key].isin([source] if isinstance(source, str) else source)
        target_obs_mask = self.adata.obs[cluster_key].isin([target] if isinstance(target, str) else target)

        source_ids = np.where(source_obs_mask)[0]
        boundary_ids = []

        graph = self.adata.obsp[graph_key]
        for source_id in source_ids:
            obs_mask = graph[source_id, :].toarray().squeeze().astype(bool)

            if (obs_mask & target_obs_mask).any():
                boundary_ids.append(source_id)

        return boundary_ids

    def _get_empirical_velocity_field(
        self, boundary_ids: List[int], target_obs_mask, rep: str, graph_key: str = "distances"
    ) -> np.ndarray:
        """Compute an emprical estimate of velocity field between two clusters.

        Parameters
        ----------
        boundary_ids
            List of observation IDs at boundary to target cluster.
        target_obs_mask
            Boolean indicator identifying relevant observations from target.
        graph_key
            Name of graph representation to use from :attr:`~anndata.AnnData.obsp`.

        Returns
        -------
        Empirical velocity estimate.
        """
        obs_ids = np.arange(0, self.adata.n_obs)
        graph = self.adata.obsp[graph_key]
        features = self.adata.obsm[rep]
        empirical_velo = np.empty(shape=(len(boundary_ids), features.shape[1]))

        for idx, boundary_id in enumerate(boundary_ids):
            row = graph[boundary_id, :].toarray().squeeze()
            obs_mask = row.astype(bool) & target_obs_mask
            neighbors = obs_ids[obs_mask]
            weights = row[obs_mask]

            empirical_velo[idx, :] = np.sum(
                weights.reshape(-1, 1) * (features[neighbors, :] - features[boundary_id, :]), axis=0
            )

        empirical_velo = np.array(empirical_velo)
        obs_mask = np.isnan(empirical_velo).any(axis=1)
        empirical_velo = empirical_velo[~obs_mask, :]

        return empirical_velo

    def _get_vector_field_estimate(self, rep: str) -> np.ndarray:
        """Compute estimate of vector field under one step of the transition matrix.

        Parameters
        ----------
        rep
            Key in :attr:`~anndata.AnnData.obsm` to use as data representation.

        Returns
        -------
        Vector field estimate based on kernel dynamics.
        """
        extrapolated_gex = self.transition_matrix @ self.adata.obsm[rep]
        return extrapolated_gex - self.adata.obsm[rep]

    # TODO: Add definition/reference to paper
    def cbc(self, source: str, target: str, cluster_key: str, rep: str, graph_key: str = "distances") -> np.ndarray:
        """Compute cross-boundary correctness score between source and target cluster.

        Parameters
        ----------
        source
            Name of the source cluster.
        target
            Name of the target cluster.
        cluster_key
            Key in :attr:`~anndata.AnnData.obs` to obtain cluster annotations.
        rep
            Key in :attr:`~anndata.AnnData.obsm` to use as data representation.
        graph_key
            Name of graph representation to use from :attr:`~anndata.AnnData.obsp`.

        Returns
        -------
        Cross-boundary correctness score for each observation.
        """

        def _pearsonr(x: np.ndarray, y: np.ndarray) -> np.ndarray:
            x_centered = x - np.mean(x, axis=1, keepdims=True)
            y_centered = y - np.mean(y, axis=1, keepdims=True)
            denom = np.linalg.norm(x_centered, axis=1) * np.linalg.norm(y_centered, axis=1)

            return np.sum(x_centered * y_centered, axis=1) / denom

        target_obs_mask = self.adata.obs[cluster_key].isin([target])
        boundary_ids = self._get_boundary(source=source, target=target, cluster_key=cluster_key, graph_key=graph_key)
        empirical_velo = self._get_empirical_velocity_field(
            boundary_ids=boundary_ids, target_obs_mask=target_obs_mask, rep=rep, graph_key=graph_key
        )
        estimated_velo = self._get_vector_field_estimate(rep=rep)[boundary_ids, :]

        return _pearsonr(x=estimated_velo, y=empirical_velo)


[docs] @d.dedent class Kernel(KernelExpression, abc.ABC): """Base kernel class. Parameters ---------- %(adata)s parent Parent kernel expression. kwargs Keyword arguments for the parent. """ def __init__(self, adata: AnnData, parent: Optional[KernelExpression] = None, **kwargs: Any): super().__init__(parent=parent, **kwargs) self._adata = adata self._n_obs = adata.n_obs self._read_from_adata(**kwargs) def _read_from_adata(self, **kwargs: Any) -> None: pass @property def adata(self) -> AnnData: """Annotated data object.""" return self._adata @adata.setter def adata(self, adata: Optional[AnnData]) -> None: if adata is None: self._adata = None return if not isinstance(adata, AnnData): raise TypeError(f"Expected `adata` to be of type `AnnData`, found `{type(adata).__name__}`.") shape = (adata.n_obs, adata.n_obs) if self.shape != shape: raise ValueError( f"Expected new `AnnData` object to have same shape as the previous `{self.shape}`, found `{shape}`." ) self._adata = adata
[docs] @classmethod @d.dedent def from_adata( cls, adata: AnnData, key: str, copy: bool = False, ) -> "Kernel": """Read the kernel saved using :meth:`write_to_adata`. Parameters ---------- %(adata)s key Key in :attr:`~anndata.AnnData.obsp` where the transition matrix is stored. The parameters should be stored in :attr:`adata.uns['{key}_params'] <anndata.AnnData.uns>`. copy Whether to copy the transition matrix. Returns ------- The kernel with explicitly initialized properties: - :attr:`transition_matrix` - the transition matrix. - :attr:`params` - parameters used for computation. """ transition_matrix = _read_graph_data(adata, key=key) try: params = adata.uns[f"{key}_params"]["params"].copy() init_params = adata.uns[f"{key}_params"]["init"].copy() except KeyError as e: raise KeyError(f"Unable to kernel parameters, reason: `{e}`") from e if copy: transition_matrix = transition_matrix.copy() kernel = cls(adata, **init_params) kernel.transition_matrix = transition_matrix kernel._params = params return kernel
[docs] def copy(self, *, deep: bool = False) -> "Kernel": """Return a copy of self. Parameters ---------- deep Whether to use :func:`~copy.deepcopy`. Returns ------- Copy of self. """ with self._remove_adata: k = copy.deepcopy(self) k.adata = self.adata.copy() if deep else self.adata return k
def _copy_ignore(self, *attrs: str) -> "Kernel": # prevent copying attributes that are not necessary, e.g. during inversion sentinel, attrs = object(), set(attrs) objects = [ (attr, obj) for attr, obj in ((attr, getattr(self, attr, sentinel)) for attr in attrs) if obj is not sentinel ] try: for attr, _ in objects: setattr(self, attr, None) return self.copy(deep=False) finally: for attr, obj in objects: setattr(self, attr, obj) def _format_params(self) -> str: n, _ = self.shape params = ", ".join(f"{k}={round(v, 3) if isinstance(v, float) else v!r}" for k, v in self.params.items()) return f"n={n}, {params}" if params else f"n={n}" @property def transition_matrix(self) -> Union[np.ndarray, sp.csr_matrix]: """Row-normalized transition matrix.""" return self._transition_matrix @transition_matrix.setter def transition_matrix(self, matrix: Any) -> None: KernelExpression.transition_matrix.fset(self, matrix) @property def kernels(self) -> Tuple["KernelExpression", ...]: """Underlying base kernels.""" return (self,) @property def shape(self) -> Tuple[int, int]: """``(n_cells, n_cells)``.""" return self._n_obs, self._n_obs def __getitem__(self, ix: int) -> "Kernel": if ix != 0: raise IndexError(ix) return self def __len__(self) -> int: return 1 def __repr__(self) -> str: params = self._format_params() prefix = "~" if self.backward and self._parent is None else "" return f"{prefix}{self.__class__.__name__}[{params}]" def __str__(self) -> str: prefix = "~" if self.backward and self._parent is None else "" return f"{prefix}{self.__class__.__name__}[n={self.shape[0]}]"
[docs] class UnidirectionalKernel(UnidirectionalMixin, Kernel, abc.ABC): pass
[docs] class BidirectionalKernel(BidirectionalMixin, Kernel, abc.ABC): pass
class Constant(UnidirectionalKernel): def __init__(self, adata: AnnData, value: Union[int, float]): super().__init__(adata) self.transition_matrix = value def compute_transition_matrix(self, value: Union[int, float]) -> "Constant": self.transition_matrix = value return self @property def transition_matrix(self) -> Union[int, float]: return self._transition_matrix @transition_matrix.setter def transition_matrix(self, value: Union[int, float]) -> None: if not isinstance(value, (int, float, np.integer, np.floating)): raise TypeError(f"Value must be a `float` or `int`, found `{type(value).__name__}`.") if value <= 0: raise ValueError(f"Expected the scalar to be positive, found `{value}`.") self._transition_matrix = value self._params = {"value": value} def copy(self, *, deep: bool = False) -> "Constant": return Constant(self.adata, self.transition_matrix) # fmt: off def __radd__(self, other: Union[int, float, "KernelExpression"]) -> "Constant": if isinstance(other, (int, float, np.integer, np.floating)): other = Constant(self.adata, other) if isinstance(other, Constant): if self.shape != other.shape: raise ValueError(f"Expected kernel shape to be `{self.shape}`, found `{other.shape}`.") return Constant(self.adata, self.transition_matrix + other.transition_matrix) return super().__radd__(other) def __rmul__(self, other: Union[int, float, "KernelExpression"]) -> "Constant": if isinstance(other, (int, float, np.integer, np.floating)): other = Constant(self.adata, other) if isinstance(other, Constant): if self.shape != other.shape: raise ValueError(f"Expected kernel shape to be `{self.shape}`, found `{other.shape}`.") return Constant(self.adata, self.transition_matrix * other.transition_matrix) return super().__rmul__(other) # fmt: on def __repr__(self) -> str: return repr(round(self.transition_matrix, 3)) def __str__(self) -> str: return str(round(self.transition_matrix, 3)) class NaryKernelExpression(BidirectionalMixin, KernelExpression): def __init__(self, *kexprs: KernelExpression, parent: Optional[KernelExpression] = None): super().__init__(parent=parent) self._validate(kexprs) @abc.abstractmethod def _combine_transition_matrices(self, t1: Tmat_t, t2: Tmat_t) -> Tmat_t: pass @property @abc.abstractmethod def _initial_value(self) -> Union[int, float, Tmat_t]: pass @property @abc.abstractmethod def _combiner(self) -> str: pass def compute_transition_matrix(self) -> "NaryKernelExpression": for kexpr in self: if kexpr.transition_matrix is None: if isinstance(kexpr, Kernel): raise RuntimeError( f"`{kexpr}` is uninitialized. Compute its " f"transition matrix first as `.compute_transition_matrix()`." ) kexpr.compute_transition_matrix() tmat = self._initial_value for kexpr in self: tmat = self._combine_transition_matrices(tmat, kexpr.transition_matrix) self.transition_matrix = tmat return self def _validate(self, kexprs: Sequence[KernelExpression]) -> None: if not len(kexprs): raise ValueError("No kernels to combined.") shapes = {kexpr.shape for kexpr in kexprs} if len(shapes) > 1: raise ValueError(f"Expected all kernels to have the same shapes, found `{sorted(shapes)}`.") directions = {kexpr.backward for kexpr in kexprs} if True in directions and False in directions: raise ValueError("Unable to combine both forward and backward kernels.") self._backward = True if True in directions else False if False in directions else None self._kexprs = kexprs for kexpr in self: kexpr._parent = self @property def adata(self) -> AnnData: """Annotated data object.""" return self[0].adata @adata.setter def adata(self, adata: Optional[AnnData]) -> None: # allow resetting (use for temp. pickling without adata) for kexpr in self: kexpr.adata = adata @property def kernels(self) -> Tuple["KernelExpression", ...]: """Underlying unique basic kernels.""" kernels = [] for kexpr in self: if isinstance(kexpr, Kernel) and not isinstance(kexpr, Constant): kernels.append(kexpr) elif isinstance(kexpr, NaryKernelExpression): # recurse kernels.extend(kexpr.kernels) # return only unique kernels return tuple(set(kernels)) def copy(self, *, deep: bool = False) -> "KernelExpression": kexprs = (k.copy(deep=deep) for k in self) return type(self)(*kexprs, parent=self._parent) @property def shape(self) -> Tuple[int, int]: """``(n_cells, n_cells)``.""" # all kernels have the same shape return self[0].shape def _format(self, formatter: Callable[[KernelExpression], str]) -> str: return ( f"{'~' if self.backward and self._parent is None else ''}(" + f" {self._combiner} ".join(formatter(kexpr) for kexpr in self) + ")" ) def __getitem__(self, ix: int) -> "KernelExpression": return self._kexprs[ix] def __len__(self) -> int: return len(self._kexprs) def __invert__(self) -> "KernelExpression": if self.backward is None: return self.copy() kexprs = tuple(~k if isinstance(k, BidirectionalMixin) else k.copy() for k in self) kexpr = type(self)(*kexprs, parent=self._parent) kexpr._transition_matrix = None return kexpr def __repr__(self) -> str: return self._format(repr) def __str__(self) -> str: return self._format(str) class KernelAdd(NaryKernelExpression): def compute_transition_matrix(self) -> "KernelAdd": self._maybe_recalculate_constants() return super().compute_transition_matrix() def _combine_transition_matrices(self, t1: Tmat_t, t2: Tmat_t) -> Tmat_t: return t1 + t2 def _maybe_recalculate_constants(self) -> None: """Normalize constants to sum to 1.""" constants = self._bin_consts if constants: total = sum((c.transition_matrix for c in constants), 0.0) for c in constants: c.transition_matrix = c.transition_matrix / total for kexpr in self: # don't normalize (c * x) kexpr._normalize = False @property def _bin_consts(self) -> List[KernelExpression]: """Return constant expressions for each binary multiplication children with at least 1 constant.""" return [c for k in self if isinstance(k, KernelMul) for c in k._bin_consts] @property def _combiner(self) -> str: return "+" @property def _initial_value(self) -> Union[int, float, Tmat_t]: return 0.0 class KernelMul(NaryKernelExpression): def _combine_transition_matrices(self, t1: Tmat_t, t2: Tmat_t) -> Tmat_t: if sp.issparse(t1): return t1.multiply(t2) if sp.issparse(t2): return t2.multiply(t1) return t1 * t2 def _format(self, formatter: Callable[[KernelExpression], str]) -> str: fmt = super()._format(formatter) if fmt[0] == "~" or not self._bin_consts: return fmt if fmt.startswith("("): fmt = fmt[1:] if fmt.endswith(")"): fmt = fmt[:-1] return fmt @property def _bin_consts(self) -> List[KernelExpression]: """Return all constants if this expression contains only 2 subexpressions.""" if len(self) != 2: return [] return [k for k in self if isinstance(k, Constant)] @property def _split_const(self) -> Tuple[Optional[Constant], Optional[KernelExpression]]: """Return a constant and the other expression, iff this expression is of length 2 and contains a constant.""" if not self._bin_consts: return None, None k1, k2 = self if isinstance(k1, Constant): return k1, k2 return k2, k1 @property def _combiner(self) -> str: return "*" @property def _initial_value(self) -> Union[int, float, Tmat_t]: return 1.0