import itertools
import os
import pathlib
import types
from collections.abc import Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import scipy.sparse as sp
from pandas.api.types import infer_dtype
import scanpy as sc
from anndata import AnnData
from cellrank import logging as logg
from cellrank._utils._docs import d, inject_docs
from cellrank._utils._enum import ModeEnum
from cellrank._utils._utils import _normalize
from cellrank.kernels._base_kernel import UnidirectionalKernel
from cellrank.settings import settings
__all__ = ["RealTimeKernel"]
if TYPE_CHECKING:
from moscot.problems.spatiotemporal import SpatioTemporalProblem
from moscot.problems.time import LineageProblem, TemporalProblem
class SelfTransitions(ModeEnum):
UNIFORM = "uniform"
DIAGONAL = "diagonal"
CONNECTIVITIES = "connectivities"
ALL = "all"
Key_t = tuple[Any, Any]
Threshold_t = Union[int, float, Literal["auto", "auto_local"]]
Coupling_t = Union[np.ndarray, sp.spmatrix, AnnData]
# TODO(michalk8): subclass the `ExperimentalTimeKernel`
[docs]
@d.dedent
class RealTimeKernel(UnidirectionalKernel):
"""Kernel which computes transition matrix using optimal transport couplings.
.. seealso::
- See :doc:`../../../notebooks/tutorials/kernels/500_real_time` on how to
compute the :attr:`~cellrank.kernels.RealTimeKernel.transition_matrix` between experimental time points
using `optimal transport <https://en.wikipedia.org/wiki/Transportation_theory_(mathematics)>`_.
This class should be constructed using either:
1. the :meth:`from_moscot` or :meth:`from_wot` method,
2. explicitly passing the pre-computed couplings, or
3. overriding the :meth:`compute_coupling` method.
Parameters
----------
adata
Annotated data object.
time_key
Key in :attr:`~anndata.AnnData.obs` containing the experimental time.
couplings
Pre-computed transport couplings. The keys should correspond to a :class:`tuple` of categories from
the :attr:`time`. If :obj:`None`, the keys will be constructed using the ``policy``
and the :meth:`compute_coupling` method must be overriden.
policy
How to construct the keys from the :attr:`time` when ``couplings = None``:
- ``policy = 'sequential'`` - the keys will be set to ``[(t1, t2), (t2, t3), ...]``.
- ``policy = 'triu'`` - the keys will be set to ``[(t1, t2), (t1, t3), ..., (t2, t3), ...]``.
kwargs
Keyword arguments for the :class:`~cellrank.kernels.Kernel`.
"""
def __init__(
self,
adata: AnnData,
time_key: str,
couplings: Optional[Mapping[Key_t, Optional[Coupling_t]]] = None,
policy: Literal["sequential", "triu"] = "sequential",
**kwargs: Any,
):
super().__init__(adata, time_key=time_key, **kwargs)
self._obs: Optional[pd.DataFrame] = None
self._couplings, self._reference = self._get_default_coupling(couplings, policy)
def _read_from_adata(
self,
time_key: str,
**kwargs: Any,
) -> None:
super()._read_from_adata(**kwargs)
self._time = self.adata.obs[time_key].copy()
if not isinstance(self._time.dtype, pd.CategoricalDtype):
raise TypeError(f"Expected `adata.obs[{time_key!r}]` to be categorical, found `{infer_dtype(self._time)}`.")
self._time = self._time.cat.remove_unused_categories()
cats = self._time.cat.categories
self._time_to_ix = dict(zip(cats, range(len(cats))))
[docs]
def compute_coupling(self, src: Any, tgt: Any, **kwargs: Any) -> Coupling_t:
"""Compute coupling for a given time pair.
.. note::
This implementation only looks up the values in :attr:`couplings`,
see :meth:`from_moscot` or :meth:`from_wot` on how to initialize them.
Parameters
----------
src
Source key in :attr:`time`.
tgt
Target key in :attr:`time`.
kwargs
Additional keyword arguments.
Returns
-------
Array of shape ``(n_source_cells, n_target_cells)``.
"""
del kwargs
if (src, tgt) not in self.couplings:
raise KeyError(f"Key `{src, tgt}` not found in the `couplings`.")
if self.couplings[src, tgt] is None:
raise ValueError(
f"Coupling for `{src, tgt}` is not computed. " f"Consider overriding the `compute_coupling` method."
)
return self.couplings[src, tgt]
[docs]
@inject_docs(st=SelfTransitions)
def compute_transition_matrix(
self,
threshold: Optional[Threshold_t] = "auto",
self_transitions: Union[
Literal["uniform", "diagonal", "connectivities", "all"],
Sequence[Any],
] = SelfTransitions.CONNECTIVITIES,
conn_weight: Optional[float] = None,
conn_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
) -> "RealTimeKernel":
"""Compute transition matrix from optimal transport couplings.
Parameters
----------
threshold
How to remove small non-zero values from the transition matrix. Valid options are:
- ``'auto'`` - find the maximum threshold value which will not remove every non-zero value from any row.
- ``'auto_local'`` - same as above, but done for each transport separately.
- :class:`float` - value in :math:`[0, 100]` corresponding to a percentage of non-zeros to remove in
the couplings.
Rows where all values are removed will have a uniform distribution and a warning will be issued.
self_transitions
How to define transitions within the blocks that correspond to transitions within the same key.
Valid options are:
- ``{st.UNIFORM!r}`` - row-normalized matrix of :math:`1`.
- ``{st.DIAGONAL!r}`` - identity matrix.
- ``{st.CONNECTIVITIES!r}`` - transition matrix from the :class:`~cellrank.kernels.ConnectivityKernel`.
- :class:`~typing.Sequence`` - sequence of source keys defining which blocks should be weighted
by the connectivities.
- ``{st.ALL!r}`` - same as above, but for all keys.
The first 3 options are applied to the block specified by the :attr:`reference`.
conn_weight
Weight of connectivities' self transitions. Only used when ``self_transitions = {st.ALL!r}`` or
a sequence of source keys is passed.
conn_kwargs
Keyword arguments for :func:`~scanpy.pp.neighbors` or
:meth:`~cellrank.kernels.ConnectivityKernel.compute_transition_matrix` when using
``self_transitions = 'connectivities'``.
kwargs
Keyword arguments for :meth:`compute_coupling`.
Returns
-------
Returns self and updates :attr:`transition_matrix`, :attr:`couplings` and :attr:`params`.
"""
cache_params = dict(kwargs)
cache_params["threshold"] = threshold
cache_params["self_transitions"] = str(self_transitions)
if self._reuse_cache(cache_params):
return self
if (self_transitions == "all" or isinstance(self_transitions, (tuple, list))) and (
conn_weight is None or not (0 < conn_weight < 1)
):
raise ValueError(f"Expected `conn_weight` to be in interval `(0, 1)`, found `{conn_weight}`.")
for src, tgt in tqdm(self._couplings, unit="time pair"):
coupling = self.compute_coupling(src, tgt, **kwargs)
self.couplings[src, tgt] = self._coupling_to_adata(src, tgt, coupling)
if threshold is not None:
self._sparsify_couplings(self.couplings, threshold=threshold, copy=False)
tmap = self._restich_couplings(
self.couplings,
self_transitions=self_transitions,
conn_weight=conn_weight,
**conn_kwargs,
)
self.transition_matrix = tmap.X
self._obs = tmap.obs
return self
[docs]
@classmethod
def from_moscot(
cls,
problem: "Union[TemporalProblem, LineageProblem, SpatioTemporalProblem]",
sparse_mode: Optional[Literal["threshold", "percentile", "min_row"]] = None,
sparsify_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
copy: bool = False,
**kwargs: Any,
) -> "RealTimeKernel":
"""Construct the kernel from :mod:`moscot` :cite:`klein:23`.
Parameters
----------
problem
:mod:`moscot` problem.
sparse_mode
Sparsification mode for :meth:`~moscot.base.output.BaseSolverOutput.sparsify`.
If :obj:`None`, do not sparsify the outputs. Note that :meth:`compute_transition_matrix`
can also sparsify the final :attr:`transition_matrix`.
sparsify_kwargs
Keyword arguments for the sparsification.
copy
Whether to copy the underlying arrays. Note that :class:`jax arrays <jax.Array>` are always copied.
kwargs
Keyword arguments for :class:`~cellrank.kernels.RealTimeKernel`.
Returns
-------
The kernel.
Examples
--------
.. code-block:: python
import moscot as mt
import cellrank as cr
adata = mt.datasets.hspc()
adata.obs["day"] = adata.obs["day"].astype("category")
problem = mt.problems.TemporalProblem(adata)
problem = problem.prepare(time_key="day").solve()
rtk = cr.kernels.RealTimeKernel.from_moscot(problem)
rtk = rtk.compute_transition_matrix()
"""
from moscot.utils.subset_policy import SequentialPolicy, TriangularPolicy
if not problem.solutions:
raise RuntimeError("Problem contains no solutions. Please run `problem.solve(...)` first.")
policy = problem._policy
if isinstance(policy, SequentialPolicy):
policy = "sequential"
elif isinstance(policy, TriangularPolicy):
policy = "triu"
else:
raise NotImplementedError(f"Handling `{type(policy)}` policy is not yet implemented.")
couplings = {}
for (t1, t2), solution in problem.solutions.items():
adata_src = problem[t1, t2].adata_src
adata_tgt = problem[t1, t2].adata_tgt
if sparse_mode is not None:
solution = solution.sparsify(mode=sparse_mode, **sparsify_kwargs)
coupling = solution.transport_matrix
if not (isinstance(coupling, np.ndarray) or sp.issparse(coupling)):
# convert from, e.g., `jax`
# we always copy because otherwise the array is read-only and we may sparsify
coupling = np.array(coupling, copy=True)
elif copy:
coupling = coupling.copy()
couplings[t1, t2] = AnnData(coupling, obs=adata_src.obs, var=adata_tgt.obs)
return cls(
problem.adata,
couplings=couplings,
time_key=problem.temporal_key,
policy=policy,
**kwargs,
)
[docs]
@classmethod
def from_wot(
cls,
adata: AnnData,
path: Union[str, pathlib.Path],
time_key: str,
**kwargs: Any,
) -> "RealTimeKernel":
"""Construct the kernel from Waddington-OT :cite:`schiebinger:19`.
Parameters
----------
adata
Annotated data object.
path
Directory where the couplings are stored.
time_key
Key in :attr:`~anndata.AnnData.obs` containing the experimental time.
kwargs
Keyword arguments for :class:`~cellrank.kernels.RealTimeKernel`.
Returns
-------
The kernel.
Examples
--------
.. code-block:: python
import wot
import cellrank as cr
adata = cr.datasets.reprogramming_schiebinger(subset_to_serum=True)
adata.obs["day"] = adata.obs["day"].astype(float).astype("category")
ot_model = wot.ot.OTModel(adata, day_field="day")
ot_model.compute_all_transport_maps(tmap_out="tmaps/")
rtk = cr.kernels.RealTimeKernel.from_wot(adata, path="tmaps/", time_key="day")
rtk = rtk.compute_transition_matrix()
"""
path = pathlib.Path(path)
dtype = type(adata.obs[time_key].iloc[0])
couplings = {}
for fname in path.glob("*h5ad"):
name, _ = os.path.splitext(fname)
*_, src, tgt = name.split("_")
couplings[dtype(src), dtype(tgt)] = sc.read(fname)
return cls(adata, couplings=couplings, time_key=time_key, policy="sequential", **kwargs)
@d.dedent
def _restich_couplings(
self,
couplings: Mapping[Key_t, AnnData],
self_transitions: Union[str, SelfTransitions, Sequence[Any]] = SelfTransitions.DIAGONAL,
conn_weight: Optional[float] = None,
**kwargs: Any,
) -> AnnData:
"""Group individual transport maps into 1 matrix aligned with :attr:`adata`.
Parameters
----------
couplings
Optimal transport couplings.
self_transitions
How to define transitions within the blocks that correspond to transitions within the same key.
conn_weight
Weight of connectivities' self transitions. Only used when ``self_transitions = {st.ALL!r}`` or
a sequence of source keys is passed.
kwargs
Keyword arguments for :func:`~scanpy.pp.neighbors` or
:meth:`~cellrank.kernels.ConnectivityKernel.compute_transition_matrix` when using
``self_transitions = 'connectivities'``.
Returns
-------
Merged transport maps into one :class:`~anndata.AnnData` object.
"""
if isinstance(self_transitions, (str, SelfTransitions)):
self_transitions = SelfTransitions(self_transitions)
if self_transitions == SelfTransitions.ALL:
self_transitions = tuple(src for (src, _) in couplings)
elif isinstance(self_transitions, Iterable):
self_transitions = tuple(self_transitions)
else:
raise TypeError(
f"Expected `self_transitions` to be a `str` or a `Sequence`, " f"found `{type(self_transitions)}`."
)
self._validate_couplings(couplings)
conn_kwargs = dict(kwargs)
conn_kwargs["copy"] = False
_ = conn_kwargs.pop("key_added", None)
obs_names, obs = {}, {}
blocks = [[None] * (len(self._time_to_ix)) for _ in range(len(self._time_to_ix))]
for (src, tgt), coupling in couplings.items():
src_ix = self._time_to_ix[src]
tgt_ix = self._time_to_ix[tgt]
blocks[src_ix][tgt_ix] = coupling.X
obs_names[src_ix] = coupling.obs_names
obs[src_ix] = coupling.obs
# to prevent blocks from disappearing
n = np.sum(self._time == src)
blocks[src_ix][src_ix] = sp.spdiags([0] * n, 0, n, n)
if tgt == self._reference:
obs_names[tgt_ix] = coupling.var_names
# if policy='sequential', reference is the last key
ref_ix = self._time_to_ix[self._reference]
ref_mask = self._time == self._reference
n = int(np.sum(ref_mask))
if self_transitions == SelfTransitions.DIAGONAL:
blocks[ref_ix][ref_ix] = sp.spdiags([1] * n, 0, n, n)
elif self_transitions == SelfTransitions.UNIFORM:
blocks[ref_ix][ref_ix] = np.ones((n, n)) / float(n)
elif self_transitions == SelfTransitions.CONNECTIVITIES:
blocks[ref_ix][ref_ix] = _compute_connectivity_tmat(self.adata[ref_mask], **conn_kwargs)
elif isinstance(self_transitions, tuple):
verbosity = settings.verbosity
try: # ignore overly verbose logging
settings.verbosity = 0
for (src, tgt), coupling in couplings.items():
# fmt: off
if src not in self_transitions:
continue
src_ix, tgt_ix = self._time_to_ix[src], self._time_to_ix[tgt]
blocks[src_ix][src_ix] = conn_weight * _compute_connectivity_tmat(
self.adata[coupling.obs_names], **conn_kwargs
)
blocks[src_ix][tgt_ix] = (1 - conn_weight) * _normalize(blocks[src_ix][tgt_ix])
# fmt: on
blocks[ref_ix][ref_ix] = _compute_connectivity_tmat(self.adata[ref_mask], **conn_kwargs)
finally:
settings.verbosity = verbosity
else:
raise NotImplementedError(f"Self transitions' mode `{self_transitions}` is not yet implemented.")
index = []
for ix in range(len(blocks)):
index.extend(obs_names[ix])
tmp = AnnData(sp.bmat(blocks, format="csr"))
tmp.obs_names = index
tmp.var_names = index
tmp = tmp[self.adata.obs_names, :][:, self.adata.obs_names]
obs = pd.concat([obs.get(ix) for ix in range(len(blocks))])
tmp.obs = pd.merge(
tmp.obs,
obs,
left_index=True,
right_index=True,
how="left",
)
return tmp
@d.get_sections(base="tmk_thresh", sections=["Parameters"])
def _sparsify_couplings(
self,
couplings: dict[Key_t, AnnData],
threshold: Threshold_t,
copy: bool = False,
) -> Optional[Mapping[Key_t, AnnData]]:
"""Remove small non-zero values from :attr:`transition_matrix`.
.. warning::
Only dense :attr:`~anndata.AnnData.X` will be sparsified.
Parameters
----------
couplings
Optimal transport couplings to sparsify.
threshold
How to remove small non-zero values from the transition matrix. Valid options are:
- ``'auto'`` - find the maximum threshold value which will not remove every non-zero value from any row.
- ``'auto_local'`` - same as above, but done for each transport separately.
- :class:`float` - value in :math:`[0, 100]` corresponding to a percentage of non-zeros to remove in
the couplings.
Rows where all values are removed will have a uniform distribution and a warning will be issued.
copy
Whether to return a copy of the ``couplings`` or modify in-place.
Returns
-------
If ``copy = True``, returns sparsified couplings. Otherwise, modifies ``couplings`` in-place.
"""
if threshold == "auto":
thresh = min(adata.X[i].max() for adata in couplings.values() for i in range(adata.n_obs))
logg.info(f"Using automatic `threshold={thresh}`")
elif threshold == "auto_local":
logg.info("Using automatic `threshold` for each coupling separately")
elif not (0 <= threshold <= 100):
raise ValueError(f"Expected `threshold` to be in `[0, 100]`, found `{threshold}`.`")
for key, adata in couplings.items():
if sp.issparse(adata.X):
continue
if copy:
adata = adata.copy()
tmat = adata.X
if threshold == "auto_local":
thresh = min(tmat[i].max() for i in range(tmat.shape[0]))
logg.debug(f"Using `threshold={thresh}` at `{key}`")
elif isinstance(threshold, (int, float, np.number)):
thresh = np.percentile(tmat.data, threshold)
logg.debug(f"Using `threshold={thresh}` at `{key}`")
tmat = sp.csr_matrix(tmat, dtype=tmat.dtype)
tmat.data[tmat.data < thresh] = 0.0
tmat.eliminate_zeros()
couplings[key] = AnnData(tmat, obs=adata.obs, var=adata.var)
return couplings if copy else None
def _validate_couplings(
self,
couplings: Mapping[Key_t, AnnData],
) -> None:
"""Validate that transport maps conform to various invariants.
Parameters
----------
couplings
Optimal transport couplings.
Returns
-------
Possibly reordered transport maps.
"""
def assert_same(expected: Sequence[Any], actual: Sequence[Any], msg: Optional[str] = None) -> None:
try:
pd.testing.assert_series_equal(
pd.Series(sorted(expected)),
pd.Series(sorted(actual)),
check_names=False,
check_flags=False,
check_index=True,
)
except AssertionError as e:
raise IndexError(msg) from e
for (src, tgt), coupling in couplings.items():
src_obs = self.adata.obs_names[self.time == src]
tgt_obs = self.adata.obs_names[self.time == tgt]
assert_same(
src_obs,
coupling.obs_names,
msg=f"Source observations for `{src, tgt}` don't match with `adata.obs_names`.",
)
assert_same(
tgt_obs,
coupling.var_names,
msg=f"Source observations for `{src, tgt}` don't match with `adata.var_names`.",
)
for key, coupling in couplings.items():
tmat = coupling.X
zeros_mask = np.where(np.asarray(tmat.sum(1)).squeeze() == 0)[0]
if np.any(zeros_mask):
logg.warning(
f"Coupling at `{key}` contains `{len(zeros_mask)}` empty row(s), e.g., due to "
f"unbalancedness or 0s in the source marginals. Using uniform distribution"
)
tmat[zeros_mask] = np.ones((len(zeros_mask), tmat.shape[1]), dtype=tmat.dtype) / tmat.shape[1]
def _coupling_to_adata(self, src: Any, tgt: Any, coupling: Coupling_t) -> AnnData:
"""Convert the coupling to :class:`~anndata.AnnData`."""
if not isinstance(coupling, AnnData):
coupling = AnnData(X=coupling)
coupling.obs_names = np.asarray(self.adata.obs_names)[self.time == src]
coupling.var_names = np.asarray(self.adata.obs_names)[self.time == tgt]
if sp.issparse(coupling.X) and not sp.isspmatrix_csr(coupling.X):
coupling.X = coupling.X.tocsr()
return coupling
def _get_default_coupling(
self,
couplings: Optional[Mapping[Key_t, Optional[Coupling_t]]],
policy: Literal["sequential", "triu"],
) -> tuple[dict[Key_t, Optional[Coupling_t]], Any]:
cats = self._time.cat.categories
if policy == "sequential":
if couplings is None:
couplings = {(src, tgt): None for src, tgt in zip(cats[:-1], cats[1:])}
reference = sorted(k for ks in couplings for k in ks)[-1]
elif policy == "triu":
if couplings is None:
couplings = {(src, tgt): None for src, tgt in itertools.product(cats, cats) if src < tgt}
reference = sorted(k for ks in couplings for k in ks)[-1]
else:
raise NotImplementedError(f"Handling `{policy}` is not yet implemented.")
for keys in couplings:
for key in keys:
if key not in cats:
raise ValueError(f"Key `{key}` is not in `{sorted(cats)}`.")
return dict(couplings), reference
@property
def time(self) -> pd.Series:
"""Experimental time."""
return self._time
@property
def couplings(self) -> Optional[dict[Key_t, AnnData]]:
"""Optimal transport couplings."""
return self._couplings
@property
def obs(self) -> Optional[pd.DataFrame]:
"""Cell-level metadata."""
return self._obs
def _compute_connectivity_tmat(
adata: AnnData, density_normalize: bool = True, **kwargs: Any
) -> Union[np.ndarray, sp.spmatrix]:
from cellrank.kernels import ConnectivityKernel
sc.pp.neighbors(adata, **kwargs)
return ConnectivityKernel(adata).compute_transition_matrix(density_normalize=density_normalize).transition_matrix