import abc
from typing import Any, Callable, Optional, Tuple
import numpy as np
import scipy.sparse as sp
from cellrank._utils._docs import d
from cellrank._utils._parallelize import parallelize
__all__ = ["HardThresholdScheme", "SoftThresholdScheme", "CustomThresholdScheme"]
[docs]
class ThresholdSchemeABC(abc.ABC):
"""Base class for all connectivity biasing schemes."""
[docs]
@d.get_summary(base="pt_scheme")
@d.get_sections(base="pt_scheme", sections=["Parameters", "Returns"])
@abc.abstractmethod
def __call__(
self,
cell_pseudotime: float,
neigh_pseudotime: np.ndarray,
neigh_conn: np.ndarray,
**kwargs: Any,
) -> np.ndarray:
"""Calculate biased connections for a given cell.
Parameters
----------
cell_pseudotime
Pseudotime of the current cell.
neigh_pseudotime
Array of shape ``(n_neighbors,)`` containing pseudotime of neighbors.
neigh_conn
Array of shape ``(n_neighbors,)`` containing connectivities of the current cell and its neighbors.
Returns
-------
Array of shape ``(n_neighbors,)`` containing the biased connectivities.
"""
def _bias_knn_helper(
self,
ixs: np.ndarray,
conn: sp.csr_matrix,
pseudotime: np.ndarray,
queue=None,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
i, indices, indptr, data = 0, [], [], []
for i in ixs:
row = conn[i]
biased_row = self(pseudotime[i], pseudotime[row.indices], row.data, **kwargs)
if np.shape(biased_row) != row.data.shape:
raise ValueError(f"Expected row of shape `{row.data.shape}`, found `{np.shape(biased_row)}`.")
data.extend(biased_row)
indices.extend(row.indices)
indptr.append(conn.indptr[i])
if queue is not None:
queue.put(1)
if i == conn.shape[0] - 1:
indptr.append(conn.indptr[-1])
if queue is not None:
queue.put(None)
return np.array(data), np.array(indices), np.array(indptr)
[docs]
@d.dedent
def bias_knn(
self,
conn: sp.csr_matrix,
pseudotime: np.ndarray,
n_jobs: Optional[int] = None,
backend: str = "loky",
show_progress_bar: bool = True,
**kwargs: Any,
) -> sp.csr_matrix:
"""Bias cell-cell connectivities of a KNN graph.
Parameters
----------
conn
Sparse matrix of shape ``(n_cells, n_cells)`` containing the nearest neighbor connectivities.
pseudotime
Pseudotemporal ordering of cells.
%(parallel)s
Returns
-------
The biased connectivities.
"""
res = parallelize(
self._bias_knn_helper,
np.arange(conn.shape[0]),
as_array=False,
unit="cell",
n_jobs=n_jobs,
backend=backend,
show_progress_bar=show_progress_bar,
)(conn, pseudotime, **kwargs)
data, indices, indptr = zip(*res)
conn = sp.csr_matrix((np.concatenate(data), np.concatenate(indices), np.concatenate(indptr)))
conn.eliminate_zeros()
return conn
def __repr__(self) -> str:
return f"<{self.__class__.__name__}>"
def __str__(self) -> str:
return repr(self)
[docs]
class HardThresholdScheme(ThresholdSchemeABC):
"""Thresholding scheme inspired by *Palantir* :cite:`setty:19`.
Note that this won't exactly reproduce the original *Palantir* results:
- *Palantir* computes the kNN graph in a scaled space of diffusion components.
- *Palantir* uses its own pseudotime to bias the kNN graph which is not implemented here.
- *Palantir* uses a slightly different mechanism to ensure the graph remains connected when removing edges
that point into the "pseudotime past".
"""
[docs]
@d.dedent
def __call__(
self,
cell_pseudotime: float,
neigh_pseudotime: np.ndarray,
neigh_conn: np.ndarray,
frac_to_keep: float = 0.3,
) -> np.ndarray:
"""Convert the undirected graph of cell-cell similarities into a directed one by removing "past" edges.
This uses a pseudotemporal measure to remove graph-edges that point into the pseudotime-past. For each cell,
it keeps the closest neighbors, even if they are in the pseudotime past, to make sure the graph remains
connected.
Parameters
----------
%(pt_scheme.parameters)s
frac_to_keep
The ``frac_to_keep`` * n_neighbors closest neighbors (according to graph connectivities) are kept, no matter
whether they lie in the pseudotemporal past or future. Must be in :math:`[0, 1]`.
Returns
-------
%(pt_scheme.returns)s
"""
if not (0 <= frac_to_keep <= 1):
raise ValueError(f"Expected `frac_to_keep` to be in `[0, 1]`, found `{frac_to_keep}`.")
k_thresh = max(0, min(30, int(np.floor(len(neigh_conn) * frac_to_keep))))
ixs = np.flip(np.argsort(neigh_conn))
close_ixs, far_ixs = ixs[:k_thresh], ixs[k_thresh:]
mask_keep = cell_pseudotime <= neigh_pseudotime[far_ixs]
far_ixs_keep = far_ixs[mask_keep]
biased_conn = np.zeros_like(neigh_conn)
biased_conn[close_ixs] = neigh_conn[close_ixs]
biased_conn[far_ixs_keep] = neigh_conn[far_ixs_keep]
return biased_conn
[docs]
class SoftThresholdScheme(ThresholdSchemeABC):
"""Thresholding scheme inspired by :cite:`stassen:21`.
The idea is to downweight edges that points against the direction of increasing pseudotime. Essentially, the
further "behind" a query cell is in pseudotime with respect to the current reference cell, the more penalized will
be its graph-connectivity.
"""
[docs]
@d.dedent
def __call__(
self,
cell_pseudotime: float,
neigh_pseudotime: np.ndarray,
neigh_conn: np.ndarray,
b: float = 10.0,
nu: float = 0.5,
) -> np.ndarray:
"""Bias the connectivities by downweighting ones to past cells.
This function uses `generalized logistic regression
<https://en.wikipedia.org/wiki/Generalized_logistic_function>`_ to weight the past connectivities.
Parameters
----------
%(pt_scheme.parameters)s
%(soft_scheme)s
Returns
-------
%(pt_scheme.returns)s
"""
past_ixs = np.where(neigh_pseudotime < cell_pseudotime)[0]
if not len(past_ixs):
return neigh_conn
weights = np.ones_like(neigh_conn)
dt = cell_pseudotime - neigh_pseudotime[past_ixs]
weights[past_ixs] = 2.0 / ((1.0 + np.exp(b * dt)) ** (1.0 / nu))
return neigh_conn * weights
[docs]
class CustomThresholdScheme(ThresholdSchemeABC):
"""Class that wraps a user supplied scheme.
Parameters
----------
callback
Function which returns the biased connectivities.
"""
def __init__(
self,
callback: Callable[[float, np.ndarray, np.ndarray, np.ndarray, Any], np.ndarray],
):
super().__init__()
self._callback = callback
[docs]
@d.dedent
def __call__(
self,
cell_pseudotime: float,
neigh_pseudotime: np.ndarray,
neigh_conn: np.ndarray,
**kwargs: Any,
) -> np.ndarray:
"""%(pt_scheme.summary)s
Parameters
----------
%(pt_scheme.parameters)s
kwargs
Additional keyword arguments.
Returns
-------
%(pt_scheme.returns)s
""" # noqa: D400
return self._callback(cell_pseudotime, neigh_pseudotime, neigh_conn, **kwargs)