Source code for cellrank.tl.estimators._cflare

"""Clustering and Filtering of Left and Right Eigenvectors based on Markov chains."""
from typing import List, Tuple, Union, Optional, Sequence

from cellrank import logging as logg
from cellrank.ul._docs import d, inject_docs
from cellrank.tl._utils import (
    _cluster_X,
    _filter_cells,
    _complex_warning,
    _get_connectivities,
)
from cellrank.tl.estimators._constants import A, P
from cellrank.tl.estimators._decomposition import Eigen
from cellrank.tl.estimators._base_estimator import BaseEstimator

import numpy as np
from pandas import Series
from scipy.stats import zscore


[docs]@d.dedent class CFLARE(BaseEstimator, Eigen): """ Compute the initial/terminal states of a Markov chain via spectral heuristics. This estimator uses the left eigenvectors of the transition matrix to filter to a set of recurrent cells and the right eigenvectors to cluster this set of cells into discrete groups. Parameters ---------- %(base_estimator.parameters)s """
[docs] @d.dedent @inject_docs(fs=P.TERM, fsp=P.TERM_PROBS) def compute_terminal_states( self, use: Optional[Union[int, Tuple[int], List[int], range]] = None, percentile: Optional[int] = 98, method: str = "kmeans", cluster_key: Optional[str] = None, n_clusters_kmeans: Optional[int] = None, n_neighbors: int = 20, resolution: float = 0.1, n_matches_min: Optional[int] = 0, n_neighbors_filtering: int = 15, basis: Optional[str] = None, n_comps: int = 5, scale: bool = False, en_cutoff: Optional[float] = 0.7, p_thresh: float = 1e-15, ) -> None: """ Find approximate recurrent classes of the Markov chain. Filter to obtain recurrent states in left eigenvectors. Cluster to obtain approximate recurrent classes in right eigenvectors. Parameters ---------- use Which or how many first eigenvectors to use as features for clustering/filtering. If `None`, use the `eigengap` statistic. percentile Threshold used for filtering out cells which are most likely transient states. Cells which are in the lower ``percentile`` percent of each eigenvector will be removed from the data matrix. method Method to be used for clustering. Must be one of `'louvain'`, `'leiden'` or `'kmeans'`. cluster_key If a key to cluster labels is given, :attr:`{fs}` will get associated with these for naming and colors. n_clusters_kmeans If `None`, this is set to ``use + 1``. n_neighbors If we use `'louvain'` or `'leiden'` for clustering cells, we need to build a KNN graph. This is the :math:`K` parameter for that, the number of neighbors for each cell. resolution Resolution parameter for `'louvain'` or `'leiden'` clustering. Should be chosen relatively small. n_matches_min Filters out cells which don't have at least n_matches_min neighbors from the same class. This filters out some cells which are transient but have been misassigned. n_neighbors_filtering Parameter for filtering cells. Cells are filtered out if they don't have at least ``n_matches_min`` neighbors among their ``n_neighbors_filtering`` nearest cells. basis Key from :paramref`adata` ``.obsm`` to be used as additional features for the clustering. n_comps Number of embedding components to be use when ``basis`` is not `None`. scale Scale to z-scores. Consider using this if appending embedding to features. %(en_cutoff_p_thresh)s Returns ------- None Nothing, but updates the following fields: - :attr:`{fsp}` - :attr:`{fs}` """ def _compute_macrostates_prob() -> Series: """Compute a global score of being an approximate recurrent class.""" # get the truncated eigendecomposition V, evals = eig["V_l"].real[:, use], eig["D"].real[use] # shift and scale V_pos = np.abs(V) V_shifted = V_pos - np.min(V_pos, axis=0) V_scaled = V_shifted / np.max(V_shifted, axis=0) # check the ranges are correct assert np.allclose(np.min(V_scaled, axis=0), 0), "Lower limit it not zero." assert np.allclose(np.max(V_scaled, axis=0), 1), "Upper limit is not one." # further scale by the eigenvalues V_eigs = V_scaled / evals # sum over cols and scale c_ = np.sum(V_eigs, axis=1) c = c_ / np.max(c_) return Series(c, index=self.adata.obs_names) def check_use(use) -> List[int]: if method not in ["kmeans", "louvain", "leiden"]: raise ValueError( f"Invalid method `{method!r}`. Valid options are `'louvain'`, `'leiden'` or `'kmeans'`." ) if use is None: use = eig["eigengap"] + 1 # add one b/c indexing starts at 0 if isinstance(use, int): use = list(range(use)) elif not isinstance(use, (tuple, list, range)): raise TypeError( f"Argument `use` must be either `int`, `tuple`, `list` or `range`, " f"found `{type(use).__name__!r}`." ) else: if not all(map(lambda u: isinstance(u, int), use)): raise TypeError("Not all values in `use` argument are integers.") use = list(use) if len(use) == 0: raise ValueError( f"Number of eigenvector must be larger than `0`, found `{len(use)}`." ) muse = max(use) if muse >= eig["V_l"].shape[1] or muse >= eig["V_r"].shape[1]: raise ValueError( f"Maximum specified eigenvector `{muse}` is larger " f'than the number of computed eigenvectors `{eig["V_l"].shape[1]}`. ' f"Use `.compute_eigendecomposition(k={muse})` to recompute the eigendecomposition." ) return use eig = self._get(P.EIG) if eig is None: raise RuntimeError( "Compute eigendecomposition first as `.compute_eigendecomposition()`." ) use = check_use(use) start = logg.info("Computing approximate recurrent classes") # we check for complex values only in the left, that's okay because the complex pattern # will be identical for left and right V_l, V_r = eig["V_l"][:, use], eig["V_r"].real[:, use] V_l = _complex_warning(V_l, use, use_imag=False) # compute a rc probability logg.debug("Computing probabilities of approximate recurrent classes") self._set(A.TERM_PROBS, _compute_macrostates_prob()) # retrieve embedding and concatenate if basis is not None: bkey = f"X_{basis}" if bkey not in self.adata.obsm.keys(): raise KeyError(f"Basis key `{bkey!r}` not found in `adata.obsm`") X_em = self.adata.obsm[bkey][:, :n_comps] X = np.concatenate([V_r, X_em], axis=1) else: logg.debug("Basis is `None`. Setting X equal to the right eigenvectors") X = V_r # filter out cells which are in the lowest q percentile in abs value in each eigenvector if percentile is not None: logg.debug("Filtering out cells according to percentile") if percentile < 0 or percentile > 100: raise ValueError( f"Percentile must be in interval `[0, 100]`, found `{percentile}`." ) cutoffs = np.percentile(np.abs(V_l), percentile, axis=0) ixs = np.sum(np.abs(V_l) < cutoffs, axis=1) < V_l.shape[1] X = X[ixs, :] # scale if scale: X = zscore(X, axis=0) # cluster X if method == "kmeans" and n_clusters_kmeans is None: n_clusters_kmeans = len(use) + (percentile is None) if X.shape[0] < n_clusters_kmeans: raise ValueError( f"Filtering resulted in only {X.shape[0]} cell(s), insufficient to cluster into " f"`{n_clusters_kmeans}` clusters. Consider decreasing the value of `percentile`." ) logg.debug( f"Using `{use}` eigenvectors, basis `{basis!r}` and method `{method!r}` for clustering" ) labels = _cluster_X( X, method=method, n_clusters=n_clusters_kmeans, n_neighbors=n_neighbors, resolution=resolution, ) # fill in the labels in case we filtered out cells before if percentile is not None: rc_labels = np.repeat(None, self.adata.n_obs) rc_labels[ixs] = labels else: rc_labels = labels rc_labels = Series(rc_labels, index=self.adata.obs_names, dtype="category") rc_labels.cat.categories = list(rc_labels.cat.categories.astype("str")) # filtering to get rid of some of the left over transient states if n_matches_min > 0: logg.debug(f"Filtering according to `n_matches_min={n_matches_min}`") distances = _get_connectivities( self.adata, mode="distances", n_neighbors=n_neighbors_filtering ) rc_labels = _filter_cells( distances, rc_labels=rc_labels, n_matches_min=n_matches_min ) self.set_terminal_states( labels=rc_labels, cluster_key=cluster_key, en_cutoff=en_cutoff, p_thresh=p_thresh, add_to_existing=False, time=start, )
def _fit_terminal_states( self, n_lineages: Optional[int] = None, keys: Optional[Sequence[str]] = None, cluster_key: Optional[str] = None, compute_absorption_probabilities: bool = True, **kwargs, ) -> None: self.compute_eigendecomposition(k=20 if n_lineages is None else n_lineages + 1) if n_lineages is None: n_lineages = self._get(P.EIG)["eigengap"] + 1 self.compute_terminal_states( use=n_lineages, cluster_key=cluster_key, n_clusters_kmeans=n_lineages, method=kwargs.pop("method", "kmeans"), **kwargs, )
[docs] @d.dedent # because of fit @d.dedent @inject_docs(fs=P.TERM, fsp=P.TERM_PROBS, ap=P.ABS_PROBS, pd=P.PRIME_DEG) def fit( self, n_lineages: Optional[int], keys: Optional[Sequence[str]] = None, cluster_key: Optional[str] = None, compute_absorption_probabilities: bool = True, **kwargs, ): """ Run the pipeline, computing the %(initial_or_terminal)s states and optionally the absorption probabilities. It is equivalent to running:: compute_eigendecomposition(...) compute_terminal_states(...) compute_absorption_probabilities(...) Parameters ---------- %(fit)s kwargs Keyword arguments for :meth:`compute_terminal_states`, such as ``n_cells``. Returns ------- None Nothing, just makes available the following fields: - :attr:`{fsp}` - :attr:`{fs}` - :attr:`{ap}` - :attr:`{pd}` """ super().fit( n_lineages=n_lineages, keys=keys, cluster_key=cluster_key, compute_absorption_probabilities=compute_absorption_probabilities, **kwargs, )