Source code for cellrank.kernels.utils._similarity

import abc
import enum
import functools
from typing import Any, Tuple

import numba as nb
import numpy as np

from cellrank._utils._docs import d
from cellrank._utils._enum import ModeEnum
from cellrank.kernels._utils import jit_kwargs, norm, np_mean

__all__ = ["DotProduct", "Cosine", "Correlation", "SimilarityABC"]


class Similarity(ModeEnum):
    DOT_PRODUCT = enum.auto()
    COSINE = enum.auto()
    CORRELATION = enum.auto()


try:
    import jax.numpy as jnp
    from jax import hessian, jit

    @jit
    def _softmax_jax(x: np.ndarray, softmax_scale) -> np.ndarray:
        numerator = x * softmax_scale
        numerator = jnp.exp(numerator - jnp.max(numerator))
        return numerator / jnp.sum(numerator)

    @jit
    def _softmax_masked_jax(x: np.ndarray, mask: np.ndarray, softmax_scale) -> np.ndarray:
        numerator = x * softmax_scale
        numerator = jnp.exp(numerator - jnp.nanmax(numerator))
        numerator = jnp.where(mask, 0, numerator)  # essential

        return numerator / jnp.nansum(numerator)

    @functools.partial(jit, static_argnums=(3, 4))
    def _predict_transition_probabilities_jax(
        X: np.ndarray,
        W: np.ndarray,
        softmax_scale: float = 1.0,
        center_mean: bool = True,
        scale_by_norm: bool = True,
    ) -> np.ndarray:
        if center_mean:
            # pearson correlation, otherwise cosine
            W -= W.mean(axis=1)[:, None]
            X -= X.mean()

        if scale_by_norm:
            denom = jnp.linalg.norm(X) * jnp.linalg.norm(W, axis=1)
            mask = jnp.isclose(denom, 0)
            denom = jnp.where(jnp.isclose(denom, 0), 1, denom)  # essential
            return _softmax_masked_jax(W.dot(X) / denom, mask, softmax_scale)

        return _softmax_jax(W.dot(X), softmax_scale)

    _predict_transition_probabilities_jax_H = hessian(_predict_transition_probabilities_jax)

    _HAS_JAX = True
except ImportError:
    _HAS_JAX = False

    hessian = jit = lambda _: _

    def _predict_transition_probabilities_jax(*_args, **_kwargs):
        raise NotImplementedError("No `jax` installation found.")

    _predict_transition_probabilities_jax_H = _predict_transition_probabilities_jax


@nb.njit(**jit_kwargs)
def _softmax(x: np.ndarray, softmax_scale: float) -> Tuple[np.ndarray, np.ndarray]:
    numerator = x * softmax_scale
    numerator = np.exp(numerator - np.max(numerator))
    return numerator / np.sum(numerator), x


@nb.njit(**jit_kwargs)
def _softmax_masked(x: np.ndarray, mask: np.ndarray, softmax_scale: float) -> Tuple[np.ndarray, np.ndarray]:
    numerator = x * softmax_scale
    numerator = np.exp(numerator - np.nanmax(numerator))
    numerator = np.where(mask, 0, numerator)  # essential

    return numerator / np.nansum(numerator), x


