Source code for cellrank.tl.estimators._base_estimator

"""Abstract base class for all kernel-holding estimators."""
import pickle
from abc import ABC, abstractmethod
from sys import version_info
from copy import copy, deepcopy
from typing import Any, Dict, Tuple, Union, Mapping, TypeVar, Optional, Sequence
from pathlib import Path
from datetime import datetime

from typing_extensions import Literal

import scvelo as scv

import numpy as np
import pandas as pd
from pandas import Series
from scipy.stats import ranksums
from scipy.sparse import spmatrix
from pandas.api.types import infer_dtype, is_categorical_dtype

import matplotlib.pyplot as plt
from matplotlib.colors import is_color_like

from cellrank import logging as logg
from cellrank.ul._docs import d, inject_docs
from cellrank.tl._utils import (
    TestMethod,
    save_fig,
    _pairwise,
    _irreducible,
    _process_series,
    _correlation_test,
    _get_cat_and_null_indices,
    _merge_categorical_series,
    _convert_to_categorical_series,
    _calculate_lineage_absorption_time_means,
)
from cellrank.tl._colors import (
    _map_names_and_colors,
    _convert_to_hex_colors,
    _create_categorical_colors,
)
from cellrank.tl.kernels import PrecomputedKernel
from cellrank.tl._lineage import Lineage
from cellrank.tl._constants import (
    Macro,
    DirPrefix,
    AbsProbKey,
    PrettyEnum,
    TermStatesKey,
    _pd,
    _probs,
    _colors,
    _lin_names,
)
from cellrank.tl._linear_solver import _solve_lin_system
from cellrank.tl.estimators._property import Partitioner, LineageEstimatorMixin
from cellrank.tl.kernels._base_kernel import KernelExpression
from cellrank.tl.estimators._constants import A, P

AnnData = TypeVar("AnnData")
_COMP_TERM_STATES_MSG = (
    "Compute terminal states first as `.compute_terminal_states()` or"
    "set them manually as `.set_terminal_states()`."
)


