Source code for cellrank.estimators._base_estimator

import abc
import contextlib
import copy as copy_
import inspect
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal

import numpy as np
import pandas as pd
import scipy.sparse as sp
from anndata import AnnData

from cellrank._utils._docs import d
from cellrank._utils._key import Key
from cellrank._utils._lineage import Lineage
from cellrank.estimators.mixins import KernelMixin
from cellrank.kernels import PrecomputedKernel
from cellrank.kernels._base_kernel import KernelExpression
from cellrank.kernels.mixins import AnnDataMixin, IOMixin

__all__ = ["BaseEstimator"]

Attr_t = (Literal["X", "raw", "layers", "obs", "var", "obsm", "varm", "obsp", "varp", "uns"],)


[docs] @d.get_sections(base="base_estimator", sections=["Parameters"]) class BaseEstimator(IOMixin, KernelMixin, AnnDataMixin, abc.ABC): """Base class for all estimators. Parameters ---------- object Can be one of the following types: - :class:`~anndata.AnnData` - annotated data object. - :class:`~scipy.sparse.spmatrix`, :class:`~numpy.ndarray` - row-normalized transition matrix. - :class:`~cellrank.kernels.KernelExpression` - kernel expression. - :class:`str` - key in :attr:`~anndata.AnnData.obsp` where the transition matrix is stored and ``adata`` must be provided in this case. - :class:`bool` - directionality of the transition matrix that will be used to infer its storage location. If :obj:`None`, the directionality will be determined automatically and ``adata`` must be provided in this case. kwargs Keyword arguments for the :class:`~cellrank.kernels.PrecomputedKernel`. """ def __init__( self, object: str | bool | np.ndarray | sp.spmatrix | AnnData | KernelExpression, **kwargs: Any, ): if isinstance(object, KernelExpression): if object.transition_matrix is None: raise RuntimeError("Compute transition matrix first as `.compute_transition_matrix()`.") else: object = PrecomputedKernel(object, copy=False, **kwargs) super().__init__(kernel=object) self._params: dict[str, Any] = {} self._shadow_adata = AnnData( X=sp.csr_matrix(self.adata.shape, dtype=self.adata.X.dtype), obs=self.adata.obs[[]].copy(), var=self.adata.var[[]].copy(), raw=None if self.adata.raw is None else self.adata.raw.to_adata(), ) def __init_subclass__(cls, **kwargs: Any): super().__init_subclass__()
[docs] @abc.abstractmethod def fit(self, *args: Any, **kwargs: Any) -> "BaseEstimator": """Fit the estimator. .. deprecated:: 2.1 Will be removed in CellRank 3.0. Use the individual methods directly (e.g., :meth:`compute_schur`, :meth:`compute_macrostates`). Parameters ---------- args Positional arguments. kwargs Keyword arguments. Returns ------- Self. """
[docs] @abc.abstractmethod def predict(self, *args: Any, **kwargs: Any) -> "BaseEstimator": """Run the prediction. .. deprecated:: 2.1 Will be removed in CellRank 3.0. Use the individual methods directly (e.g., :meth:`predict_terminal_states`). Parameters ---------- args Positional arguments. kwargs Keyword arguments. Returns ------- Self. """
def _set( self, attr: str | None = None, obj: pd.DataFrame | Mapping[str, Any] | None = None, key: str | None = None, value: np.ndarray | pd.Series | pd.DataFrame | Lineage | AnnData | dict[str, Any] | None = None, copy: bool = True, shadow_only: bool = False, ) -> None: """Set an attribute and optionally update ``obj['{key}']``. Parameters ---------- attr Attribute to set. Only updated when we're not in the shadow. If :obj:`None`, don't update anything. See :attr:`_in_shadow` and ``obj`` for more information. obj Object which to update with ``value`` alongside the ``attr``. Usually, an attribute of :attr:`adata` is passed here. key Key in ``obj`` to update with ``value``. Only used when ``obj != None``. value Value to set. If :obj:`None` and ``key != None``, it removes the values under ``obj['{key}']``, if present. copy Whether to copy the ``value`` before setting it in ``obj``. shadow_only Whether to update the ``obj`` if we are not in the shadow. Returns ------- Nothing, just optionally updates ``attr`` and/or ``obj[{key}]``. Raises ------ AttributeError If ``attr`` doesn't exist. """ if not self._in_shadow: if attr is not None: if not hasattr(self, attr): raise AttributeError(attr) setattr(self, attr, value) if shadow_only: return if obj is None: return if key is not None: if value is None: with contextlib.suppress(KeyError): del obj[key] else: obj[key] = copy_.copy(value) if copy else value def _get( self, *, obj: pd.DataFrame | Mapping[str, Any], key: str, shadow_attr: Literal["obs", "obsm", "var", "varm", "uns"] | None = None, dtype: type | tuple[type, ...] | None = None, copy: bool = True, allow_missing: bool = False, ) -> Any: """Get data from an object and set an attribute. Parameters ---------- obj Object from which to extract the data. key Key in ``obj`` where the data is stored. shadow_attr Attribute of :attr:`_shadow_adata` where to save the extracted data. If :obj:`None`, don't update it. dtype Valid type(s) of the extracted data. copy Copy the data before setting the ``self_attr``. allow_missing Whether to allow ``key`` to be missing in ``obj``. Returns ------- The extracted values and optionally updates :attr:`_shadow_adata` ``.{shadow_attr}``. Raises ------ AttributeError If ``attr`` doesn't exist. TypeError If ``dtype != None`` and the extracted values are not instances of ``dtype``. KeyError If ``allow_missing = False`` and ``key`` was not found in ``obj``. """ if shadow_attr is not None and not hasattr(self._shadow_adata, shadow_attr): raise AttributeError(shadow_attr) try: data = obj[key] if dtype is not None and not isinstance(data, dtype): raise TypeError(f"Expected object to be of type `{dtype}`, found `{type(data).__name__}`.") if copy: data = copy_.copy(data) if shadow_attr is not None: getattr(self._shadow_adata, shadow_attr)[key] = data return data except KeyError: if not allow_missing: raise return None @property @contextlib.contextmanager def _shadow(self) -> None: """Temporarily set :attr:`adata` to :attr:`_shadow_adata`. Used to construct the serialization object in :meth:`to_adata`. """ if self._in_shadow: yield else: adata = self.adata try: self.adata = self._shadow_adata yield finally: self.adata = adata @property def _in_shadow(self) -> bool: """Return `True` if :attr:`adata` is :attr:`_shadow_adata`.""" return self.adata is self._shadow_adata def _create_params( self, locs: Mapping[str, Any] | None = None, func: Callable | None = None, remove: Sequence[str] = (), ) -> dict[str, Any]: """Create parameters of interest from a function call. Parameters ---------- locs Environment from which to get the parameters. If `None`, get the caller's environment. func Function of interest. If :obj:`None`, use the caller. remove Keys in ``locs`` which should not be included in the result. Returns ------- The parameters as a :class:`dict`. Notes ----- *args/**kwargs are always ignored and the values in ``locs`` are not copied. """ frame = inspect.currentframe() try: if locs is None: locs = frame.f_back.f_locals if func is None: name = frame.f_back.f_code.co_name func = dict(inspect.getmembers(self)).get(name, None) if not callable(func): raise TypeError(f"Expected `func` to be `callable`, found `{type(func).__name__}`.") params = {} for name, param in inspect.signature(func).parameters.items(): if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): continue if name in remove: continue if name in locs: params[name] = locs[name] return params except AttributeError: # frame can be None return {} except TypeError: return {} finally: del frame def _read_params(self, key: str) -> dict[str, Any]: """Read ``key`` from estimator params in :attr:`adata`. Usually called in :meth:`_read_adata` during :meth:`from_adata`. """ ekey = Key.uns.estimator(self.backward) + "_params" return dict(self.adata.uns.get(ekey, {}).get(key, {}))
[docs] @d.dedent def to_adata( self, keep: Literal["all"] | Sequence[Attr_t] = ("X", "raw"), *, copy: bool | Sequence[Attr_t] = True, ) -> AnnData: """%(to_adata.full_desc)s Parameters ---------- keep Which attributes to keep from the underlying :attr:`adata`. Valid options are: - ``'all'`` - keep all attributes specified in the signature. - :class:`~typing.Sequence` - keep only subset of these attributes. - :class:`dict` - the keys correspond the attribute names and values to a subset of keys which to keep from this attribute. If the values are specified either as :obj:`True` or ``'all'``, everything from this attribute will be kept. copy Whether to copy the data. Can be specified on per-attribute basis. Useful for attributes that are array-like. Returns ------- Annotated data object. """ # noqa: D400 def handle_attribute(attr: Attr_t, keys: list[str], *, copy: bool) -> None: try: if attr == "X": adata.X = copy_.deepcopy(self.adata.X) if copy else self.adata.X return if attr == "raw": adata.raw = self.adata.raw.to_adata() if self.adata.raw is not None else None return old = getattr(self.adata, attr) new = getattr(adata, attr) if keys == ["all"]: keys = list(old.keys()) # fmt: off if isinstance(new, pd.DataFrame): old = old[keys] # avoid duplicates old = old[old.columns.difference(new.columns)] setattr(adata, attr, pd.merge(new, old, how="inner", left_index=True, right_index=True, copy=copy)) elif isinstance(new, Mapping): old = {k: old[k] for k in keys} # old has preference, since it's user supplied setattr(adata, attr, {**new, **(copy_.deepcopy(old) if copy else old)}) else: raise TypeError(f"Expected `adata.{attr}` to be either `Mapping` or `pandas. DataFrame`, " f"found `{type(new).__name__}`.") # fmt: on except KeyError: missing = sorted(k for k in keys if k not in old) raise KeyError(f"Unable to find key(s) `{missing}` in `adata.{attr}`.") from None adata = self._shadow_adata.copy() _adata = self.adata try: # kernel and estimator share the adata self.adata = adata self.kernel.write_to_adata() finally: self.adata = _adata key = Key.uns.estimator(self.backward) + "_params" adata.uns[key] = copy_.deepcopy(self.params) # fmt: off if isinstance(keep, str): if keep == "all": keep = ["X", "raw", "layers", "obs", "var", "obsm", "varm", "obsp", "varp", "uns"] else: keep = [keep] if not isinstance(keep, Mapping): keep = dict.fromkeys(keep, True) if isinstance(copy, bool): copy = dict.fromkeys(keep, copy) elif isinstance(copy, str): copy = [copy] if not isinstance(copy, Mapping): copy = dict.fromkeys(copy, True) # fmt: on for attr, keys in keep.items(): if keys is True: keys = ["all"] elif isinstance(keys, str): keys = [keys] if keys is False or not len(keys): continue handle_attribute(attr, keys=list(keys), copy=copy.get(attr, False)) return adata
[docs] @classmethod @d.dedent def from_adata(cls, adata: AnnData, obsp_key: str) -> "BaseEstimator": """%(from_adata.full_desc)s Parameters ---------- %(adata)s obsp_key Key in :attr:`~anndata.AnnData.obsp` where the transition matrix is stored. Returns ------- %(from_adata.returns)s """ # noqa: D400 return super().from_adata(adata, obsp_key=obsp_key)
[docs] @d.dedent def copy(self, *, deep: bool = False) -> "BaseEstimator": """Return a copy of self. Parameters ---------- deep Whether to return a deep copy or not. If :obj:`True`, this also copies the :attr:`adata`. Returns ------- A copy of self. """ k = copy_.deepcopy(self.kernel) if deep else copy_.copy(self.kernel) res = type(self)(k) for k, v in self.__dict__.items(): if isinstance(v, Mapping): res.__dict__[k] = copy_.deepcopy(v) elif k != "_kernel": res.__dict__[k] = copy_.deepcopy(v) if deep else copy_.copy(v) return res
def __copy__(self) -> "BaseEstimator": return self.copy(deep=False) def _format_params(self) -> str: return f"kernel={self.kernel!s}" def __repr__(self) -> str: return f"{self.__class__.__name__}[{self._format_params()}]" def __str__(self) -> str: return repr(self) @property def params(self) -> dict[str, Any]: """Estimator parameters.""" return self._params