import abc
import contextlib
import copy as copy_
import inspect
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal
import numpy as np
import pandas as pd
import scipy.sparse as sp
from anndata import AnnData
from cellrank._utils._docs import d
from cellrank._utils._key import Key
from cellrank._utils._lineage import Lineage
from cellrank.estimators.mixins import KernelMixin
from cellrank.kernels import PrecomputedKernel
from cellrank.kernels._base_kernel import KernelExpression
from cellrank.kernels.mixins import AnnDataMixin, IOMixin
__all__ = ["BaseEstimator"]
Attr_t = (Literal["X", "raw", "layers", "obs", "var", "obsm", "varm", "obsp", "varp", "uns"],)
[docs]
@d.get_sections(base="base_estimator", sections=["Parameters"])
class BaseEstimator(IOMixin, KernelMixin, AnnDataMixin, abc.ABC):
"""Base class for all estimators.
Parameters
----------
object
Can be one of the following types:
- :class:`~anndata.AnnData` - annotated data object.
- :class:`~scipy.sparse.spmatrix`, :class:`~numpy.ndarray` - row-normalized transition matrix.
- :class:`~cellrank.kernels.KernelExpression` - kernel expression.
- :class:`str` - key in :attr:`~anndata.AnnData.obsp` where the transition matrix is stored and
``adata`` must be provided in this case.
- :class:`bool` - directionality of the transition matrix that will be used to infer its storage location.
If :obj:`None`, the directionality will be determined automatically and
``adata`` must be provided in this case.
kwargs
Keyword arguments for the :class:`~cellrank.kernels.PrecomputedKernel`.
"""
def __init__(
self,
object: str | bool | np.ndarray | sp.spmatrix | AnnData | KernelExpression,
**kwargs: Any,
):
if isinstance(object, KernelExpression):
if object.transition_matrix is None:
raise RuntimeError("Compute transition matrix first as `.compute_transition_matrix()`.")
else:
object = PrecomputedKernel(object, copy=False, **kwargs)
super().__init__(kernel=object)
self._params: dict[str, Any] = {}
self._shadow_adata = AnnData(
X=sp.csr_matrix(self.adata.shape, dtype=self.adata.X.dtype),
obs=self.adata.obs[[]].copy(),
var=self.adata.var[[]].copy(),
raw=None if self.adata.raw is None else self.adata.raw.to_adata(),
)
def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__()
[docs]
@abc.abstractmethod
def fit(self, *args: Any, **kwargs: Any) -> "BaseEstimator":
"""Fit the estimator.
.. deprecated:: 2.1
Will be removed in CellRank 3.0. Use the individual methods directly
(e.g., :meth:`compute_schur`, :meth:`compute_macrostates`).
Parameters
----------
args
Positional arguments.
kwargs
Keyword arguments.
Returns
-------
Self.
"""
[docs]
@abc.abstractmethod
def predict(self, *args: Any, **kwargs: Any) -> "BaseEstimator":
"""Run the prediction.
.. deprecated:: 2.1
Will be removed in CellRank 3.0. Use the individual methods directly
(e.g., :meth:`predict_terminal_states`).
Parameters
----------
args
Positional arguments.
kwargs
Keyword arguments.
Returns
-------
Self.
"""
def _set(
self,
attr: str | None = None,
obj: pd.DataFrame | Mapping[str, Any] | None = None,
key: str | None = None,
value: np.ndarray | pd.Series | pd.DataFrame | Lineage | AnnData | dict[str, Any] | None = None,
copy: bool = True,
shadow_only: bool = False,
) -> None:
"""Set an attribute and optionally update ``obj['{key}']``.
Parameters
----------
attr
Attribute to set. Only updated when we're not in the shadow. If :obj:`None`, don't update anything.
See :attr:`_in_shadow` and ``obj`` for more information.
obj
Object which to update with ``value`` alongside the ``attr``.
Usually, an attribute of :attr:`adata` is passed here.
key
Key in ``obj`` to update with ``value``. Only used when ``obj != None``.
value
Value to set. If :obj:`None` and ``key != None``, it removes the values under ``obj['{key}']``, if present.
copy
Whether to copy the ``value`` before setting it in ``obj``.
shadow_only
Whether to update the ``obj`` if we are not in the shadow.
Returns
-------
Nothing, just optionally updates ``attr`` and/or ``obj[{key}]``.
Raises
------
AttributeError
If ``attr`` doesn't exist.
"""
if not self._in_shadow:
if attr is not None:
if not hasattr(self, attr):
raise AttributeError(attr)
setattr(self, attr, value)
if shadow_only:
return
if obj is None:
return
if key is not None:
if value is None:
with contextlib.suppress(KeyError):
del obj[key]
else:
obj[key] = copy_.copy(value) if copy else value
def _get(
self,
*,
obj: pd.DataFrame | Mapping[str, Any],
key: str,
shadow_attr: Literal["obs", "obsm", "var", "varm", "uns"] | None = None,
dtype: type | tuple[type, ...] | None = None,
copy: bool = True,
allow_missing: bool = False,
) -> Any:
"""Get data from an object and set an attribute.
Parameters
----------
obj
Object from which to extract the data.
key
Key in ``obj`` where the data is stored.
shadow_attr
Attribute of :attr:`_shadow_adata` where to save the extracted data. If :obj:`None`, don't update it.
dtype
Valid type(s) of the extracted data.
copy
Copy the data before setting the ``self_attr``.
allow_missing
Whether to allow ``key`` to be missing in ``obj``.
Returns
-------
The extracted values and optionally updates :attr:`_shadow_adata` ``.{shadow_attr}``.
Raises
------
AttributeError
If ``attr`` doesn't exist.
TypeError
If ``dtype != None`` and the extracted values are not instances of ``dtype``.
KeyError
If ``allow_missing = False`` and ``key`` was not found in ``obj``.
"""
if shadow_attr is not None and not hasattr(self._shadow_adata, shadow_attr):
raise AttributeError(shadow_attr)
try:
data = obj[key]
if dtype is not None and not isinstance(data, dtype):
raise TypeError(f"Expected object to be of type `{dtype}`, found `{type(data).__name__}`.")
if copy:
data = copy_.copy(data)
if shadow_attr is not None:
getattr(self._shadow_adata, shadow_attr)[key] = data
return data
except KeyError:
if not allow_missing:
raise
return None
@property
@contextlib.contextmanager
def _shadow(self) -> None:
"""Temporarily set :attr:`adata` to :attr:`_shadow_adata`.
Used to construct the serialization object in :meth:`to_adata`.
"""
if self._in_shadow:
yield
else:
adata = self.adata
try:
self.adata = self._shadow_adata
yield
finally:
self.adata = adata
@property
def _in_shadow(self) -> bool:
"""Return `True` if :attr:`adata` is :attr:`_shadow_adata`."""
return self.adata is self._shadow_adata
def _create_params(
self,
locs: Mapping[str, Any] | None = None,
func: Callable | None = None,
remove: Sequence[str] = (),
) -> dict[str, Any]:
"""Create parameters of interest from a function call.
Parameters
----------
locs
Environment from which to get the parameters. If `None`, get the caller's environment.
func
Function of interest. If :obj:`None`, use the caller.
remove
Keys in ``locs`` which should not be included in the result.
Returns
-------
The parameters as a :class:`dict`.
Notes
-----
*args/**kwargs are always ignored and the values in ``locs`` are not copied.
"""
frame = inspect.currentframe()
try:
if locs is None:
locs = frame.f_back.f_locals
if func is None:
name = frame.f_back.f_code.co_name
func = dict(inspect.getmembers(self)).get(name, None)
if not callable(func):
raise TypeError(f"Expected `func` to be `callable`, found `{type(func).__name__}`.")
params = {}
for name, param in inspect.signature(func).parameters.items():
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
continue
if name in remove:
continue
if name in locs:
params[name] = locs[name]
return params
except AttributeError:
# frame can be None
return {}
except TypeError:
return {}
finally:
del frame
def _read_params(self, key: str) -> dict[str, Any]:
"""Read ``key`` from estimator params in :attr:`adata`.
Usually called in :meth:`_read_adata` during :meth:`from_adata`.
"""
ekey = Key.uns.estimator(self.backward) + "_params"
return dict(self.adata.uns.get(ekey, {}).get(key, {}))
[docs]
@d.dedent
def to_adata(
self,
keep: Literal["all"] | Sequence[Attr_t] = ("X", "raw"),
*,
copy: bool | Sequence[Attr_t] = True,
) -> AnnData:
"""%(to_adata.full_desc)s
Parameters
----------
keep
Which attributes to keep from the underlying :attr:`adata`. Valid options are:
- ``'all'`` - keep all attributes specified in the signature.
- :class:`~typing.Sequence` - keep only subset of these attributes.
- :class:`dict` - the keys correspond the attribute names and values to a subset of keys
which to keep from this attribute. If the values are specified either as :obj:`True` or ``'all'``,
everything from this attribute will be kept.
copy
Whether to copy the data. Can be specified on per-attribute basis.
Useful for attributes that are array-like.
Returns
-------
Annotated data object.
""" # noqa: D400
def handle_attribute(attr: Attr_t, keys: list[str], *, copy: bool) -> None:
try:
if attr == "X":
adata.X = copy_.deepcopy(self.adata.X) if copy else self.adata.X
return
if attr == "raw":
adata.raw = self.adata.raw.to_adata() if self.adata.raw is not None else None
return
old = getattr(self.adata, attr)
new = getattr(adata, attr)
if keys == ["all"]:
keys = list(old.keys())
# fmt: off
if isinstance(new, pd.DataFrame):
old = old[keys]
# avoid duplicates
old = old[old.columns.difference(new.columns)]
setattr(adata, attr, pd.merge(new, old, how="inner", left_index=True, right_index=True, copy=copy))
elif isinstance(new, Mapping):
old = {k: old[k] for k in keys}
# old has preference, since it's user supplied
setattr(adata, attr, {**new, **(copy_.deepcopy(old) if copy else old)})
else:
raise TypeError(f"Expected `adata.{attr}` to be either `Mapping` or `pandas. DataFrame`, "
f"found `{type(new).__name__}`.")
# fmt: on
except KeyError:
missing = sorted(k for k in keys if k not in old)
raise KeyError(f"Unable to find key(s) `{missing}` in `adata.{attr}`.") from None
adata = self._shadow_adata.copy()
_adata = self.adata
try:
# kernel and estimator share the adata
self.adata = adata
self.kernel.write_to_adata()
finally:
self.adata = _adata
key = Key.uns.estimator(self.backward) + "_params"
adata.uns[key] = copy_.deepcopy(self.params)
# fmt: off
if isinstance(keep, str):
if keep == "all":
keep = ["X", "raw", "layers", "obs", "var", "obsm", "varm", "obsp", "varp", "uns"]
else:
keep = [keep]
if not isinstance(keep, Mapping):
keep = dict.fromkeys(keep, True)
if isinstance(copy, bool):
copy = dict.fromkeys(keep, copy)
elif isinstance(copy, str):
copy = [copy]
if not isinstance(copy, Mapping):
copy = dict.fromkeys(copy, True)
# fmt: on
for attr, keys in keep.items():
if keys is True:
keys = ["all"]
elif isinstance(keys, str):
keys = [keys]
if keys is False or not len(keys):
continue
handle_attribute(attr, keys=list(keys), copy=copy.get(attr, False))
return adata
[docs]
@classmethod
@d.dedent
def from_adata(cls, adata: AnnData, obsp_key: str) -> "BaseEstimator":
"""%(from_adata.full_desc)s
Parameters
----------
%(adata)s
obsp_key
Key in :attr:`~anndata.AnnData.obsp` where the transition matrix is stored.
Returns
-------
%(from_adata.returns)s
""" # noqa: D400
return super().from_adata(adata, obsp_key=obsp_key)
[docs]
@d.dedent
def copy(self, *, deep: bool = False) -> "BaseEstimator":
"""Return a copy of self.
Parameters
----------
deep
Whether to return a deep copy or not. If :obj:`True`, this also copies the :attr:`adata`.
Returns
-------
A copy of self.
"""
k = copy_.deepcopy(self.kernel) if deep else copy_.copy(self.kernel)
res = type(self)(k)
for k, v in self.__dict__.items():
if isinstance(v, Mapping):
res.__dict__[k] = copy_.deepcopy(v)
elif k != "_kernel":
res.__dict__[k] = copy_.deepcopy(v) if deep else copy_.copy(v)
return res
def __copy__(self) -> "BaseEstimator":
return self.copy(deep=False)
def _format_params(self) -> str:
return f"kernel={self.kernel!s}"
def __repr__(self) -> str:
return f"{self.__class__.__name__}[{self._format_params()}]"
def __str__(self) -> str:
return repr(self)
@property
def params(self) -> dict[str, Any]:
"""Estimator parameters."""
return self._params