Source code for cellrank.tl.kernels._cytotrace_kernel

from typing import Any, List, Tuple, Optional
from typing_extensions import Literal

from enum import auto

from anndata import AnnData
from cellrank import logging as logg
from cellrank._key import Key
from cellrank.tl._enum import ModeEnum
from cellrank.ul._docs import d, inject_docs
from cellrank.tl._utils import _correlation_test
from cellrank.tl.kernels._pseudotime_kernel import PseudotimeKernel

import numpy as np
import pandas as pd
from scipy.stats import gmean, hmean
from scipy.sparse import issparse

__all__ = ("CytoTRACEKernel",)


class CytoTRACEAggregation(ModeEnum):
    MEAN = auto()
    MEDIAN = auto()
    GMEAN = auto()
    HMEAN = auto()


[docs]@d.dedent class CytoTRACEKernel(PseudotimeKernel): """ Kernel which computes directed transition probabilities based on a KNN graph and the CytoTRACE score \ :cite:`gulati:20`. The KNN graph contains information about the (undirected) connectivities among cells, reflecting their similarity. CytoTRACE can be used to estimate cellular plasticity and in turn, a pseudotemporal ordering of cells from more plastic to less plastic states. It relies on the assumption that differentiated cells express, on average, less genes than naive cells. This kernel internally uses the :class:`cellrank.kernels.PseudotimeKernel` to direct the kNN graph on the basis of the CytoTRACE-derived pseudotime. %(density_correction)s Parameters ---------- %(adata)s %(backward)s kwargs Keyword arguments for :class:`cellrank.kernels.PseudotimeKernel`. Example ------- Workflow:: # import packages and load data import scvelo as scv import cellrank as cr adata = cr.datasets.pancreas() # standard pre-processing sc.pp.filter_genes(adata, min_cells=10) sc.pp.normalize_total(adata) sc.pp.log1p(adata) sc.pp.highly_variable_genes(adata) # CytoTRACE by default uses imputed data - a simple way to compute kNN-imputed data is to use scVelo's moments # function. However, note that this function expects `spliced` counts because it's designed for RNA velocity, # so we're using a simple hack here: if 'spliced' not in adata.layers or 'unspliced' not in adata.layers: adata.layers['spliced'] = adata.X adata.layers['unspliced'] = adata.X # compute kNN-imputation using scVelo's moments function scv.pp.moments(adata) # import and initialize the CytoTRACE kernel, compute transition matrix - done! from cellrank.kernels import CytoTRACEKernel ctk = CytoTRACEKernel(adata).compute_cytotrace().compute_transition_matrix() """ def __init__( self, adata: AnnData, backward: bool = False, **kwargs: Any, ): kwargs.pop("time_key", None) super().__init__( adata, backward=backward, time_key=Key.cytotrace("pseudotime"), **kwargs, ) def _read_from_adata( self, time_key: str = Key.cytotrace("pseudotime"), **kwargs: Any, ) -> None: try: super()._read_from_adata(time_key=time_key, **kwargs) except KeyError as e: self._pseudotime: Optional[np.ndarray] = None if "Unable to find pseudotime" not in str(e): raise super(PseudotimeKernel, self)._read_from_adata(**kwargs)
[docs] @d.get_sections(base="cytotrace", sections=["Parameters"]) @inject_docs(ct=CytoTRACEAggregation) def compute_cytotrace( self, layer: Optional[str] = "Ms", aggregation: Literal[ "mean", "median", "hmean", "gmean" ] = CytoTRACEAggregation.MEAN, use_raw: bool = False, n_genes: int = 200, ) -> "CytoTRACEKernel": """ Re-implementation of the CytoTRACE algorithm :cite:`gulati:20` to estimate cellular plasticity. Computes the number of genes expressed per cell and ranks genes according to their correlation with this measure. Next, it selects to top-correlating genes and aggregates their (imputed) expression to obtain the CytoTRACE score. A high score stands for high differentiation potential (naive, plastic cells) and a low score stands for low differentiation potential (mature, differentiation cells). Parameters ---------- layer Key in :attr:`anndata.AnnData.layers` or `'X'` for :attr:`anndata.AnnData.X` from where to get the expression. aggregation How to aggregate expression of the top-correlating genes. Valid options are: - `{ct.MEAN!r}` - arithmetic mean. - `{ct.MEDIAN!r}` - median. - `{ct.HMEAN!r}` - harmonic mean. - `{ct.GMEAN!r}` - geometric mean. use_raw Whether to use the :attr:`anndata.AnnData.raw` to compute the number of genes expressed per cell (#genes/cell) and the correlation of gene expression across cells with #genes/cell. n_genes Number of top positively correlated genes to compute the CytoTRACE score. Returns ------- Nothing, just modifies :attr:`anndata.AnnData.obs` with the following keys: - `'ct_score'` - the normalized CytoTRACE score. - `'ct_pseudotime'` - associated pseudotime, essentially `1 - CytoTRACE score`. - `'ct_num_exp_genes'` - the number of genes expressed per cell, basis of the CytoTRACE score. It also modifies :attr:`anndata.AnnData.var` with the following keys: - `'ct_gene_corr'` - the correlation as specified above. - `'ct_correlates'` - indication of the genes used to compute the CytoTRACE score, i.e. the ones that correlated positively with `'ct_num_exp_genes'`. Notes ----- This will not exactly reproduce the results of the original CytoTRACE algorithm :cite:`gulati:20` because we allow for any normalization and imputation techniques whereas CytoTRACE has built-in specific methods for that. """ from cellrank.tl import Lineage aggregation = CytoTRACEAggregation(aggregation) if use_raw and self.adata.raw is None: logg.warning("`adata.raw` is `None`. Setting `use_raw=False`") use_raw = False if layer not in (None, "X") and layer not in self.adata.layers: raise KeyError( f"Unable to find `{layer!r}` in `adata.layers`. " f"Valid option are: `{sorted({'X'} | set(self.adata.layers.keys()))}`." ) adata_mraw = self.adata.raw if use_raw else self.adata msg = f"Computing CytoTRACE score with `{adata_mraw.n_vars}` genes" if adata_mraw.n_vars < 10000: msg += ". Consider using more than `10000` genes" start = logg.info(msg) # compute number of expressed genes per cell num_exp_genes = np.asarray((adata_mraw.X > 0).sum(axis=1)).squeeze() logg.debug("Correlating all genes with number of genes expressed per cell") gene_corr = _correlation_test( adata_mraw.X, Lineage(num_exp_genes[:, None], names=["gene"]), gene_names=adata_mraw.var_names, )["gene_corr"] cytotrace_score, pos_top_genes = self._compute_score( gene_corr, layer=layer, aggregation=aggregation, ascending=False, n_genes=n_genes, ) cytotrace_score -= cytotrace_score.min() cytotrace_score /= cytotrace_score.max() self.adata.obs[Key.cytotrace("score")] = cytotrace_score self.adata.obs[Key.cytotrace("pseudotime")] = self._pseudotime = ( 1 - cytotrace_score ) self.adata.obs[Key.cytotrace("num_exp_genes")] = num_exp_genes self.adata.var[Key.cytotrace("gene_corr")] = gene_corr self.adata.var[Key.cytotrace("correlates")] = False self.adata.var.loc[pos_top_genes, Key.cytotrace("correlates")] = True self.adata.uns[Key.cytotrace("params")] = { "aggregation": aggregation, "layer": layer, "use_raw": use_raw, "n_genes": len(pos_top_genes), } logg.info( f"Adding `adata.obs[{Key.cytotrace('score')!r}]`\n" f" `adata.obs[{Key.cytotrace('pseudotime')!r}]`\n" f" `adata.obs[{Key.cytotrace('num_exp_genes')!r}]`\n" f" `adata.var[{Key.cytotrace('gene_corr')!r}]`\n" f" `adata.var[{Key.cytotrace('correlates')!r}]`\n" f" `adata.uns[{Key.cytotrace('params')!r}]`\n" f" Finish", time=start, ) return self
def _compute_score( self, gene_corr: pd.Series, *, layer: Optional[str] = None, aggregation: CytoTRACEAggregation.MEAN, ascending: bool = False, n_genes: int = 200, ) -> Tuple[np.ndarray, List[str]]: from sklearn.utils.sparsefuncs import csc_median_axis_0 # fmt: off modifier = "negatively" if ascending else "positively" if n_genes <= 0: raise ValueError(f"Expected number of {modifier} correlated genes to be positive, found `{n_genes}`.") top_genes = [g for g in gene_corr.sort_values(ascending=ascending).index if g in self.adata.var_names] top_genes = top_genes[:n_genes] invalid_genes = int(gene_corr.loc[top_genes].isnull().sum()) # fmt: on if invalid_genes: raise ValueError( f"Top `{len(top_genes)}` {modifier} correlated genes contain `{invalid_genes}` NaN values." ) if not len(top_genes): raise ValueError("No genes have been selected.") if len(top_genes) != n_genes: logg.warning( f"Unable to get requested top {modifier} correlated `{n_genes}`. " f"Using top `{len(top_genes)}` genes" ) # fmt: off if layer == "X": imputed_exp = self.adata[:, top_genes].X else: imputed_exp = self.adata[:, top_genes].layers[layer] if issparse(imputed_exp) and aggregation not in (CytoTRACEAggregation.MEAN, CytoTRACEAggregation.MEDIAN): imputed_exp = imputed_exp.A if aggregation == CytoTRACEAggregation.MEAN: cytotrace_score = np.asarray(imputed_exp.mean(axis=1)).reshape((-1,)) elif aggregation == CytoTRACEAggregation.MEDIAN: if issparse(imputed_exp): cytotrace_score = np.asarray(csc_median_axis_0(imputed_exp.T.tocsc())).reshape((-1,)) else: cytotrace_score = np.median(imputed_exp, axis=1) elif aggregation == CytoTRACEAggregation.GMEAN: cytotrace_score = gmean(imputed_exp, axis=1) elif aggregation == CytoTRACEAggregation.HMEAN: cytotrace_score = hmean(imputed_exp, axis=1) else: raise NotImplementedError(f"Aggregation method `{aggregation}` is not yet implemented.") # fmt: on return cytotrace_score, top_genes def __invert__(self) -> "CytoTRACEKernel": ck = self._copy_ignore("_transition_matrix") if ck.pseudotime is not None: # already computed ck._pseudotime = np.max(ck.pseudotime) - ck.pseudotime ck._backward = not self.backward ck._params = {} return ck