Source code for cellrank.tl.kernels._pseudotime_schemes

from abc import ABC, abstractmethod
from typing import Any, Callable, Optional

import numpy as np
from scipy.sparse import csr_matrix

from cellrank.ul._docs import d


[docs]class ThresholdSchemeABC(ABC): """Base class for all connectivity biasing schemes."""
[docs] @d.get_summary(base="pt_scheme") @d.get_sections(base="pt_scheme", sections=["Parameters", "Returns"]) @abstractmethod def __call__( self, cell_pseudotime: float, neigh_pseudotime: np.ndarray, neigh_conn: np.ndarray, **kwargs: Any, ) -> np.ndarray: """ Calculate biased connections for a given cell. Parameters ---------- cell_pseudotime Pseudotime of the current cell. neigh_pseudotime Array of shape ``(n_neighbors,)`` containing pseudotimes of neighbors. neigh_conn Array of shape ``(n_neighbors,)`` containing connectivities of the current cell and its neighbors. Returns ------- Array of shape ``(n_neighbors,)`` containing the biased connectivities. """
[docs] def bias_knn( self, conn: csr_matrix, pseudotime: np.ndarray, **kwargs: Any ) -> csr_matrix: """ Bias cell-cell connectivities of a KNN graph. Parameters ---------- conn Sparse matrix of shape ``(n_cells, n_cells)`` containing the nearest neighbor connectivities. pseudotime Pseudotemporal ordering of cells. Returns ------- The biased connectivities. """ conn_biased = conn.copy() for i in range(conn.shape[0]): row, start, end = conn[i], conn.indptr[i], conn.indptr[i + 1] biased_row = self( pseudotime[i], pseudotime[row.indices], row.data, **kwargs ) if np.shape(biased_row) != row.data.shape: raise ValueError( f"Expected row of shape `{row.data.shape}`, found `{np.shape(biased_row)}`." ) conn_biased.data[start:end] = biased_row conn_biased.eliminate_zeros() return conn_biased
def __repr__(self): return f"<{self.__class__.__name__}>" def __str__(self): return repr(self)
[docs]class HardThresholdScheme(ThresholdSchemeABC): """ Thresholding scheme inspired by Palantir [Setty19]_. Note that this won't exactly reproduce the original Palantir results, for three reasons: - Palantir computes the KNN graph in a scaled space of diffusion components. - Palantir uses its own pseudotime to bias the KNN graph which is not implemented here. - Palantir uses a slightly different mechanism to ensure the graph remains connected when removing edges that point into the "pseudotime past". """
[docs] @d.dedent def __call__( self, cell_pseudotime: float, neigh_pseudotime: np.ndarray, neigh_conn: np.ndarray, n_neighs: int, frac_to_keep: float = 0.3, ) -> np.ndarray: """ Convert the undirected graph of cell-cell similarities into a directed one by removing "past" edges. This uses a pseudotemporal measure to remove graph-edges that point into the pseudotime-past. For each cell, it keeps the closest neighbors, even if they are in the pseudotime past, to make sure the graph remains connected. Parameters ---------- %(pt_scheme.parameters)s n_neighs Number of neighbors to keep. frac_to_keep The `fract_to_keep` * n_neighbors closest neighbors (according to graph connectivities) are kept, no matter whether they lie in the pseudotemporal past or future. Returns ------- %(pt_scheme.returns)s """ k_thresh = max(0, min(30, int(np.floor(n_neighs * frac_to_keep)) - 1)) # below code does not work with argpartition ixs = np.flip(np.argsort(neigh_conn)) close_ixs, far_ixs = ixs[:k_thresh], ixs[k_thresh:] mask_keep = cell_pseudotime <= neigh_pseudotime[far_ixs] far_ixs_keep = far_ixs[mask_keep] biased_conn = np.zeros_like(neigh_conn) biased_conn[close_ixs] = neigh_conn[close_ixs] biased_conn[far_ixs_keep] = neigh_conn[far_ixs_keep] return biased_conn
[docs]class SoftThresholdScheme(ThresholdSchemeABC): """ Thresholding scheme inspired by [VIA21]_. The idea is to downweight edges that points against the direction of increasing pseudotime. Essentially, the further "behind" a query cell is in pseudotime with respect to the current reference cell, the more penalized will be its graph-connectivity. """
[docs] @d.dedent def __call__( self, cell_pseudotime: float, neigh_pseudotime: np.ndarray, neigh_conn: np.ndarray, b: float = 20.0, nu: float = 1.0, perc: Optional[int] = 95, ) -> np.ndarray: """ Bias the connectivities by downweighting ones to past cells. This function uses `generalized logistic regression <https://en.wikipedia.org/wiki/Generalized_logistic_function>`_ to weight the past connectivities. Parameters ---------- %(pt_scheme.parameters)s %(soft_scheme)s Returns ------- %(pt_scheme.returns)s """ if perc is not None: neigh_conn = np.clip( neigh_conn, np.percentile(neigh_conn, 100 - perc), np.percentile(neigh_conn, perc), ) past_ixs = np.where(neigh_pseudotime < cell_pseudotime)[0] if not len(past_ixs): return neigh_conn weights = np.ones_like(neigh_conn) dt = cell_pseudotime - neigh_pseudotime[past_ixs] weights[past_ixs] = 2.0 / ((1.0 + np.exp(b * dt)) ** (1.0 / nu)) return neigh_conn * weights
class CustomThresholdScheme(ThresholdSchemeABC): """ Class that wraps a user supplied scheme. Parameters ---------- callback Function which returns the biased connectivities. """ def __init__( self, callback: Callable[ [float, np.ndarray, np.ndarray, np.ndarray, Any], np.ndarray ], ): super().__init__() self._callback = callback @d.dedent def __call__( self, cell_pseudotime: float, neigh_pseudotime: np.ndarray, neigh_conn: np.ndarray, **kwargs: Any, ) -> np.ndarray: """ %(pt_scheme.summary)s Parameters ---------- %(pt_scheme.parameters)s kwargs Additional keyword arguments. Returns ------- %(pt_scheme.returns)s """ # noqa: D400 return self._callback(cell_pseudotime, neigh_pseudotime, neigh_conn, **kwargs)