from typing import Any, Dict, Tuple, Union, Mapping, Optional
from typing_extensions import Literal
from types import MappingProxyType
from tqdm.auto import tqdm
import scanpy as sc
from anndata import AnnData
from cellrank import logging as logg
from cellrank.ul._docs import d, inject_docs
from cellrank.tl._utils import _maybe_subset_hvgs
from cellrank.external.kernels._utils import MarkerGenes
from cellrank.tl.kernels._exp_time_kernel import LastTimePoint
import numpy as np
import pandas as pd
_error = None
try:
import wot
from cellrank.tl.kernels import TransportMapKernel as Kernel
except ImportError as e:
from cellrank.external.kernels._import_error_kernel import ErroredKernel as Kernel
_error = e
wot = None
class nstr(str): # used for params+cache (precomputed cost matrices)
"""String class that is not equal to any other string."""
def __eq__(self, other: str) -> bool:
return False
[docs]@d.dedent
class WOTKernel(Kernel, error=_error):
"""
Waddington optimal transport kernel from :cite:`schiebinger:19`.
This class requires the `wot` package, which can be installed as
`pip install git+https://github.com/broadinstitute/wot`.
Parameters
----------
%(adata)s
%(backward)s
time_key
Key in :attr:`adata` ``.obs`` where experimental time is stored.
The experimental time can be of either of a numeric or an ordered categorical type.
%(cond_num)s
Examples
--------
Workflow::
# import packages, load data
import scanpy as sc
import cellrank as cr
adata = cr.datasets.lung()
# filter, normalize and annotate highly variable genes
sc.pp.filter_genes(adata, min_cells=10)
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata)
# estimate proliferation and apoptosis from gene sets (see. e.g. WOT tutorial for example lists)
proliferation_genes = ...
apoptosis_genes = ...
sc.tl.score_genes(adata, gene_list=proliferation_genes, score_name='proliferation')
sc.tl.score_genes(adata, gene_list=apoptosis_genes, score_name='apoptosis')
# initialize kernel, estimate initial growth rate based on scores from above
from cellrank.external.kernels import WOTKernel
ot = WOTKernel(adata, time_key='day')
ot.compute_initial_growth_rates(proliferation_key='proliferation',
apoptosis_key='apoptosis',
key_added='initial_growth_rates')
# compute transport maps, aggregate into one single transition matrix
ot.compute_transition_matrix(growth_rate_key='initial_growth_rates', growth_iters=3)
"""
__import_error_message__ = (
"Unable to import the kernel. Please install `wot` first as "
"`pip install git+https://github.com/broadinstitute/wot`."
)
def __init__(
self,
adata: AnnData,
backward: bool = False,
time_key: str = "exp_time",
compute_cond_num: bool = False,
**kwargs: Any,
):
super().__init__(
adata,
backward=backward,
time_key=time_key,
compute_cond_num=compute_cond_num,
**kwargs,
)
# WOT's requirements
cats = self.experimental_time.cat.categories
self._exp_time = self.experimental_time.cat.rename_categories(
dict(zip(cats, map(float, cats)))
)
self.adata.obs[self._time_key] = self.experimental_time
self._growth_rates = None
def _read_from_adata(
self,
conn_key: Optional[str] = "connectivities",
read_conn: bool = True,
**kwargs: Any,
) -> None:
super()._read_from_adata(conn_key=conn_key, read_conn=False, **kwargs)
[docs] def compute_initial_growth_rates(
self,
proliferation_key: Optional[str] = None,
apoptosis_key: Optional[str] = None,
organism: Optional[Literal["human", "mouse"]] = None,
beta_min: float = 0.3,
beta_max: float = 1.7,
delta_min: float = 0.3,
delta_max: float = 1.7,
key_added: Optional[str] = None,
**kwargs: Any,
) -> Optional[pd.Series]:
r"""
Estimate initial growth rates using a birth-death process as described in :cite:`schiebinger:19`.
The doubling time is defined as :math:`\frac{\ln 2}{\beta - \delta}` (similarly defined for half-time).
The logistic function is used to transform the birth/death rate scores and to smoothly interpolate between
specified minimum and maximum birth/death rates.
Parameters
----------
proliferation_key
Key in :attr:`anndata.AnnData.obs` where the birth rate score is saved.
apoptosis_key
Key in :attr:`anndata.AnnData.obs` where the death rate score is saved.
organism
Organism for which to calculate the birth/death scores, if they cannot be found in :attr:`adata`.
In this case, :func:`scanpy.tl.score_genes` is used to calculate the scores based on an organism-dependent
set of marker genes for proliferation and apoptosis.
beta_min
Minimum birth rate.
beta_max
Maximum birth rate.
delta_min
Minimum death rate.
delta_max
Maximum death rate.
key_added
Key in :attr:`adata` ``.obs`` where to add the estimated growth rates. If `None`, just return them.
kwargs
Keyword arguments for :func:`scanpy.tl.score_genes`. Only used when ``proliferation_key``
or ``apoptosis_key`` cannot be found in :attr:`adata.AnnData.obs`.
Returns
-------
:class:`pandas.Series`
The estimated initial growth rates if ``key_added = None``, otherwise `None`.
Notes
-----
If you don't have access to proliferation/apoptosis gene sets, you can use the ones defined in :mod:`cellrank`
for a specific organism. Alternatively, you can also use WOT without an estimate of initial growth rates. In
that case, make sure to use several iterations in
:meth:`cellrank.external.kernels.WOTKernel.compute_transition_matrix` by increasing the ``growth_iters``
parameter. A value around 3 works well in most cases.
The markers used here were taken from the following sources:
- human, proliferation - :cite:`tirosh:16:science`.
- human, apoptosis - `Hallmark Apoptosis, MSigDB <https://www.gsea-msigdb.org/gsea/msigdb/cards/HALLMARK_APOPTOSIS>`_.
- mouse, proliferation - :cite:`tirosh:16:nature`.
- mouse, apoptosis - `Hallmark P53 Pathway, MSigDB <https://www.gsea-msigdb.org/gsea/msigdb/cards/HALLMARK_P53_PATHWAY>`_.
For more information about WOT, see the official `tutorial <https://broadinstitute.github.io/wot/tutorial/>`_.
""" # noqa: E501
def get_scores(
key: str,
*,
kind: Literal["proliferation", "apoptosis"],
organism: Optional[Literal["human", "mouse"]],
) -> np.ndarray:
try:
return np.asarray(self.adata.obs[key])
except KeyError:
if organism is None:
raise KeyError(
f"Unable to find `{kind}` scores in `adata.obs[{kind!r}]`. Consider specifying `organism=...`."
) from None
logg.info(f"Computing `{kind}` scores")
score_name = f"{kind}_score" if key is None else key
sc.tl.score_genes(
self.adata,
gene_list=getattr(MarkerGenes, f"{kind}_markers")(organism),
score_name=score_name,
**kwargs,
)
return get_scores(score_name, kind=kind, organism=None)
def logistic(x: np.ndarray, L: float, k: float, x0: float = 0) -> np.ndarray:
return L / (1 + np.exp(-k * (x - x0)))
def gen_logistic(
p: np.ndarray, beta_max: float, beta_min: float, center: float, width: float
) -> np.ndarray:
return beta_min + logistic(p, L=beta_max - beta_min, k=4 / width, x0=center)
def beta(
p: np.ndarray,
beta_max: float = 1.7,
beta_min: float = 0.3,
center: float = 0.25,
) -> np.ndarray:
return gen_logistic(p, beta_max, beta_min, center, width=0.5)
def delta(
a: np.ndarray,
delta_max: float = 1.7,
delta_min: float = 0.3,
center: float = 0.1,
) -> np.ndarray:
return gen_logistic(a, delta_max, delta_min, center, width=0.2)
birth = beta(
get_scores(proliferation_key, kind="proliferation", organism=organism),
beta_min=beta_min,
beta_max=beta_max,
)
death = delta(
get_scores(apoptosis_key, kind="apoptosis", organism=organism),
delta_min=delta_min,
delta_max=delta_max,
)
gr = np.exp(birth - death)
if key_added is None:
return pd.Series(gr, index=self.adata.obs_names)
self.adata.obs[key_added] = gr
[docs] @inject_docs(ltp=LastTimePoint)
def compute_transition_matrix(
self,
cost_matrices: Optional[
Union[str, Mapping[Tuple[float, float], np.ndarray]]
] = None,
lambda1: float = 1,
lambda2: float = 50,
epsilon: float = 0.05,
growth_iters: int = 1,
solver: Literal["fixed_iters", "duality_gap"] = "duality_gap",
growth_rate_key: Optional[str] = None,
use_highly_variable: Optional[Union[str, bool]] = True,
last_time_point: Literal[
"uniform", "diagonal", "connectivities"
] = LastTimePoint.UNIFORM,
threshold: Optional[Union[float, Literal["auto"]]] = "auto",
conn_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> "WOTKernel":
"""
Compute transition matrix using Waddington OT :cite:`schiebinger:19`.
Computes transport maps linking together pairs of time points for time-series single cell data using unbalanced
optimal transport, taking into account cell birth and death rates. From the sequence of transition maps linking
pairs of sequential time points, we construct one large transition matrix which contains the normalized
transport maps as blocks on the 1st upper diagonal.
Parameters
----------
cost_matrices
Cost matrices for each consecutive pair of time points.
If a :class:`str`, it specifies a key in :attr:`anndata.AnnData.layers` or :attr:`anndata.AnnData.obsm`.
containing cell features that are used to compute cost matrices. If `None`, use `WOT`'s default, i.e.
compute distances in PCA space derived from :attr:`anndata.AnnData.X` for each time point pair separately.
lambda1
Regularization parameter for the marginal constraint on :math:`p`, the transport map row sums.
Smaller value is useful when precise information about the growth rate is not present.
lambda2
Regularization parameter for the marginal constraint on :math:`q`, the transport map column sums.
epsilon
Entropy regularization parameter. Larger value gives more entropic descendant distributions.
growth_iters
Number of iterations for growth rate estimates. If growth rates are not known, consider using more
iterations.
solver
Which solver to use.
growth_rate_key
Key in :attr:`anndata.AnnData.obs` where initial cell growth rates are stored.
See :meth:`compute_initial_growth_rates` on how to estimate them.
use_highly_variable
Key in :attr:`anndata.AnnData.var` where highly variable genes are stored.
If `True`, use `'highly_variable'`. If `None`, use all genes.
last_time_point
How to define transitions within the last time point. Valid options are:
- `{ltp.UNIFORM!r}` - row-normalized matrix of 1s for transitions within the last time point.
- `{ltp.DIAGONAL!r}` - diagonal matrix with 1s on the diagonal.
- `{ltp.CONNECTIVITIES!r}` - use transitions from :class:`cellrank.tl.kernels.ConnectivityKernel`
derived from the last time point subset of :attr:`adata`.
threshold
How to remove small non-zero values from the transition matrix. Valid options are:
- `'auto'` - find the maximum threshold value which will not remove every non-zero value from any row.
- :class:`float` - value in `[0, 100]` corresponding to a percentage of non-zeros to remove.
Rows where all values are removed will have uniform distribution.
- `None` - do not threshold.
conn_kwargs
Keyword arguments for :func:`scanpy.pp.neighbors`, when using ``last_time_point = {ltp.CONNECTIVITIES!r}``.
Can contain `'density_normalize'` for
:meth:`cellrank.tl.kernels.ConnectivityKernel.compute_transition_matrix`.
kwargs
Additional keyword arguments for optimal transport configuration.
Returns
-------
:class:`cellrank.external.kernels.WOTKernel`
Makes :attr:`transition_matrix`, :attr:`transport_maps` and :attr:`growth_rates` available.
Also modifies :attr:`anndata.AnnData.obs` with the following key:
- `'estimated_growth_rates'` - the estimated final growth rates.
Notes
-----
For more information about WOT, see the official `tutorial <https://broadinstitute.github.io/wot/tutorial/>`_.
"""
# disallow these params (because they e.g. subset the data)
_ = kwargs.pop("cell_day_filter", None)
_ = kwargs.pop("covariate_field", None)
_ = kwargs.pop("ncounts", None)
_ = kwargs.pop("ncells", None)
_ = kwargs.pop("parameters", None) # parameters file
kwargs["lambda1"] = lambda1
kwargs["lambda2"] = lambda2
kwargs["epsilon"] = epsilon
kwargs["growth_iters"] = max(growth_iters, 1)
last_time_point = LastTimePoint(last_time_point)
start = logg.info(
"Computing transition matrix using Waddington optimal transport"
)
adata = _maybe_subset_hvgs(self.adata, use_highly_variable=use_highly_variable)
cost_matrices, cmat_param = self._generate_cost_matrices(adata, cost_matrices)
if self._reuse_cache(
{
"cost_matrices": cmat_param,
"solver": solver,
"growth_rate_key": growth_rate_key,
"use_highly_variable": use_highly_variable,
"last_time_point": last_time_point,
"threshold": threshold,
**kwargs,
},
time=start,
):
return self
tmap = self._compute_pairwise_tmaps(
adata,
cost_matrices=cost_matrices,
solver=solver,
growth_rate_field=growth_rate_key,
**kwargs,
)
tmap = self._restich_tmaps(tmap, last_time_point, conn_kwargs=conn_kwargs)
self._growth_rates = tmap.obs
self._compute_transition_matrix(
matrix=tmap.X,
density_normalize=False,
check_irreducibility=False,
)
if threshold:
self._threshold_transition_matrix(threshold)
self.adata.obs["estimated_growth_rates"] = self.growth_rates[f"g{growth_iters}"]
logg.info(" Finish", time=start)
return self
def _compute_pairwise_tmaps(
self,
adata: AnnData,
cost_matrices: Optional[
Union[str, Mapping[Tuple[float, float], np.ndarray]]
] = None,
solver: Literal["fixed_iters", "duality_gap"] = "duality_gap",
growth_rate_field: Optional[str] = None,
**kwargs: Any,
) -> Dict[Tuple[float, float], AnnData]:
_ = wot.ot.OTModel(
adata,
day_field=self._time_key,
covariate_field=None,
growth_rate_field=growth_rate_field,
)
for k in kwargs:
if k not in _.ot_config:
raise TypeError(f"WOT got an unexpected keyword argument {k!r}.")
self._ot_model = wot.ot.OTModel(
adata,
day_field=self._time_key,
covariate_field=None,
growth_rate_field=growth_rate_field,
solver=solver,
**kwargs,
)
self._tmaps: Dict[Tuple[float, float], AnnData] = {}
start = logg.info(
f"Computing transport maps for `{len(cost_matrices)}` time pairs"
)
for tpair, cost_matrix in tqdm(cost_matrices.items(), unit="time pair"):
tmap: Optional[AnnData] = self._ot_model.compute_transport_map(
*tpair, cost_matrix=cost_matrix
)
if tmap is None:
raise TypeError(
f"Unable to compute transport map for time pair `{tpair}`. "
f"Please ensure `adata.obs[{self._time_key!r}]` has the correct dtype (float)."
)
tmap.X = tmap.X.astype(np.float64)
nans = int(np.sum(~np.isfinite(tmap.X)))
if nans:
raise ValueError(
f"Encountered `{nans}` non-finite values for time pair `{tpair}`."
)
self._tmaps[tpair] = tmap
logg.info(" Finish", time=start)
return self._tmaps
def _generate_cost_matrices(
self,
adata: AnnData,
cost_matrices: Optional[
Union[str, Mapping[Tuple[float, float], np.ndarray]]
] = None,
) -> Tuple[Mapping[Tuple[float, float], Optional[np.ndarray]], str]:
timepoints = self.experimental_time.cat.categories
timepoints = list(zip(timepoints[:-1], timepoints[1:]))
if cost_matrices is None:
logg.info("Using default cost matrices")
return {tpair: None for tpair in timepoints}, "default"
if isinstance(cost_matrices, dict):
logg.info("Using precomputed cost matrices")
cmats = {}
for tpair in timepoints:
if tpair not in cost_matrices:
logg.warning(
f"Unable to find cost matrix for pair `{tpair}`. Using default"
)
cmats[tpair] = cmat = cost_matrices.get(tpair, None)
if cmat is not None:
n_start = len(np.where(self.experimental_time == tpair[0])[0])
n_end = len(np.where(self.experimental_time == tpair[1])[0])
try:
if cmat.shape != (n_start, n_end):
raise ValueError(
f"Expected cost matrix for time pair `{tpair}` to be "
f"of shape `{(n_start, n_end)}`, found `{cmat.shape}`."
)
except AttributeError:
logg.warning(
f"Unable to verify whether supplied cost matrix for time pair `{tpair}` "
f"has the correct shape `{(n_start, n_end)}`"
)
# prevent equality comparison when comparing with cache
return cmats, nstr("precomputed")
if isinstance(cost_matrices, str):
logg.info(f"Computing cost matrices using `{cost_matrices!r}` key")
if cost_matrices == "X":
cost_matrices = None
try:
features = adata._get_X(layer=cost_matrices)
modifier = "layer"
except KeyError:
try:
features = adata.obsm[cost_matrices]
modifier = "obsm"
except KeyError:
raise KeyError(
f"Unable to find key `{cost_matrices!r}` in `adata.layers` or `adata.obsm`."
) from None
cmats = {}
for tpair in tqdm(timepoints, unit="cost matrix"):
start_ixs = np.where(self.experimental_time == tpair[0])[0]
end_ixs = np.where(self.experimental_time == tpair[1])[0]
# being sparse is handled in WOT's function below
cmats[tpair] = wot.ot.OTModel.compute_default_cost_matrix(
features[start_ixs], features[end_ixs]
)
return cmats, f"{modifier}:{cost_matrices}"
raise NotImplementedError(
f"Specifying cost matrices as "
f"`{type(cost_matrices).__name__}` is not yet implemented."
)
def _threshold_transition_matrix(
self, threshold: Union[float, Literal["auto"]]
) -> None:
tmat = self.transition_matrix
if threshold == "auto":
threshold = min(np.max(tmat[i].data) for i in range(tmat.shape[0]))
logg.info(f"Using `threshold={threshold}`")
tmat.data[tmat.data < threshold] = 0.0
else:
if not (0 <= threshold <= 100):
raise ValueError(
f"Expected `threshold to be in `[0, 100]`, found `{threshold}`.`"
)
threshold = np.percentile(tmat.data, threshold)
logg.info(f"Using `threshold={threshold}`")
tmat.data[tmat.data <= threshold] = 0.0
tmat.eliminate_zeros()
self._compute_transition_matrix(
matrix=tmat,
density_normalize=False,
check_irreducibility=False,
)
@property
def growth_rates(self) -> Optional[pd.DataFrame]:
"""Estimated cell growth rates for each growth rate iteration."""
return self._growth_rates
def __invert__(self) -> "WOTKernel":
super().__invert__()
# because WOT reads from `adata`
self.adata.obs[self._time_key] = self.experimental_time
return self