[docs]@d.get_sections(base="base_estimator", sections=["Parameters"]) @d.dedent class BaseEstimator(LineageEstimatorMixin, Partitioner, ABC): """ Base class for all estimators. Parameters ---------- obj Either a :class:`cellrank.tl.kernels.Kernel` object, an :class:`anndata.AnnData` object which stores the transition matrix in ``.obsp`` attribute or :mod:`numpy` or :mod:`scipy` array. inplace Whether to modify :paramref:`adata` object inplace or make a copy. read_from_adata Whether to read available attributes in :paramref:`adata`, if present. obsp_key Key in ``obj.obsp`` when ``obj`` is an :class:`anndata.AnnData` object. g2m_key Key in :paramref:`adata` ``.obs``. Can be used to detect cell-cycle driven start- or endpoints. s_key Key in :paramref:`adata` ``.obs``. Can be used to detect cell-cycle driven start- or endpoints. write_to_adata Whether to write the transition matrix to :paramref:`adata` ``.obsp`` and the parameters to :paramref:`adata` ``.uns``. %(write_to_adata.parameters)s Only used when ``write_to_adata=True``. """ def __init__( self, obj: Union[KernelExpression, AnnData, spmatrix, np.ndarray], inplace: bool = True, read_from_adata: bool = False, obsp_key: Optional[str] = None, g2m_key: Optional[str] = "G2M_score", s_key: Optional[str] = "S_score", write_to_adata: bool = True, key: Optional[str] = None, ): from anndata import AnnData super().__init__(obj, obsp_key=obsp_key, key=key, write_to_adata=write_to_adata) if isinstance(obj, (KernelExpression, AnnData)) and not inplace: self.kernel._adata = self.adata.copy() if self.kernel.backward: self._term_key = TermStatesKey.BACKWARD.s self._abs_prob_key = AbsProbKey.BACKWARD.s self._term_abs_prob_key = Macro.BACKWARD.s else: self._term_key = TermStatesKey.FORWARD.s self._abs_prob_key = AbsProbKey.FORWARD.s self._term_abs_prob_key = Macro.FORWARD.s self._key_added = key self._g2m_key = g2m_key self._s_key = s_key self._G2M_score = None self._S_score = None self._absorption_time_mean = None self._absorption_time_var = None if read_from_adata: self._read_from_adata() def __init_subclass__(cls, **kwargs: Any): super().__init_subclass__() def _read_from_adata(self) -> None: self._set_or_debug(f"eig_{self._direction}", self.adata.uns, "_eig") self._set_or_debug(self._g2m_key, self.adata.obs, "_G2M_score") self._set_or_debug(self._s_key, self.adata.obs, "_S_score") self._set_or_debug(self._term_key, self.adata.obs, A.TERM.s) self._set_or_debug(_probs(self._term_key), self.adata.obs, A.TERM_PROBS) self._set_or_debug(_colors(self._term_key), self.adata.uns, A.TERM_COLORS) self._reconstruct_lineage(A.ABS_PROBS, self._abs_prob_key) self._set_or_debug(_pd(self._abs_prob_key), self.adata.obs, A.PRIME_DEG) def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str): self._set_or_debug(obsm_key, self.adata.obsm, attr) names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns) colors = self._set_or_debug(_colors(self._term_key), self.adata.uns) probs = self._get(attr) if probs is not None: if len(names) != probs.shape[1]: if isinstance(probs, Lineage): names = probs.names else: logg.warning( f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. " f"Creating new names" ) names = [f"Lineage {i}" for i in range(probs.shape[1])] if len(colors) != probs.shape[1] or not all( map(lambda c: isinstance(c, str) and is_color_like(c), colors) ): if isinstance(probs, Lineage): colors = probs.colors else: logg.warning( f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. " f"Creating new colors" ) colors = _create_categorical_colors(probs.shape[1]) self._set(attr, Lineage(probs, names=names, colors=colors)) self.adata.obsm[obsm_key] = self._get(attr) self.adata.uns[_lin_names(self._term_key)] = names self.adata.uns[_colors(self._term_key)] = colors
[docs] @d.dedent @inject_docs(fs=P.TERM.s, fsp=P.TERM_PROBS.s) def set_terminal_states( self, labels: Union[Series, Dict[str, Sequence[Any]]], cluster_key: Optional[str] = None, en_cutoff: Optional[float] = None, p_thresh: Optional[float] = None, add_to_existing: bool = False, **kwargs, ) -> None: """ Set the approximate recurrent classes, if they are known a priori. Parameters ---------- labels Either a categorical :class:`pandas.Series` with index as cell names, where `NaN` marks marks a cell belonging to a transient state or a :class:`dict`, where each key is the name of the recurrent class and values are list of cell names. cluster_key If a key to cluster labels is given, :paramref:`{fs}` will be associated with these for naming and colors. %(en_cutoff_p_thresh)s add_to_existing Whether to add these categories to existing ones. Cells already belonging to recurrent classes will be updated if there's an overlap. Throws an error if previous approximate recurrent classes have not been calculated. Returns ------- None Nothing, but updates the following fields: - :paramref:`{fsp}` - :paramref:`{fs}` """ self._set_categorical_labels( attr_key=A.TERM.v, color_key=A.TERM_COLORS.v, pretty_attr_key=P.TERM.v, add_to_existing_error_msg=_COMP_TERM_STATES_MSG, categories=labels, cluster_key=cluster_key, en_cutoff=en_cutoff, p_thresh=p_thresh, add_to_existing=add_to_existing, ) self._write_terminal_states(time=kwargs.get("time", None))
[docs] @inject_docs(ts=P.TERM.s) def rename_terminal_states( self, new_names: Mapping[str, str], update_adata: bool = True ) -> None: """ Rename the names of :paramref:`{ts}`. Parameters ---------- new_names Mapping where keys are the old names and the values are the new names. New names must be unique. update_adata Whether to update underlying :paramref:`adata` object as well or not. Returns ------- None Nothing, just updates the names of :paramref:`{ts}`. """ term_states = self._get(P.TERM) if term_states is None: raise RuntimeError(_COMP_TERM_STATES_MSG) if not isinstance(new_names, Mapping): raise TypeError(f"Expected a `Mapping` type, found `{type(new_names)!r}`.") if not len(new_names): return new_names = {k: str(v) for k, v in new_names.items()} mask = np.isin(list(new_names.keys()), term_states.cat.categories) if not np.all(mask): raise ValueError( f"Invalid old terminal states names: `{np.array(list(new_names.keys()))[~mask]}`." ) names_after_renaming = [new_names.get(n, n) for n in term_states.cat.categories] if len(set(names_after_renaming)) != len(term_states.cat.categories): raise ValueError( f"After renaming, the names will not be unique: `{names_after_renaming}`." ) term_states.cat.rename_categories(new_names, inplace=True) memberships = ( self._get(A.TERM_ABS_PROBS) if hasattr(self, A.TERM_ABS_PROBS.s) else None ) if memberships is not None: # GPCCA memberships.names = [new_names.get(n, n) for n in memberships.names] self._set(A.TERM_ABS_PROBS, memberships) # we can be just computing it and it's not yet saved in adata if ( update_adata and self._term_key in self.adata.obs and _lin_names(self._term_key) in self.adata.uns ): self.adata.obs[self._term_key].cat.rename_categories( new_names, inplace=True ) self.adata.uns[_lin_names(self._term_key)] = np.array( self.adata.obs[self._term_key].cat.categories )
[docs] @inject_docs( abs_prob=P.ABS_PROBS, fs=P.TERM.s, lat=P.LIN_ABS_TIMES, ) def compute_absorption_probabilities( self, keys: Optional[Sequence[str]] = None, check_irreducibility: bool = False, solver: str = "gmres", use_petsc: bool = True, time_to_absorption: Optional[ Union[ str, Sequence[Union[str, Sequence[str]]], Dict[Union[str, Sequence[str]], str], ] ] = None, n_jobs: Optional[int] = None, backend: str = "loky", show_progress_bar: bool = True, tol: float = 1e-6, preconditioner: Optional[str] = None, ) -> None: """ Compute absorption probabilities of a Markov chain. For each cell, this computes the probability of it reaching any of the approximate recurrent classes defined by :paramref:`{fs}`. Parameters ---------- keys Keys defining the recurrent classes. check_irreducibility: Check whether the transition matrix is irreducible. solver Solver to use for the linear problem. Options are `'direct', 'gmres', 'lgmres', 'bicgstab' or 'gcrotmk'` when ``use_petsc=False`` or one of :class:`petsc4py.PETSc.KPS.Type` otherwise. Information on the :mod:`scipy` iterative solvers can be found in :func:`scipy.sparse.linalg` or for :mod:`petsc4py` solver `here <https://www.mcs.anl.gov/petsc/documentation/linearsolvertable.html>`__. use_petsc Whether to use solvers from :mod:`petsc4py` or :mod:`scipy`. Recommended for large problems. If no installation is found, defaults to :func:`scipy.sparse.linalg.gmres`. time_to_absorption Whether to compute mean time to absorption and its variance to specific absorbing states. If a :class:`dict`, can be specified as ``{{'Alpha': 'var', ...}}`` to also compute variance. In case when states are a :class:`tuple`, time to absorption will be computed to the subset of these states, such as ``[('Alpha', 'Beta'), ...]`` or ``{{('Alpha', 'Beta'): 'mean', ...}}``. Can be specified as ``'all'`` to compute it to any absorbing state in ``keys``, which is more efficient than listing all absorbing states. It might be beneficial to disable the progress bar as ``show_progress_bar=False``, because many linear systems are being solved. n_jobs Number of parallel jobs to use when using an iterative solver. When ``use_petsc=True`` or for quickly-solvable problems, we recommend higher number (>=8) of jobs in order to fully saturate the cores. backend Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid options. show_progress_bar Whether to show progress bar when the solver isn't a direct one. tol Convergence tolerance for the iterative solver. The default is fine for most cases, only consider decreasing this for severely ill-conditioned matrices. preconditioner Preconditioner to use, only available when ``use_petsc=True``. For available values, see `here <https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/PC/PCType.html#PCType>`__ or the values of `petsc4py.PETSc.PC.Type`. We recommended `'ilu'` preconditioner for badly conditioned problems. Returns ------- None Nothing, but updates the following fields: - :paramref:`{abs_prob}` - probabilities of being absorbed into the terminal states. - :paramref:`{lat}` - mean times until absorption to subset absorbing states and optionally their variances saved as ``'{{lineage}} mean'`` and ``'{{lineage}} var'``, respectively, for each subset of absorbing states specified in ``time_to_absorption``. """ if self._get(P.TERM) is None: raise RuntimeError(_COMP_TERM_STATES_MSG) if keys is not None: keys = sorted(set(keys)) start = logg.info("Computing absorption probabilities") # get the transition matrix t = self.transition_matrix if not self.issparse: logg.warning( "Attempting to solve a potentially large linear system with dense transition matrix" ) # process the current annotations according to `keys` terminal_states_, colors_ = _process_series( series=self._get(P.TERM), keys=keys, colors=self._get(A.TERM_COLORS) ) # warn in case only one state is left keys = list(terminal_states_.cat.categories) if len(keys) == 1: logg.warning( "There is only 1 recurrent class, all cells will have probability 1 of going there" ) lin_abs_times = {} if time_to_absorption is not None: if isinstance(time_to_absorption, (str, tuple)): time_to_absorption = [time_to_absorption] if not isinstance(time_to_absorption, dict): time_to_absorption = {ln: "mean" for ln in time_to_absorption} for ln, moment in time_to_absorption.items(): if moment not in ("mean", "var"): raise ValueError( f"Moment must be either `'mean'` or `'var'`, found `{moment!r}` for `{ln!r}`." ) seen = set() if isinstance(ln, str): ln = tuple(keys) if ln == "all" else (ln,) sorted_ln = tuple(sorted(ln)) # preserve the user order if sorted_ln not in seen: seen.add(sorted_ln) for lin in ln: if lin not in keys: raise ValueError( f"Invalid absorbing state `{lin!r}` in `{ln}`. " f"Valid options are `{list(terminal_states_.cat.categories)}`." ) lin_abs_times[tuple(ln)] = moment # define the dimensions of this problem n_cells = t.shape[0] n_macrostates = len(terminal_states_.cat.categories) # get indices corresponding to recurrent and transient states rec_indices, trans_indices, lookup_dict = _get_cat_and_null_indices( terminal_states_ ) if not len(trans_indices): raise RuntimeError("Cannot proceed - Markov chain is irreducible.") # create Q (restriction transient-transient), S (restriction transient-recurrent) q = t[trans_indices, :][:, trans_indices] s = t[trans_indices, :][:, rec_indices] # check for irreducibility if check_irreducibility: if self.is_irreducible is None: self._is_irreducible = _irreducible(self.transition_matrix) else: if not self.is_irreducible: logg.warning("Transition matrix is not irreducible") else: logg.debug("Transition matrix is irreducible") logg.debug(f"Found `{n_cells}` cells and `{s.shape[1]}` absorbing states") # solve the linear system of equations mat_x = _solve_lin_system( q, s, solver=solver, use_petsc=use_petsc, n_jobs=n_jobs, backend=backend, tol=tol, use_eye=True, show_progress_bar=show_progress_bar, preconditioner=preconditioner, ) if time_to_absorption is not None: abs_time_means = _calculate_lineage_absorption_time_means( q, t[trans_indices, :][:, rec_indices], trans_indices, n=t.shape[0], ixs=lookup_dict, lineages=lin_abs_times, solver=solver, use_petsc=use_petsc, n_jobs=n_jobs, backend=backend, tol=tol, show_progress_bar=show_progress_bar, preconditioner=preconditioner, ) abs_time_means.index = self.adata.obs_names else: abs_time_means = None # take individual solutions and piece them together to get absorption probabilities towards the classes macro_ix_helper = np.cumsum( [0] + [len(indices) for indices in lookup_dict.values()] ) _abs_classes = np.concatenate( [ mat_x[:, np.arange(a, b)].sum(1)[:, None] for a, b in _pairwise(macro_ix_helper) ], axis=1, ) # for recurrent states, set their self-absorption probability to one abs_classes = np.zeros((len(self), n_macrostates)) rec_classes_full = { cl: np.where(terminal_states_ == cl)[0] for cl in terminal_states_.cat.categories } for col, cl_indices in enumerate(rec_classes_full.values()): abs_classes[trans_indices, col] = _abs_classes[:, col] abs_classes[cl_indices, col] = 1 self._set( A.ABS_PROBS, Lineage( abs_classes, names=terminal_states_.cat.categories, colors=colors_, ), ) extra_msg = "" if abs_time_means is not None: self._set(A.LIN_ABS_TIMES, abs_time_means) extra_msg = f" `.{P.LIN_ABS_TIMES}`\n" self._write_absorption_probabilities(time=start, extra_msg=extra_msg)
[docs] @d.dedent def compute_lineage_priming( self, method: Literal["kl_divergence", "entropy"] = "kl_divergence", early_cells: Optional[Union[Mapping[str, Sequence[str]], Sequence[str]]] = None, ) -> pd.Series: """ %(lin_pd.full_desc)s Parameters ---------- %(lin_pd.parameters)s Cell ids or a mask marking early cells. If `None`, use all cells. Only used when ``method='kl_divergence'``. If a :class:`dict`, the key species a cluster key in :attr:`anndata.AnnData.obs` and the values specify cluster labels containing early cells. Returns ------- %(lin_pd.returns)s """ # noqa: D400 abs_probs: Optional[Lineage] = self._get(P.ABS_PROBS) if abs_probs is None: raise RuntimeError( "Compute absorption probabilities first as `.compute_absorption_probabilities()`." ) if isinstance(early_cells, dict): if len(early_cells) != 1: raise ValueError( f"Expected a dictionary with only 1 key, found `{len(early_cells)}`." ) key = next(iter(early_cells.keys())) if key not in self.adata.obs: raise KeyError(f"Unable to find clustering in `adata.obs[{key!r}]`.") early_cells = self.adata.obs[key].isin(early_cells[key]) elif early_cells is not None: early_cells = np.asarray(early_cells) if not np.issubdtype(early_cells.dtype, np.bool_): early_cells = np.isin(self.adata.obs_names, early_cells) values = pd.Series( abs_probs.priming_degree(method, early_cells), index=self.adata.obs_names ) self._set(A.PRIME_DEG, values) self.adata.obs[_pd(self._abs_prob_key)] = values logg.info( f"Adding `adata.obs[{_pd(self._abs_prob_key)!r}]`\n" f" `.{P.PRIME_DEG}`" ) return values
[docs] @d.get_sections( base="lineage_drivers", sections=["Parameters", "Returns", "References"] ) @d.get_full_description(base="lineage_drivers") @d.dedent @inject_docs(lin_drivers=P.LIN_DRIVERS, tm=TestMethod) def compute_lineage_drivers( self, lineages: Optional[Union[str, Sequence]] = None, method: str = TestMethod.FISCHER.s, cluster_key: Optional[str] = None, clusters: Optional[Union[str, Sequence]] = None, layer: str = "X", use_raw: bool = False, confidence_level: float = 0.95, n_perms: int = 1000, seed: Optional[int] = None, return_drivers: bool = True, **kwargs, ) -> Optional[pd.DataFrame]: """ Compute driver genes per lineage. Correlates gene expression with lineage probabilities, for a given lineage and set of clusters. Often, it makes sense to restrict this to a set of clusters which are relevant for the specified lineages. Parameters ---------- lineages Either a set of lineage names from :paramref:`absorption_probabilities` `.names` or `None`, in which case all lineages are considered. method Mode to use when calculating p-values and confidence intervals. Can be one of: - {tm.FISCHER.s!r} - use Fischer transformation [Fischer21]_. - {tm.PERM_TEST.s!r} - use permutation test. cluster_key Key from :paramref:`adata` ``.obs`` to obtain cluster annotations. These are considered for ``clusters``. clusters Restrict the correlations to these clusters. layer Key from :paramref:`adata` ``.layers``. use_raw Whether or not to use :paramref:`adata` ``.raw`` to correlate gene expression. If using a layer other than ``.X``, this must be set to `False`. confidence_level Confidence level for the confidence interval calculation. Must be in `[0, 1]`. n_perms Number of permutations to use when ``method={tm.PERM_TEST.s!r}``. seed Random seed when ``method={tm.PERM_TEST.s!r}``. return_drivers Whether to return the drivers. This also contains the lower and upper ``confidence_level`` confidence interval bounds. %(parallel)s Returns ------- %(correlation_test.returns)s Only if ``return_drivers=True``. None Updates :paramref:`adata` ``.var`` or :paramref:`adata` ``.raw.var``, depending ``use_raw`` with: - ``'{{direction}} {{lineage}} corr'`` - the potential lineage drivers. - ``'{{direction}} {{lineage}} qval'`` - the corrected p-values. Updates the following fields: - :paramref:`{lin_drivers}` - same as the returned values. References ---------- .. [Fischer21] Fisher, R. A. (1921), *On the “probable error” of a coefficient of correlation deduced from a small sample.*, `Metron 1 3–32 <http://hdl.handle.net/2440/15169>`__. """ # check that lineage probs have been computed method = TestMethod(method) abs_probs = self._get(P.ABS_PROBS) prefix = DirPrefix.BACKWARD if self.kernel.backward else DirPrefix.FORWARD if abs_probs is None: raise RuntimeError( "Compute absorption probabilities first as `.compute_absorption_probabilities()`." ) elif abs_probs.shape[1] == 1: logg.warning( "There is only 1 lineage present. Using the stationary distribution instead" ) abs_probs = Lineage( self._get(P.TERM_PROBS).values, names=abs_probs.names, colors=abs_probs.colors, ) # check all lin_keys exist in self.lin_names if isinstance(lineages, str): lineages = [lineages] if lineages is not None: _ = abs_probs[lineages] else: lineages = abs_probs.names if not len(lineages): raise ValueError("No lineages have been selected.") # use `cluster_key` and clusters to subset the data if clusters is not None: if cluster_key not in self.adata.obs.keys(): raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.") if isinstance(clusters, str): clusters = [clusters] all_clusters = np.array(self.adata.obs[cluster_key].cat.categories) cluster_mask = np.array([name not in all_clusters for name in clusters]) if any(cluster_mask): raise KeyError( f"Clusters `{list(np.array(clusters)[cluster_mask])}` not found in " f"`adata.obs[{cluster_key!r}]`." ) subset_mask = np.in1d(self.adata.obs[cluster_key], clusters) adata_comp = self.adata[subset_mask] lin_probs = abs_probs[subset_mask, :] else: adata_comp = self.adata lin_probs = abs_probs # check that the layer exists, and that use raw is only used with layer X if layer != "X": if layer not in self.adata.layers: raise KeyError(f"Layer `{layer!r}` not found in `adata.layers`.") if use_raw: raise ValueError("For `use_raw=True`, layer must be 'X'.") data = adata_comp.layers[layer] var_names = adata_comp.var_names else: if use_raw and self.adata.raw is None: logg.warning("No raw attribute set. Using `.X` instead") use_raw = False data = adata_comp.raw.X if use_raw else adata_comp.X var_names = adata_comp.raw.var_names if use_raw else adata_comp.var_names start = logg.info( f"Computing correlations for lineages `{lineages}` restricted to clusters `{clusters}` in " f"layer `{layer}` with `use_raw={use_raw}`" ) drivers = _correlation_test( data, lin_probs[lineages], gene_names=var_names, method=method, n_perms=n_perms, seed=seed, confidence_level=confidence_level, **kwargs, ) self._set(A.LIN_DRIVERS, drivers) corrs, qvals = [f"{lin} corr" for lin in lineages], [ f"{lin} qval" for lin in lineages ] if use_raw: self.adata.raw.var[[f"{prefix} {col}" for col in corrs]] = drivers[corrs] self.adata.raw.var[[f"{prefix} {col}" for col in qvals]] = drivers[qvals] else: self.adata.var[[f"{prefix} {col}" for col in corrs]] = drivers[corrs] self.adata.var[[f"{prefix} {col}" for col in qvals]] = drivers[qvals] field = "raw.var" if use_raw else "var" keys_added = [f"`adata.{field}['{prefix} {lin} corr']`" for lin in lineages] logg.info( f"Adding `.{P.LIN_DRIVERS}`\n " + "\n ".join(keys_added) + "\n Finish", time=start, ) if return_drivers: return drivers
[docs] @d.get_sections(base="plot_lineage_drivers", sections=["Parameters"]) @d.dedent def plot_lineage_drivers( self, lineage: str, n_genes: int = 8, ncols: Optional[int] = None, use_raw: bool = False, title_fmt: str = "{gene} qval={qval:.4e}", figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, **kwargs, ) -> None: """ Plot lineage drivers discovered by :meth:`compute_lineage_drivers`. Parameters ---------- lineage Lineage for which to plot the driver genes. n_genes Top most correlated genes to plot. ncols Number of columns. use_raw Whether to look in :paramref:`adata` ``.raw.var`` or :paramref:`adata` ``.var``. title_fmt Title format. Possible keywords include `{gene}`, `{qval}`, `{corr}` for gene name, q-value and correlation, respectively. %(plotting)s kwargs Keyword arguments for :func:`scvelo.pl.scatter`. Returns ------- %(just_plots)s """ def prepare_format( gene: str, *, pval: Optional[float], qval: Optional[float], corr: Optional[float], ) -> Dict[str, Any]: kwargs = {} if "{gene" in title_fmt: kwargs["gene"] = gene if "{pval" in title_fmt: kwargs["pval"] = float(pval) if pval is not None else np.nan if "{qval" in title_fmt: kwargs["qval"] = float(qval) if qval is not None else np.nan if "{corr" in title_fmt: kwargs["corr"] = float(corr) if corr is not None else np.nan return kwargs lin_drivers = self._get(P.LIN_DRIVERS) if lin_drivers is None: raise RuntimeError( f"Compute `.{P.LIN_DRIVERS}` first as `.compute_lineage_drivers()`." ) key = f"{lineage} corr" if key not in lin_drivers: raise KeyError( f"Lineage `{key!r}` not found in `{list(lin_drivers.columns)}`." ) if n_genes <= 0: raise ValueError(f"Expected `n_genes` to be positive, found `{n_genes}`.") kwargs.pop("save", None) genes = lin_drivers.sort_values(by=key, ascending=False).head(n_genes) ncols = 4 if ncols is None else ncols nrows = int(np.ceil(len(genes) / ncols)) fig, axes = plt.subplots( ncols=ncols, nrows=nrows, dpi=dpi, figsize=(ncols * 6, nrows * 4) if figsize is None else figsize, ) axes = np.ravel([axes]) _i = 0 for _i, (gene, ax) in enumerate(zip(genes.index, axes)): data = genes.loc[gene] scv.pl.scatter( self.adata, color=gene, ncols=ncols, use_raw=use_raw, ax=ax, show=False, title=title_fmt.format( **prepare_format( gene, pval=data.get(f"{lineage} pval", None), qval=data.get(f"{lineage} qval", None), corr=data.get(f"{lineage} corr", None), ) ), **kwargs, ) for j in range(_i + 1, len(axes)): axes[j].remove() if save is not None: save_fig(fig, save)
def _detect_cc_stages(self, rc_labels: Series, p_thresh: float = 1e-15) -> None: """ Detect cell-cycle driven start or endpoints. Parameters ---------- rc_labels Approximate recurrent classes. p_thresh P-value threshold for the rank-sum test for the group to be considered cell-cycle driven. Returns ------- None Nothing, but warns if a group is cell-cycle driven. """ # initialize the groups (start or end clusters) and scores groups = rc_labels.cat.categories scores = [] if self._G2M_score is not None: scores.append(self._G2M_score) if self._S_score is not None: scores.append(self._S_score) # loop over groups and scores for group in groups: mask = rc_labels == group for score in scores: a, b = score[mask], score[~mask] statistic, pvalue = ranksums(a, b) if statistic > 0 and pvalue < p_thresh: logg.warning(f"Group `{group!r}` appears to be cell-cycle driven") break def _set_categorical_labels( self, attr_key: str, color_key: str, pretty_attr_key: str, categories: Union[Series, Dict[Any, Any]], add_to_existing_error_msg: Optional[str] = None, cluster_key: Optional[str] = None, en_cutoff: Optional[float] = None, p_thresh: Optional[float] = None, add_to_existing: bool = False, ) -> None: if isinstance(categories, dict): categories = _convert_to_categorical_series( categories, list(self.adata.obs_names) ) if not is_categorical_dtype(categories): raise TypeError( f"Object must be `categorical`, found `{infer_dtype(categories)}`." ) if add_to_existing: if getattr(self, attr_key) is None: raise RuntimeError(add_to_existing_error_msg) categories = _merge_categorical_series( getattr(self, attr_key), categories, inplace=False ) if cluster_key is not None: logg.debug(f"Creating colors based on `{cluster_key}`") # check that we can load the reference series from adata if cluster_key not in self.adata.obs: raise KeyError( f"Cluster key `{cluster_key!r}` not found in `adata.obs`." ) series_query, series_reference = categories, self.adata.obs[cluster_key] # load the reference colors if they exist if _colors(cluster_key) in self.adata.uns.keys(): colors_reference = _convert_to_hex_colors( self.adata.uns[_colors(cluster_key)] ) else: colors_reference = _create_categorical_colors( len(series_reference.cat.categories) ) approx_rcs_names, colors = _map_names_and_colors( series_reference=series_reference, series_query=series_query, colors_reference=colors_reference, en_cutoff=en_cutoff, ) setattr(self, color_key, colors) # if approx_rcs_names is categorical, the info is take from .cat.categories categories.cat.categories = approx_rcs_names else: setattr( self, color_key, _create_categorical_colors(len(categories.cat.categories)), ) if p_thresh is not None: self._detect_cc_stages(categories, p_thresh=p_thresh) # write to class and adata if getattr(self, attr_key) is not None: logg.debug(f"Overwriting `.{pretty_attr_key}`") setattr(self, attr_key, categories) def _write_terminal_states(self, time=None) -> None: self.adata.obs[self._term_key] = self._get(P.TERM) self.adata.obs[_probs(self._term_key)] = self._get(P.TERM_PROBS) self.adata.uns[_colors(self._term_key)] = self._get(A.TERM_COLORS) self.adata.uns[_lin_names(self._term_key)] = np.array( self._get(P.TERM).cat.categories ) extra_msg = "" if getattr(self, A.TERM_ABS_PROBS.s, None) is not None and hasattr( self, "_term_abs_prob_key" ): # checking for None because terminal states can be set using `set_terminal_states` # without the probabilities in GPCCA self.adata.obsm[self._term_abs_prob_key] = self._get(A.TERM_ABS_PROBS) extra_msg = f" `adata.obsm[{self._term_abs_prob_key!r}]`\n" logg.info( f"Adding `adata.obs[{_probs(self._term_key)!r}]`\n" f" `adata.obs[{self._term_key!r}]`\n" f"{extra_msg}" f" `.{P.TERM_PROBS}`\n" f" `.{P.TERM}`\n" " Finish", time=time, ) @abstractmethod def _fit_terminal_states(self, *args: Any, **kwargs: Any) -> None: """ High level API helper method called inside :meth:`fit` that should compute the terminal states. This method would usually call :meth:`compute_terminal_states` after all the functions that required beforehand. Parameters ---------- args Positional arguments. kwargs Keyword arguments. Returns ------- None Nothing, just sets the terminal states. See also -------- See :meth:`cellrank.tl.estimators.GPCCA._fit_terminal_states` for an example implementation. """
[docs] @d.dedent @inject_docs(fs=P.TERM, fsp=P.TERM_PROBS, ap=P.ABS_PROBS, pd=P.PRIME_DEG) def fit( self, keys: Optional[Sequence] = None, compute_absorption_probabilities: bool = True, **kwargs, ) -> None: """ Run the pipeline. Parameters ---------- keys States for which to compute absorption probabilities. compute_absorption_probabilities Whether to compute absorption probabilities or just %(initial_or_terminal)s states. kwargs Keyword arguments. Returns ------- None Nothing, just makes available the following fields: - :paramref:`{fsp}` - :paramref:`{fs}` - :paramref:`{ap}` - :paramref:`{pd}` """ self._fit_terminal_states(**kwargs) if compute_absorption_probabilities: self.compute_absorption_probabilities(keys=keys)
def _write_absorption_probabilities( self, time: datetime, extra_msg: str = "" ) -> None: self.adata.obsm[self._abs_prob_key] = self._get(P.ABS_PROBS) abs_prob = self._get(P.ABS_PROBS) self.adata.uns[_lin_names(self._abs_prob_key)] = abs_prob.names self.adata.uns[_colors(self._abs_prob_key)] = abs_prob.colors logg.info( f"Adding `adata.obsm[{self._abs_prob_key!r}]`\n" f"{extra_msg}" f" `.{P.ABS_PROBS}`\n" " Finish", time=time, ) def _set(self, n: Union[str, PrettyEnum], v: Any) -> None: setattr(self, n.s if isinstance(n, PrettyEnum) else n, v) def _get(self, n: Union[str, PrettyEnum]) -> Any: return getattr(self, n.s if isinstance(n, PrettyEnum) else n) def _set_or_debug( self, needle: str, haystack, attr: Optional[Union[str, PrettyEnum]] = None ) -> Optional[Any]: if isinstance(attr, PrettyEnum): attr = attr.s if needle in haystack: if attr is None: return haystack[needle] setattr(self, attr, haystack[needle]) elif attr is not None: logg.debug(f"Unable to set attribute `.{attr}`, skipping")
[docs] def copy(self) -> "BaseEstimator": """Return a copy of self, including the underlying :paramref:`adata` object.""" k = deepcopy(self.kernel) # ensure we copy the adata object res = type(self)(k, read_from_adata=False) for k, v in self.__dict__.items(): if isinstance(v, dict): res.__dict__[k] = deepcopy(v) elif k != "_kernel": res.__dict__[k] = copy(v) return res
def __copy__(self) -> "BaseEstimator": return self.copy()
[docs] @d.dedent def write(self, fname: Union[str, Path], ext: Optional[str] = "pickle") -> None: """ %(pickleable.full_desc)s Parameters ---------- %(pickleable.parameters)s Returns ------- %(pickleable.returns)s """ # noqa fname = str(fname) if ext is not None: if not ext.startswith("."): ext = "." + ext if not fname.endswith(ext): fname += ext logg.debug(f"Writing to `{fname}`") with open(fname, "wb") as fout: if version_info[:2] > (3, 6): pickle.dump(self, fout) else: # we need to use PrecomputedKernel because Python3.6 can't pickle Enums # and they are present in VelocityKernel logg.warning("Saving kernel as `cellrank.tl.kernels.PrecomputedKernel`") orig_kernel = self.kernel self._kernel = PrecomputedKernel(self.kernel) try: pickle.dump(self, fout) except Exception as e: # noqa: B902 raise e finally: self._kernel = orig_kernel