import abc
import contextlib
import copy as copy_
import inspect
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
import numpy as np
import pandas as pd
import scipy.sparse as sp
from anndata import AnnData
from cellrank._utils._docs import d
from cellrank._utils._key import Key
from cellrank._utils._lineage import Lineage
from cellrank.estimators.mixins import KernelMixin
from cellrank.kernels import PrecomputedKernel
from cellrank.kernels._base_kernel import KernelExpression
from cellrank.kernels.mixins import AnnDataMixin, IOMixin
__all__ = ["BaseEstimator"]
Attr_t = (Literal["X", "raw", "layers", "obs", "var", "obsm", "varm", "obsp", "varp", "uns"],)
[docs]
@d.get_sections(base="base_estimator", sections=["Parameters"])
class BaseEstimator(IOMixin, KernelMixin, AnnDataMixin, abc.ABC):
"""Base class for all estimators.
Parameters
----------
object
Can be one of the following types:
- :class:`~anndata.AnnData` - annotated data object.
- :class:`~scipy.sparse.spmatrix`, :class:`~numpy.ndarray` - row-normalized transition matrix.
- :class:`~cellrank.kernels.KernelExpression` - kernel expression.
- :class:`str` - key in :attr:`~anndata.AnnData.obsp` where the transition matrix is stored and
``adata`` must be provided in this case.
- :class:`bool` - directionality of the transition matrix that will be used to infer its storage location.
If :obj:`None`, the directionality will be determined automatically and
``adata`` must be provided in this case.
kwargs
Keyword arguments for the :class:`~cellrank.kernels.PrecomputedKernel`.
"""
def __init__(
self,
object: Union[str, bool, np.ndarray, sp.spmatrix, AnnData, KernelExpression],
**kwargs: Any,
):
if isinstance(object, KernelExpression):
if object.transition_matrix is None:
raise RuntimeError("Compute transition matrix first as `.compute_transition_matrix()`.")
else:
object = PrecomputedKernel(object, copy=False, **kwargs)
super().__init__(kernel=object)
self._params: Dict[str, Any] = {}
self._shadow_adata = AnnData(
X=sp.csr_matrix(self.adata.shape, dtype=self.adata.X.dtype),
obs=self.adata.obs[[]].copy(),
var=self.adata.var[[]].copy(),
raw=None if self.adata.raw is None else self.adata.raw.to_adata(),
)
def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__()
[docs]
@abc.abstractmethod
def fit(self, *args: Any, **kwargs: Any) -> "BaseEstimator":
"""Fit the estimator.
Parameters
----------
args
Positional arguments.
kwargs
Keyword arguments.
Returns
-------
Self.
"""
[docs]
@abc.abstractmethod
def predict(self, *args: Any, **kwargs: Any) -> "BaseEstimator":
"""Run the prediction.
Parameters
----------
args
Positional arguments.
kwargs
Keyword arguments.
Returns
-------
Self.
"""
def _set(
self,
attr: Optional[str] = None,
obj: Optional[Union[pd.DataFrame, Mapping[str, Any]]] = None,
key: Optional[str] = None,
value: Optional[Union[np.ndarray, pd.Series, pd.DataFrame, Lineage, AnnData, Dict[str, Any]]] = None,
copy: bool = True,
shadow_only: bool = False,
) -> None:
"""Set an attribute and optionally update ``obj['{key}']``.
Parameters
----------
attr
Attribute to set. Only updated when we're not in the shadow. If :obj:`None`, don't update anything.
See :attr:`_in_shadow` and ``obj`` for more information.
obj
Object which to update with ``value`` alongside the ``attr``.
Usually, an attribute of :attr:`adata` is passed here.
key
Key in ``obj`` to update with ``value``. Only used when ``obj != None``.
value
Value to set. If :obj:`None` and ``key != None``, it removes the values under ``obj['{key}']``, if present.
copy
Whether to copy the ``value`` before setting it in ``obj``.
shadow_only
Whether to update the ``obj`` if we are not in the shadow.
Returns
-------
Nothing, just optionally updates ``attr`` and/or ``obj[{key}]``.
Raises
------
AttributeError
If ``attr`` doesn't exist.
"""
if not self._in_shadow:
if attr is not None:
if not hasattr(self, attr):
raise AttributeError(attr)
setattr(self, attr, value)
if shadow_only:
return
if obj is None:
return
if key is not None:
if value is None:
with contextlib.suppress(KeyError):
del obj[key]
else:
obj[key] = copy_.copy(value) if copy else value
def _get(
self,
*,
obj: Union[pd.DataFrame, Mapping[str, Any]],
key: str,
shadow_attr: Optional[Literal["obs", "obsm", "var", "varm", "uns"]] = None,
dtype: Optional[Union[type, Tuple[type, ...]]] = None,
copy: bool = True,
allow_missing: bool = False,
) -> Any:
"""Get data from an object and set an attribute.
Parameters
----------
obj
Object from which to extract the data.
key
Key in ``obj`` where the data is stored.
shadow_attr
Attribute of :attr:`_shadow_adata` where to save the extracted data. If :obj:`None`, don't update it.
dtype
Valid type(s) of the extracted data.
copy
Copy the data before setting the ``self_attr``.
allow_missing
Whether to allow ``key`` to be missing in ``obj``.
Returns
-------
The extracted values and optionally updates :attr:`_shadow_adata` ``.{shadow_attr}``.
Raises
------
AttributeError
If ``attr`` doesn't exist.
TypeError
If ``dtype != None`` and the extracted values are not instances of ``dtype``.
KeyError
If ``allow_missing = False`` and ``key`` was not found in ``obj``.
"""
if shadow_attr is not None and not hasattr(self._shadow_adata, shadow_attr):
raise AttributeError(shadow_attr)
try:
data = obj[key]
if dtype is not None and not isinstance(data, dtype):
raise TypeError(f"Expected object to be of type `{dtype}`, found `{type(data).__name__}`.")
if copy:
data = copy_.copy(data)
if shadow_attr is not None:
getattr(self._shadow_adata, shadow_attr)[key] = data
return data
except KeyError:
if not allow_missing:
raise
return None
@property
@contextlib.contextmanager
def _shadow(self) -> None:
"""Temporarily set :attr:`adata` to :attr:`_shadow_adata`.
Used to construct the serialization object in :meth:`to_adata`.
"""
if self._in_shadow:
yield
else:
adata = self.adata
try:
self.adata = self._shadow_adata
yield
finally:
self.adata = adata
@property
def _in_shadow(self) -> bool:
"""Return `True` if :attr:`adata` is :attr:`_shadow_adata`."""
return self.adata is self._shadow_adata
def _create_params(
self,
locs: Optional[Mapping[str, Any]] = None,
func: Optional[Callable] = None,
remove: Sequence[str] = (),
) -> Dict[str, Any]:
"""Create parameters of interest from a function call.
Parameters
----------
locs
Environment from which to get the parameters. If `None`, get the caller's environment.
func
Function of interest. If :obj:`None`, use the caller.
remove
Keys in ``locs`` which should not be included in the result.
Returns
-------
The parameters as a :class:`dict`.
Notes
-----
*args/**kwargs are always ignored and the values in ``locs`` are not copied.
"""
frame = inspect.currentframe()
try:
if locs is None:
locs = frame.f_back.f_locals
if func is None:
name = frame.f_back.f_code.co_name
func = dict(inspect.getmembers(self)).get(name, None)
if not callable(func):
raise TypeError(f"Expected `func` to be `callable`, found `{type(func).__name__}`.")
params = {}
for name, param in inspect.signature(func).parameters.items():
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
continue
if name in remove:
continue
if name in locs:
params[name] = locs[name]
return params
except AttributeError:
# frame can be None
return {}
except TypeError:
return {}
finally:
del frame
def _read_params(self, key: str) -> Dict[str, Any]:
"""Read ``key`` from estimator params in :attr:`adata`.
Usually called in :meth:`_read_adata` during :meth:`from_adata`.
"""
ekey = Key.uns.estimator(self.backward) + "_params"
return dict(self.adata.uns.get(ekey, {}).get(key, {}))
[docs]
@d.dedent
def to_adata(
self,
keep: Union[Literal["all"], Sequence[Attr_t]] = ("X", "raw"),
*,
copy: Union[bool, Sequence[Attr_t]] = True,
) -> AnnData:
"""%(to_adata.full_desc)s
Parameters
----------
keep
Which attributes to keep from the underlying :attr:`adata`. Valid options are:
- ``'all'`` - keep all attributes specified in the signature.
- :class:`~typing.Sequence` - keep only subset of these attributes.
- :class:`dict` - the keys correspond the attribute names and values to a subset of keys
which to keep from this attribute. If the values are specified either as :obj:`True` or ``'all'``,
everything from this attribute will be kept.
copy
Whether to copy the data. Can be specified on per-attribute basis.
Useful for attributes that are array-like.
Returns
-------
Annotated data object.
""" # noqa: D400
def handle_attribute(attr: Attr_t, keys: List[str], *, copy: bool) -> None:
try:
if attr == "X":
adata.X = copy_.deepcopy(self.adata.X) if copy else self.adata.X
return
if attr == "raw":
adata.raw = self.adata.raw.to_adata() if self.adata.raw is not None else None
return
old = getattr(self.adata, attr)
new = getattr(adata, attr)
if keys == ["all"]:
keys = list(old.keys())
# fmt: off
if isinstance(new, pd.DataFrame):
old = old[keys]
# avoid duplicates
old = old[old.columns.difference(new.columns)]
setattr(adata, attr, pd.merge(new, old, how="inner", left_index=True, right_index=True, copy=copy))
elif isinstance(new, Mapping):
old = {k: old[k] for k in keys}
# old has preference, since it's user supplied
setattr(adata, attr, {**new, **(copy_.deepcopy(old) if copy else old)})
else:
raise TypeError(f"Expected `adata.{attr}` to be either `Mapping` or `pandas. DataFrame`, "
f"found `{type(new).__name__}`.")
# fmt: on
except KeyError:
missing = sorted(k for k in keys if k not in old)
raise KeyError(f"Unable to find key(s) `{missing}` in `adata.{attr}`.") from None
adata = self._shadow_adata.copy()
_adata = self.adata
try:
# kernel and estimator share the adata
self.adata = adata
self.kernel.write_to_adata()
finally:
self.adata = _adata
key = Key.uns.estimator(self.backward) + "_params"
adata.uns[key] = copy_.deepcopy(self.params)
# fmt: off
if isinstance(keep, str):
if keep == "all":
keep = ["X", "raw", "layers", "obs", "var", "obsm", "varm", "obsp", "varp", "uns"]
else:
keep = [keep]
if not isinstance(keep, Mapping):
keep = {attr: True for attr in keep}
if isinstance(copy, bool):
copy = {attr: copy for attr in keep}
elif isinstance(copy, str):
copy = [copy]
if not isinstance(copy, Mapping):
copy = {attr: True for attr in copy}
# fmt: on
for attr, keys in keep.items():
if keys is True:
keys = ["all"]
elif isinstance(keys, str):
keys = [keys]
if keys is False or not len(keys):
continue
handle_attribute(attr, keys=list(keys), copy=copy.get(attr, False))
return adata
[docs]
@classmethod
@d.dedent
def from_adata(cls, adata: AnnData, obsp_key: str) -> "BaseEstimator":
"""%(from_adata.full_desc)s
Parameters
----------
%(adata)s
obsp_key
Key in :attr:`~anndata.AnnData.obsp` where the transition matrix is stored.
Returns
-------
%(from_adata.returns)s
""" # noqa: D400
return super().from_adata(adata, obsp_key=obsp_key)
[docs]
@d.dedent
def copy(self, *, deep: bool = False) -> "BaseEstimator":
"""Return a copy of self.
Parameters
----------
deep
Whether to return a deep copy or not. If :obj:`True`, this also copies the :attr:`adata`.
Returns
-------
A copy of self.
"""
k = copy_.deepcopy(self.kernel) if deep else copy_.copy(self.kernel)
res = type(self)(k)
for k, v in self.__dict__.items():
if isinstance(v, Mapping):
res.__dict__[k] = copy_.deepcopy(v)
elif k != "_kernel":
res.__dict__[k] = copy_.deepcopy(v) if deep else copy_.copy(v)
return res
def __copy__(self) -> "BaseEstimator":
return self.copy(deep=False)
def _format_params(self) -> str:
return f"kernel={self.kernel!s}"
def __repr__(self) -> str:
return f"{self.__class__.__name__}[{self._format_params()}]"
def __str__(self) -> str:
return repr(self)
@property
def params(self) -> Dict[str, Any]:
"""Estimator parameters."""
return self._params