import enum
from typing import Any, Literal, Optional
import numpy as np
import pandas as pd
import scipy.sparse as sp
import scipy.stats as st
from anndata import AnnData
from cellrank import logging as logg
from cellrank._utils._docs import d, inject_docs
from cellrank._utils._enum import ModeEnum
from cellrank._utils._key import Key
from cellrank._utils._utils import _correlation_test
from cellrank.kernels._pseudotime_kernel import PseudotimeKernel
__all__ = ["CytoTRACEKernel"]
class CytoTRACEAggregation(ModeEnum):
MEAN = enum.auto()
MEDIAN = enum.auto()
GMEAN = enum.auto()
HMEAN = enum.auto()
[docs]
@d.dedent
class CytoTRACEKernel(PseudotimeKernel):
"""Kernel which computes directed transition probabilities using the CytoTRACE score :cite:`gulati:20`.
.. seealso::
- See :doc:`../../../notebooks/tutorials/kernels/400_cytotrace` on how to
compute the :attr:`~cellrank.kernels.CytoTRACEKernel.transition_matrix` based on the CytoTRACE score.
The k-NN graph contains information about the (undirected) connectivities among cells, reflecting their similarity.
CytoTRACE can be used to estimate cellular plasticity and in turn, a pseudotemporal ordering of cells from more
plastic to less plastic states. It relies on the assumption that differentiated cells express, on average,
fewer genes than naive cells.
This kernel internally uses the :class:`~cellrank.kernels.PseudotimeKernel` to direct the k-NN graph
on the basis of the CytoTRACE pseudotime.
%(density_correction)s
Parameters
----------
%(adata)s
%(backward)s
kwargs
Keyword arguments for the :class:`~cellrank.kernels.PseudotimeKernel`.
Examples
--------
.. code-block::
import scvelo as scv
import cellrank as cr
adata = cr.datasets.pancreas()
sc.pp.filter_genes(adata, min_cells=10)
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata)
# CytoTRACE by default uses imputed data - a simple way to compute
# k-NN imputed data is to use scVelo's moments function.
# However, note that this function expects `spliced` counts because
# it's designed for RNA velocity, so we're using a simple hack here:
if 'spliced' not in adata.layers or 'unspliced' not in adata.layers:
adata.layers['spliced'] = adata.X
adata.layers['unspliced'] = adata.X
scv.pp.moments(adata)
ctk = cr.kernels.CytoTRACEKernel(adata)
ckt = ctk.compute_cytotrace().compute_transition_matrix()
"""
def __init__(
self,
adata: AnnData,
backward: bool = False,
**kwargs: Any,
):
_ = kwargs.pop("time_key", None)
super().__init__(
adata,
backward=backward,
time_key=Key.cytotrace("pseudotime"),
**kwargs,
)
def _read_from_adata(
self,
time_key: str,
**kwargs: Any,
) -> None:
try:
super()._read_from_adata(time_key=time_key, **kwargs)
except KeyError as e:
self._pseudotime: Optional[np.ndarray] = None
if "Unable to find pseudotime" not in str(e):
raise
super(PseudotimeKernel, self)._read_from_adata(**kwargs)
[docs]
@d.get_sections(base="cytotrace", sections=["Parameters"])
@inject_docs(ct=CytoTRACEAggregation)
def compute_cytotrace(
self,
layer: Optional[str] = "Ms",
aggregation: Literal["mean", "median", "hmean", "gmean"] = CytoTRACEAggregation.MEAN,
use_raw: bool = False,
n_genes: int = 200,
) -> "CytoTRACEKernel":
"""Re-implementation of the CytoTRACE algorithm :cite:`gulati:20` to estimate cellular plasticity.
Computes the number of genes expressed per cell and ranks genes according to their correlation with this
measure. Next, it selects to top-correlating genes and aggregates their (imputed) expression to obtain
the CytoTRACE score. A high score stands for high differentiation potential (naive, plastic cells) and
a low score stands for low differentiation potential (mature, differentiation cells).
Parameters
----------
layer
Key in :attr:`~anndata.AnnData.layers` or ``'X'`` for :attr:`~anndata.AnnData.X`
from where to get the expression.
aggregation
How to aggregate expression of the top-correlating genes. Valid options are:
- ``{ct.MEAN!r}`` - arithmetic mean.
- ``{ct.MEDIAN!r}`` - median.
- ``{ct.HMEAN!r}`` - harmonic mean.
- ``{ct.GMEAN!r}`` - geometric mean.
use_raw
Whether to use the :attr:`~anndata.AnnData.raw` to compute the number of genes expressed per cell
(#genes/cell) and the correlation of gene expression across cells with #genes/cell.
n_genes
Number of top positively correlated genes to compute the CytoTRACE score.
Returns
-------
Nothing, just modifies :attr:`~anndata.AnnData.obs` with the following keys:
- ``'ct_score'`` - the normalized CytoTRACE score.
- ``'ct_pseudotime'`` - associated pseudotime, essentially ``1 - CytoTRACE score``.
- ``'ct_num_exp_genes'`` - the number of genes expressed per cell, basis of the CytoTRACE score.
It also modifies :attr:`~anndata.AnnData.var` with the following keys:
- ``'ct_gene_corr'`` - the correlation as specified above.
- ``'ct_correlates'`` - indication of the genes used to compute the CytoTRACE score, i.e. the ones that
correlated positively with ``'ct_num_exp_genes'``.
Notes
-----
This will not exactly reproduce the results of the original CytoTRACE algorithm :cite:`gulati:20` because we
allow for any normalization and imputation techniques whereas CytoTRACE has built-in specific methods for that.
"""
from cellrank._utils import Lineage
aggregation = CytoTRACEAggregation(aggregation)
if use_raw and self.adata.raw is None:
logg.warning("`adata.raw` is `None`. Setting `use_raw=False`")
use_raw = False
if layer not in (None, "X") and layer not in self.adata.layers:
raise KeyError(
f"Unable to find `{layer!r}` in `adata.layers`. "
f"Valid option are: `{sorted({'X'} | set(self.adata.layers.keys()))}`."
)
adata_mraw = self.adata.raw if use_raw else self.adata
msg = f"Computing CytoTRACE score with `{adata_mraw.n_vars}` genes"
if adata_mraw.n_vars < 10000:
msg += ". Consider using more than `10000` genes"
start = logg.info(msg)
# compute number of expressed genes per cell
num_exp_genes = np.asarray((adata_mraw.X > 0).sum(axis=1)).squeeze()
logg.debug("Correlating all genes with number of genes expressed per cell")
gene_corr = _correlation_test(
adata_mraw.X,
Lineage(num_exp_genes[:, None], names=["gene"]),
gene_names=adata_mraw.var_names,
)["gene_corr"]
cytotrace_score, pos_top_genes = self._compute_score(
gene_corr,
layer=layer,
aggregation=aggregation,
ascending=False,
n_genes=n_genes,
)
cytotrace_score -= cytotrace_score.min()
cytotrace_score /= cytotrace_score.max()
self.adata.obs[Key.cytotrace("score")] = cytotrace_score
self.adata.obs[Key.cytotrace("pseudotime")] = self._pseudotime = 1 - cytotrace_score
self.adata.obs[Key.cytotrace("num_exp_genes")] = num_exp_genes
self.adata.var[Key.cytotrace("gene_corr")] = gene_corr
self.adata.var[Key.cytotrace("correlates")] = False
self.adata.var.loc[pos_top_genes, Key.cytotrace("correlates")] = True
self.adata.uns[Key.cytotrace("params")] = {
"aggregation": str(aggregation),
"layer": layer,
"use_raw": use_raw,
"n_genes": len(pos_top_genes),
}
logg.info(
f"Adding `adata.obs[{Key.cytotrace('score')!r}]`\n"
f" `adata.obs[{Key.cytotrace('pseudotime')!r}]`\n"
f" `adata.obs[{Key.cytotrace('num_exp_genes')!r}]`\n"
f" `adata.var[{Key.cytotrace('gene_corr')!r}]`\n"
f" `adata.var[{Key.cytotrace('correlates')!r}]`\n"
f" `adata.uns[{Key.cytotrace('params')!r}]`\n"
f" Finish",
time=start,
)
return self
def _compute_score(
self,
gene_corr: pd.Series,
*,
layer: Optional[str] = None,
aggregation: CytoTRACEAggregation.MEAN,
ascending: bool = False,
n_genes: int = 200,
) -> tuple[np.ndarray, list[str]]:
from sklearn.utils.sparsefuncs import csc_median_axis_0
# fmt: off
modifier = "negatively" if ascending else "positively"
if n_genes <= 0:
raise ValueError(f"Expected number of {modifier} correlated genes to be positive, found `{n_genes}`.")
top_genes = [g for g in gene_corr.sort_values(ascending=ascending).index if g in self.adata.var_names]
top_genes = top_genes[:n_genes]
invalid_genes = int(gene_corr.loc[top_genes].isnull().sum())
# fmt: on
if invalid_genes:
raise ValueError(
f"Top `{len(top_genes)}` {modifier} correlated genes contain `{invalid_genes}` NaN values."
)
if not len(top_genes):
raise ValueError("No genes have been selected.")
if len(top_genes) != n_genes:
logg.warning(
f"Unable to get requested top {modifier} correlated `{n_genes}`. " f"Using top `{len(top_genes)}` genes"
)
# fmt: off
imputed_exp = self.adata[:, top_genes].X if layer == "X" else self.adata[:, top_genes].layers[layer]
if sp.issparse(imputed_exp) and aggregation not in (CytoTRACEAggregation.MEAN, CytoTRACEAggregation.MEDIAN):
imputed_exp = imputed_exp.toarray()
if aggregation == CytoTRACEAggregation.MEAN:
cytotrace_score = np.asarray(imputed_exp.mean(axis=1)).reshape((-1,))
elif aggregation == CytoTRACEAggregation.MEDIAN:
if sp.issparse(imputed_exp):
cytotrace_score = np.asarray(csc_median_axis_0(imputed_exp.T.tocsc())).reshape((-1,))
else:
cytotrace_score = np.median(imputed_exp, axis=1)
elif aggregation == CytoTRACEAggregation.GMEAN:
cytotrace_score = st.gmean(imputed_exp, axis=1)
elif aggregation == CytoTRACEAggregation.HMEAN:
cytotrace_score = st.hmean(imputed_exp, axis=1)
else:
raise NotImplementedError(f"Aggregation method `{aggregation}` is not yet implemented.")
# fmt: on
return cytotrace_score, top_genes