from typing import Any, Dict, List, Union, Optional, Sequence
from typing_extensions import Literal
from anndata import AnnData
from cellrank import logging as logg
from cellrank.ul._docs import d
from cellrank.tl._utils import (
_cluster_X,
_filter_cells,
_complex_warning,
_get_connectivities,
)
from cellrank.tl.kernels._utils import _get_basis
from cellrank.tl.estimators.mixins import EigenMixin, LinDriversMixin
from cellrank.tl.estimators.terminal_states._term_states_estimator import (
TermStatesEstimator,
)
import numpy as np
import pandas as pd
from scipy.stats import zscore
[docs]@d.dedent
class CFLARE(TermStatesEstimator, LinDriversMixin, EigenMixin):
"""
Compute the initial/terminal states of a Markov chain via spectral heuristics.
This estimator uses the left eigenvectors of the transition matrix to filter to a set of recurrent cells
and the right eigenvectors to cluster this set of cells into discrete groups.
Parameters
----------
%(base_estimator.parameters)s
"""
[docs] @d.dedent
def fit(self, k: int = 20, **kwargs: Any) -> "TermStatesEstimator":
"""
Prepare self for terminal states prediction.
Parameters
----------
k
Number of eigenvectors to compute.
kwargs
Keyword arguments for :meth:`compute_eigendecomposition`.
Returns
-------
Self and modifies the following field:
- :attr:`eigendecomposition` - %(eigen.summary)s
"""
self.compute_eigendecomposition(k=k, only_evals=False, **kwargs)
return self
[docs] @d.dedent
def predict(
self,
use: Optional[Union[int, Sequence[int]]] = None,
percentile: Optional[int] = 98,
method: Literal["leiden", "means"] = "leiden",
cluster_key: Optional[str] = None,
n_clusters_kmeans: Optional[int] = None,
n_neighbors: int = 20,
resolution: float = 0.1,
n_matches_min: int = 0,
n_neighbors_filtering: int = 15,
basis: Optional[str] = None,
n_comps: int = 5,
scale: Optional[bool] = None,
) -> None:
"""
Find approximate recurrent classes of the Markov chain.
Filter to obtain recurrent states in left eigenvectors.
Cluster to obtain approximate recurrent classes in right eigenvectors.
Parameters
----------
use
Which or how many first eigenvectors to use as features for filtering and clustering.
If `None`, use the *eigengap* statistic.
percentile
Threshold used for filtering out cells which are most likely transient states. Cells which are in the
lower ``percentile`` percent of each eigenvector will be removed from the data matrix.
method
Method to be used for clustering. Valid option are:
- `'kmeans'` - :class:`sklearn.cluster.KMeans`.
- `'leiden'` - :func:`scanpy.tl.leiden`.
cluster_key
Key in :attr:`anndata.AnnData.obs` in order to associate names and colors with :attr:`terminal_states`.
n_clusters_kmeans
If `None`, this is set to ``use + 1``.
n_neighbors
Number of neighbors in a KNN graph. This is the :math:`K` parameter for that,
the number of neighbors for each cell. Only used when ``method = 'leiden'``.
resolution
Resolution parameter for :func:`scanpy.tl.leiden`. Should be chosen relatively small.
n_matches_min
Filters out cells which don't have at least ``n_matches_min`` neighbors from the same category.
This filters out some cells which are transient but have been misassigned.
n_neighbors_filtering
Parameter for filtering cells. Cells are filtered out if they don't have at least ``n_matches_min``
neighbors among their ``n_neighbors_filtering`` nearest cells.
basis
Key from :attr:`anndata.AnnData.obsm` as additional features for clustering.
If `None`, use only the right eigenvectors.
n_comps
Number of embedding components to be use when ``basis != None``.
scale
Scale the values to z-scores. If `None`, scale the values if ``basis != None``.
Returns
-------
Nothing, just updates the following fields:
- :attr:`terminal_states` - %(tse_term_states.summary)s
- :attr:`terminal_states_probabilities` - %(tse_term_states_probs.summary)s
"""
def convert_use(
use: Optional[Union[int, Sequence[int], np.ndarray]]
) -> List[int]:
if method not in ["kmeans", "leiden"]:
raise ValueError(
f"Invalid method `{method!r}`. Valid options are `leiden` or `kmeans`."
)
if use is None:
use = eig["eigengap"] + 1 # add one b/c indexing starts at 0
if isinstance(use, int):
use = list(range(use))
elif not isinstance(use, (np.ndarray, Sequence)):
raise TypeError(
f"Expected `use` to be `int` or a `Sequence`, found `{type(use).__name__}`."
)
use = list(use)
if not use:
raise ValueError("No eigenvectors have been selected.")
muse, mevecs = max(use), max(eig["V_l"].shape[1], eig["V_r"].shape[1])
if muse >= mevecs:
raise ValueError(
f"Maximum specified eigenvector `{muse}` is larger than the number of eigenvectors `{mevecs}`. "
f"Use `.compute_eigendecomposition(k={muse})` to recompute the eigendecomposition."
)
return use
eig = self.eigendecomposition
if eig is None:
raise RuntimeError(
"Compute eigendecomposition first as `.compute_eigendecomposition()`."
)
use = convert_use(use)
start = logg.info("Computing terminal states")
# we check for complex values only in the left, that's okay because the complex pattern
# will be identical for left and right
V_l, V_r = eig["V_l"][:, use], eig["V_r"].real[:, use]
V_l = _complex_warning(V_l, use, use_imag=False)
# retrieve embedding and concatenate
if basis is not None:
X_em = _get_basis(self.adata, basis=basis)[:, :n_comps]
X = np.concatenate([V_r, X_em], axis=1)
else:
X = V_r
# filter out cells which are in the lowest q percentile in abs value in each eigenvector
if percentile is not None:
logg.debug("Filtering out cells according to percentile")
if not (0 <= percentile <= 100):
raise ValueError(
f"Expected `percentile` to be in interval `[0, 100]`, found `{percentile}`."
)
cutoffs = np.percentile(np.abs(V_l), percentile, axis=0)
ixs = np.any(cutoffs <= np.abs(V_l), axis=1)
X = X[ixs, :]
if scale is None:
scale = basis is not None
if scale:
X = zscore(X, axis=0)
# cluster X
if method == "kmeans" and n_clusters_kmeans is None:
n_clusters_kmeans = len(use) + (percentile is None)
if X.shape[0] < n_clusters_kmeans:
raise ValueError(
f"Filtering resulted in only {X.shape[0]} cell(s), insufficient to cluster into "
f"`{n_clusters_kmeans}` clusters. Consider decreasing the value of `percentile`."
)
# fmt: off
logg.debug(f"Using `{use}` eigenvectors, basis `{basis!r}` and method `{method!r}` for clustering")
clusters = _cluster_X(
X,
method=method,
n_clusters=n_clusters_kmeans,
n_neighbors=n_neighbors,
resolution=resolution,
)
# fill in the labels in case we filtered out cells before
if percentile is not None:
labels = np.repeat(None, len(self))
labels[ixs] = clusters
else:
labels = clusters
labels = pd.Series(labels, index=self.adata.obs_names, dtype="category")
labels = labels.cat.rename_categories({c: str(c) for c in labels.cat.categories})
# filtering to get rid of some of the left over transient states
if n_matches_min > 0:
logg.debug(f"Filtering according to `n_matches_min={n_matches_min}`")
distances = _get_connectivities(self.adata, mode="distances", n_neighbors=n_neighbors_filtering)
labels = _filter_cells(distances, rc_labels=labels, n_matches_min=n_matches_min)
# fmt: on
self.set_terminal_states(
labels=labels,
cluster_key=cluster_key,
probs=self._compute_term_states_probs(eig, use),
params=self._create_params(),
time=start,
)
def _compute_term_states_probs(
self, eig: Dict[str, Any], use: List[int]
) -> pd.Series:
# get the truncated eigendecomposition
V, evals = eig["V_l"].real[:, use], eig["D"].real[use]
# shift and scale
V_pos = np.abs(V)
V_shifted = V_pos - np.min(V_pos, axis=0)
V_scaled = V_shifted / np.max(V_shifted, axis=0)
# check the ranges are correct
np.testing.assert_allclose(np.min(V_scaled, axis=0), 0)
np.testing.assert_allclose(np.max(V_scaled, axis=0), 1)
# further scale by the eigenvalues
V_eigs = V_scaled / evals
# sum over cols and scale
c = np.sum(V_eigs, axis=1)
c /= np.max(c)
return pd.Series(c, index=self.adata.obs_names)
def _read_from_adata(self, adata: AnnData, **kwargs: Any) -> bool:
ok = super()._read_from_adata(adata, **kwargs)
return (
ok
and self._read_eigendecomposition(adata, allow_missing=False)
and self._read_absorption_probabilities(adata)
)