@nb.njit(parallel=False, **jit_kwargs)
def _predict_transition_probabilities_numpy(
    X: np.ndarray,
    W: np.ndarray,
    softmax_scale: float = 1.0,
    center_mean: bool = True,
    scale_by_norm: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
    if center_mean:
        # pearson correlation
        W -= np.expand_dims(np_mean(W, axis=1), axis=1)

    if X.shape[0] == 1:
        if center_mean:
            # pearson correlation
            X = X - np.mean(X)

        if scale_by_norm:
            # cosine or pearson correlation
            denom = np.linalg.norm(X) * norm(W, axis=1)
            mask = denom == 0
            denom[mask] = 1
            return _softmax_masked(W.dot(X[0]) / denom, mask, softmax_scale)

        return _softmax(W.dot(X[0]), softmax_scale)

    assert X.shape[0] == W.shape[0], "Wrong shape."

    if center_mean:
        X = X - np.expand_dims(np_mean(X, axis=1), axis=1)

    if scale_by_norm:
        denom = norm(X, axis=1) * norm(W, axis=1)
        mask = denom == 0
        denom[mask] = 1
        return _softmax_masked(
            np.array([np.dot(X[i], W[i]) for i in range(X.shape[0])]) / denom,
            mask,
            softmax_scale,
        )

    return _softmax(np.array([np.dot(X[i], W[i]) for i in range(X.shape[0])]), softmax_scale)


class Hessian(abc.ABC):
    @d.get_full_description(base="hessian")
    @d.get_sections(base="hessian", sections=["Parameters", "Returns"])
    @abc.abstractmethod
    def hessian(self, v: np.ndarray, D: np.ndarray, softmax_scale: float = 1.0) -> np.ndarray:
        """Compute the Hessian.

        Parameters
        ----------
        v
            Array of shape ``(n_genes,)`` containing the velocity vector.
        D
            Array of shape ``(n_neighbors, n_genes)`` corresponding to the transcriptomic displacement of the current
            cell with respect to ist nearest neighbors.
        softmax_scale
            Scaling factor for the softmax function.

        Returns
        -------
        The full Hessian of shape ``(n_neighbors, n_genes, n_genes)`` or only its diagonal of shape
        ``(n_neighbors, n_genes)``.

        Developer notes
        ---------------
        This class should be used in conjunction with any class that implements the ``__call__`` method and should
        compute the Hessian of such function. This it does not need to handle the case of a backward process,
        i.e., when the velocity vector :math:`v` is of shape ``(n_genes, n_neighbors)``.

        If using :mod:`jax` to compute the Hessian, please specify the class attribute ``__use_jax__ = True``.
        """


[docs] class SimilarityABC(abc.ABC): """Base class for all similarity schemes."""
[docs] @d.get_full_description(base="sim_scheme") @d.get_sections(base="sim_scheme", sections=["Parameters", "Returns"]) @abc.abstractmethod def __call__(self, v: np.ndarray, D: np.ndarray, softmax_scale: float = 1.0) -> Tuple[np.ndarray, np.ndarray]: """Compute transition probability of a cell to its nearest neighbors using RNA velocity. Parameters ---------- v Array of shape ``(n_genes,)`` or ``(n_neighbors, n_genes)`` containing the velocity vector(s). The second case is used for the backward process. D Array of shape ``(n_neighbors, n_genes)`` corresponding to the transcriptomic displacement of the current cell with respect to ist nearest neighbors. softmax_scale Scaling factor for the softmax function. Returns ------- The probability and unscaled logits arrays of shape ``(n_neighbors,)``. """
@staticmethod def create(scheme: Similarity) -> "SimilarityABC": if scheme == Similarity.CORRELATION: return Correlation() if scheme == Similarity.COSINE: return Cosine() if scheme == Similarity.DOT_PRODUCT: return DotProduct() raise NotImplementedError(scheme) def __repr__(self): return f"<{self.__class__.__name__}>" def __str__(self): return repr(self)
class SimilarityHessian(SimilarityABC, Hessian): """Base class for all similarity schemes as defined in :cite:`li:20`. Parameters ---------- center_mean Whether to center the velocity vectors and the transcriptomic displacement matrix. scale_by_norm Whether to scale the velocity vectors) and the transcriptomic displacement matrix by their norms. """ __use_jax__: bool = True def __init__(self, center_mean: bool, scale_by_norm: bool): self._center_mean = center_mean self._scale_by_norm = scale_by_norm @d.dedent def __call__(self, v: np.ndarray, D: np.ndarray, softmax_scale: float = 1.0) -> Tuple[np.ndarray, np.ndarray]: """%(sim_scheme.full_desc)s Parameters ---------- %(sim_scheme.parameters)s Returns ------- %(sim_scheme.returns)s """ # noqa: D400 return _predict_transition_probabilities_numpy(v, D, softmax_scale, self._center_mean, self._scale_by_norm) @d.dedent def hessian(self, v: np.ndarray, D: np.ndarray, softmax_scale: float = 1.0) -> np.ndarray: """%(hessian.full_desc)s Parameters ---------- %(hessian.parameters)s Returns ------- %(hessian.returns)s """ # noqa: D400 return _predict_transition_probabilities_jax_H(v, D, softmax_scale, self._center_mean, self._scale_by_norm) def __getattribute__(self, name: str) -> Any: if name == "hessian" and self.__use_jax__ and not _HAS_JAX: raise NotImplementedError("No `jax` installation found.") return super().__getattribute__(name)
[docs] class DotProduct(SimilarityHessian): r"""Dot product scheme as defined in eq. (4.9) :cite:`li:20`. .. math:: v(s_i, s_j) := g(\delta_{i, j}^T v_i) where :math:`v_i` is the velocity vector of cell :math:`i`, :math:`\delta_{i, j}` corresponds to the transcriptional displacement between cells :math:`i` and :math:`j` and :math:`g` is a softmax function with some scaling parameter. """ def __init__(self): super().__init__(center_mean=False, scale_by_norm=False)
[docs] class Cosine(SimilarityHessian): r"""Cosine similarity scheme as defined in eq. (4.7) :cite:`li:20`. .. math:: v(s_i, s_j) := g(cos(\delta_{i, j}, v_i)) where :math:`v_i` is the velocity vector of cell :math:`i`, :math:`\delta_{i, j}` corresponds to the transcriptional displacement between cells :math:`i` and :math:`j` and :math:`g` is a softmax function with some scaling parameter. """ def __init__(self): super().__init__(center_mean=False, scale_by_norm=True)
[docs] class Correlation(SimilarityHessian): r"""Pearson correlation scheme as defined in eq. (4.8) :cite:`li:20`. .. math:: v(s_i, s_j) := g(corr(\delta_{i, j}, v_i)) where :math:`v_i` is the velocity vector of cell :math:`i`, :math:`\delta_{i, j}` corresponds to the transcriptional displacement between cells :math:`i` and :math:`j` and :math:`g` is a softmax function with some scaling parameter. """ def __init__(self): super().__init__(center_mean=True, scale_by_norm=True)