Source code for cellrank.models._base_model

import abc
import collections
import copy
import enum
import re
import warnings
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import wrapt

import numpy as np
import pandas as pd
import scipy.sparse as sp
from pandas.api.types import infer_dtype
from scipy.ndimage import convolve

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm, colors
from mpl_toolkits.axes_grid1 import make_axes_locatable

from anndata import AnnData
from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation

from cellrank import logging as logg
from cellrank._utils._docs import d
from cellrank._utils._enum import ModeEnum
from cellrank._utils._lineage import Lineage
from cellrank._utils._utils import _densify_squeeze, _minmax, save_fig, valuedispatch
from cellrank.kernels.mixins import IOMixin

__all__ = ["BaseModel"]

_dup_spaces = re.compile(r" +")  # used on repr for underlying model's repr
ArrayLike = Union[np.ndarray, sp.spmatrix, List, Tuple]


class UnknownModelError(RuntimeError):
    pass


class FailedReturnType(ModeEnum):
    PREPARE = enum.auto()
    FIT = enum.auto()
    PREDICT = enum.auto()
    CONFIDENCE_INTERVAL = enum.auto()
    DEFAULT_CONFIDENCE_INTERVAL = enum.auto()
    PLOT = enum.auto()


class ColorType(ModeEnum):
    CONT = enum.auto()
    CAT = enum.auto()
    STR = enum.auto()


def _handle_exception(return_type: FailedReturnType, func: Callable) -> Callable:
    def handle(*, exception_handler: Callable):
        @wrapt.decorator
        def wrapper(wrapped, instance, args, kwargs):
            try:
                if isinstance(instance, FailedModel):
                    instance.reraise()
                return wrapped(*args, **kwargs)
            except Exception as e:  # noqa: BLE001
                return exception_handler(instance, e, *args, **kwargs)

        return wrapper

    def array_output(instance: "BaseModel", exc: BaseException, *_args, **kwargs) -> np.ndarray:
        if not instance._is_bulk:
            raise exc from None

        if kwargs.get("x_test", None) is not None:
            n_obs = kwargs["x_test"].shape[0]
        elif instance.x_test is not None:
            n_obs = instance.x_test.shape[0]
        else:
            n_obs = 200

        return np.full((n_obs, array_shape), np.nan, dtype=instance._dtype)

    def model_output(instance: "BaseModel", exc: BaseException, *_args, **_kwargs) -> "FailedModel":
        if not isinstance(instance, FailedModel):
            instance = FailedModel(instance, exc=exc)
        if not instance._is_bulk:
            instance.reraise()

        return instance

    def no_output(instance: "BaseModel", exc: BaseException, *_args, **_kwargs) -> None:
        if not instance._is_bulk:
            raise exc from None

    @valuedispatch
    def wrapper(mode: FailedReturnType, *_args, **_kwargs):
        raise NotImplementedError(mode)

    @wrapper.register(FailedReturnType.PREPARE)
    @wrapper.register(FailedReturnType.FIT)
    def _(func: Callable) -> Callable:
        return handle(exception_handler=model_output)(func)

    @wrapper.register(FailedReturnType.PREDICT)
    @wrapper.register(FailedReturnType.CONFIDENCE_INTERVAL)
    @wrapper.register(FailedReturnType.DEFAULT_CONFIDENCE_INTERVAL)
    def _(func: Callable) -> Callable:
        return handle(exception_handler=array_output)(func)

    @wrapper.register(FailedReturnType.PLOT)
    def _(func: Callable):
        return handle(exception_handler=no_output)(func)

    return_type = FailedReturnType(return_type)
    array_shape = 1 + (
        return_type
        in (
            FailedReturnType.CONFIDENCE_INTERVAL,
            FailedReturnType.DEFAULT_CONFIDENCE_INTERVAL,
        )
    )

    return wrapper(return_type, func)


class BaseModelMeta(abc.ABCMeta):
    """Metaclass for all base models."""

    def __new__(cls, clsname: str, superclasses: Tuple[type, ...], attributedict: Dict[str, Any]):
        """Create a new instance.

        Parameters
        ----------
        clsname
            Name of class to be constructed.
        superclasses
            List of superclasses.
        attributedict
            Dictionary of attributes.
        """
        obj = super().__new__(cls, clsname, superclasses, attributedict)
        for fun_name in list(FailedReturnType):
            setattr(
                obj,
                fun_name,
                _handle_exception(FailedReturnType(fun_name), getattr(obj, fun_name)),
            )

        return obj


[docs] @d.get_sections(base="base_model", sections=["Parameters"]) @d.dedent class BaseModel(IOMixin, abc.ABC, metaclass=BaseModelMeta): """Base class for all model classes. Parameters ---------- %(adata)s model The underlying model that is used for fitting and prediction. """ _dtype = np.float64 def __init__( self, adata: Optional[AnnData], model: Any, ): if not isinstance(adata, AnnData) and not isinstance(self, FittedModel): # FittedModel doesn't need it raise TypeError(f"Expected `adata` to be of type `anndata.AnnData`, found `{type(adata).__name__}`.") super().__init__() self._adata = adata self._n_obs = 0 if adata is None else adata.n_obs self._model = model self._gene = None self._use_raw = False self._data_key = None self._lineage = None self._prepared = False self._obs_names = None self._is_bulk = False self._x_all = None self._y_all = None self._w_all = None self._x = None self._y = None self._w = None self._x_test = None self._y_test = None self._x_hat = None self._y_hat = None self._conf_int = None @d.get_summary(base="base_model_prepared") @property def prepared(self): """Whether the model is prepared for fitting.""" return self._prepared @property @d.dedent def adata(self) -> AnnData: """Annotated data object.""" return self._adata @adata.setter def adata(self, adata: Optional[AnnData]) -> None: self._adata = adata @property def shape(self) -> Tuple[int]: """Number of cells in :attr:`adata`.""" return (self._n_obs,) @property def model(self) -> Any: """Underlying model.""" return self._model @property @d.get_summary(base="base_model_x_all") def x_all(self) -> np.ndarray: """Unfiltered independent variables of shape ``(n_cells, 1)``.""" return self._x_all @property @d.get_summary(base="base_model_y_all") def y_all(self) -> np.ndarray: """Unfiltered dependent variables of shape ``(n_cells, 1)``.""" return self._y_all @property @d.get_summary(base="base_model_w_all") def w_all(self) -> np.ndarray: """Unfiltered weights of shape ``(n_cells,)``.""" return self._w_all @property @d.get_summary(base="base_model_x") def x(self) -> np.ndarray: """Filtered independent variables of shape ``(n_filtered_cells, 1)`` used for fitting.""" return self._x @property @d.get_summary(base="base_model_y") def y(self) -> np.ndarray: """Filtered dependent variables of shape ``(n_filtered_cells, 1)`` used for fitting.""" return self._y @property @d.get_summary(base="base_model_w") def w(self) -> np.ndarray: """Filtered weights of shape ``(n_filtered_cells,)`` used for fitting.""" return self._w @property @d.get_summary(base="base_model_x_test") def x_test(self) -> np.ndarray: """Independent variables of shape ``(n_samples, 1)`` used for prediction.""" return self._x_test @property @d.get_summary(base="base_model_y_test") def y_test(self) -> np.ndarray: """Prediction values of shape ``(n_samples,)`` for :attr:`x_test`.""" return self._y_test @property @d.get_summary(base="base_model_x_hat") def x_hat(self) -> np.ndarray: """Filtered independent variables used when calculating default confidence interval, usually same as :attr:`x`.""" # noqa: E501 return self._x_hat @property @d.get_summary(base="base_model_y_hat") def y_hat(self) -> np.ndarray: """Filtered dependent variables used when calculating default confidence interval, usually same as :attr:`y`.""" return self._y_hat @property @d.get_summary(base="base_model_conf_int") def conf_int(self) -> np.ndarray: """Array of shape ``(n_samples, 2)`` containing the lower and upper bound of the confidence interval.""" return self._conf_int
[docs] @d.get_sections(base="base_model_prepare", sections=["Parameters", "Returns"]) @d.get_full_description(base="base_model_prepare") @d.dedent def prepare( self, gene: str, lineage: Optional[str], time_key: str, backward: bool = False, time_range: Optional[Union[float, Tuple[float, float]]] = None, data_key: Optional[str] = "X", use_raw: bool = False, threshold: Optional[float] = None, weight_threshold: Union[float, Tuple[float, float]] = (0.01, 0.01), filter_cells: Optional[float] = None, n_test_points: int = 200, ) -> "BaseModel": """Prepare the model to be ready for fitting. Parameters ---------- gene Gene in :attr:`~anndata.AnnData.var_names`. lineage Name of the lineage. If :obj:`None`, all weights will be set to :math:`1`. time_key Key in :attr:`~anndata.AnnData.obs` where the pseudotime is stored. %(backward)s %(time_range)s data_key Key in :attr:`~anndata.AnnData.layers` or ``'X'`` for :attr:`~anndata.AnnData.X`. If ``use_raw = True``, it's always set to ``'X'``. use_raw Whether to access :attr:`~anndata.AnnData.raw`. threshold Consider only cells with weights > ``threshold`` when estimating the test endpoint. If :obj:`None`, use the median of the weights. weight_threshold Set all weights below ``weight_threshold`` to ``weight_threshold`` if a :class:`float`, or to the second value, if a :class:`tuple`. filter_cells Filter out all cells with expression values lower than this threshold. n_test_points Number of test points. If :obj:`None`, use the original points based on ``threshold``. Returns ------- Nothing, just updates the following fields: - :attr:`x` - %(base_model_x.summary)s - :attr:`y` - %(base_model_y.summary)s - :attr:`w` - %(base_model_w.summary)s - :attr:`x_all` - %(base_model_x_all.summary)s - :attr:`y_all` - %(base_model_y_all.summary)s - :attr:`w_all` - %(base_model_w_all.summary)s - :attr:`x_test` - %(base_model_x_test.summary)s - :attr:`prepared` - %(base_model_prepared.summary)s """ if use_raw: if self.adata.raw is None: raise AttributeError("AnnData object has no attribute `.raw`.") if data_key != "X": data_key = "X" self._use_raw = use_raw if data_key not in ["X", "obs", None] + list(self.adata.layers.keys()): raise KeyError( f"Data key must be a key of `adata.layers`: `{list(self.adata.layers.keys())}`, " f"`adata.X` or `adata.obs`." ) probs = Lineage.from_adata(self.adata, backward=backward) if lineage is not None: probs = probs[lineage] if time_key not in self.adata.obs: raise KeyError(f"Time key `{time_key!r}` not found in `adata.obs`.") if data_key == "obs": if gene not in self.adata.obs: raise KeyError(f"Unable to find data in `adata.obs[{gene!r}]`.") else: if use_raw and gene not in self.adata.raw.var_names: raise KeyError(f"Gene `{gene!r}` not found in `adata.raw.var_names`.") if not use_raw and gene not in self.adata.var_names: raise KeyError(f"Gene `{gene!r}` not found in `adata.var_names`.") if not isinstance(time_range, (type(None), float, int, tuple)): raise TypeError( f"Expected time range to be either `None`, " f"`float` or `tuple`, found `{type(time_range).__name__!r}`." ) if isinstance(time_range, tuple): if len(time_range) != 2: raise ValueError(f"Expected time range to be a `tuple` of length `2`, found `{len(time_range)}`.") if not isinstance(time_range[0], (float, int, type(None))) or not isinstance( time_range[1], (float, int, type(None)) ): raise TypeError( f"Expected values in time ranges be of types `float` or `int` or `None`, " f"got `{type(time_range[0]).__name__!r}` and `{type(time_range[1]).__name__!r}`." ) val_start, val_end = time_range elif time_range is None: val_start, val_end = None, None else: val_start, val_end = None, time_range if isinstance(weight_threshold, (int, float, np.number)): weight_threshold = (weight_threshold, weight_threshold) if len(weight_threshold) != 2: raise ValueError(f"Expected `weight_threshold` to be of size `2`, found `{len(weight_threshold)}`.") self._obs_names = self.adata.obs_names.values[:] x = np.array(self.adata.obs[time_key]).astype(self._dtype) adata = self.adata.raw.to_adata() if use_raw else self.adata gene_ix = np.where(adata.var_names == gene)[0] if data_key in ("X", None): y = adata.X[:, gene_ix] self._data_key = None elif data_key == "obs": y = adata.obs[gene].values # don't set data_key, it's just when `cell_color = ...` elif data_key in adata.layers: y = adata.layers[data_key][:, gene_ix] self._data_key = data_key else: raise NotImplementedError(f"Data key `{data_key!r}` is not implemented.") if lineage is not None: weight_threshold, val = weight_threshold w = _densify_squeeze(probs.X, self._dtype) w[w < weight_threshold] = val else: w = np.ones(len(x), dtype=self._dtype) if use_raw: correct_ixs = np.isin(self.adata.obs_names, adata.obs_names) x = x[correct_ixs] y = y[correct_ixs] w = w[correct_ixs] self._obs_names = self._obs_names[correct_ixs] del adata self._x_all, self._y_all, self._w_all = ( _densify_squeeze(x, self._dtype)[:, np.newaxis], _densify_squeeze(y, self._dtype)[:, np.newaxis], w, ) x, y, w = self.x_all[:], self.y_all[:], self.w_all[:] # sanity checks if self._x_all.shape[0] != self._y_all.shape[0]: raise ValueError( f"Independent variable's first dimension ({self._x_all.shape[0]}) " f"differs from dependent variable's first dimension ({self._y_all.shape[0]})." ) if self._x_all.shape[0] != self._w_all.shape[0]: raise ValueError( f"Independent variable's first dimension ({self._x_all.shape[0]}) " f"differs from weights' first dimension ({self._w_all.shape[0]})." ) x, ixs = np.unique(x, return_index=True) # GAMR (mgcv) needs unique y = y[ixs] w = w[ixs] ixs = np.argsort(x) x, y, w = x[ixs], y[ixs], w[ixs] self._obs_names = self._obs_names[ixs] if np.allclose(w, w[0]): # degenerate case val_start = w[0] - 1 val_end = w[0] + 1 if val_start is None: val_start = np.min(x) if val_end is None: if threshold is None: threshold = np.nanmedian(w) # use `>=` because weights can all be 1 w_test = w[w >= threshold] n_window = n_test_points // 20 tmp = convolve(w_test, np.ones(n_window) / n_window, mode="nearest") val_end = x[w >= threshold][-1 if lineage is None else np.nanargmax(tmp)] if val_start > val_end: val_start, val_end = val_end, val_start val_start, val_end = (max(val_start, np.min(x)), min(val_end, np.max(x))) fil = (x >= val_start) & (x <= val_end) x_test = np.linspace(val_start, val_end, n_test_points) if n_test_points is not None else x[fil] x, y, w = x[fil], y[fil], w[fil] self._obs_names = self._obs_names[fil] if filter_cells is not None: tmp = y.squeeze() fil = (tmp >= filter_cells) & (~np.isclose(tmp, filter_cells).astype(bool)) x, y, w = x[fil], y[fil], w[fil] self._obs_names = self._obs_names[fil] self._x, self._y, self._w = ( self._reshape_and_retype(x), self._reshape_and_retype(y), self._reshape_and_retype(w).squeeze(-1), ) self._x_test = self._reshape_and_retype(x_test) self._y_test = None self._x_hat = None self._y_hat = None self._conf_int = None if self.x.shape[0] == 0: raise RuntimeError("Unable to proceed, no values to fit.") if len({self.x.shape[0], self.y.shape[0], self.w.shape[0]}) != 1: raise RuntimeError( f"Values have different shapes: `{self.x.shape[0]}`, `{self.y.shape[0]}`, `{self.w.shape[0]}`." ) self._gene = gene self._lineage = lineage self._prepared = True return self
[docs] @abc.abstractmethod @d.get_sections(base="base_model_fit", sections=["Parameters"]) @d.get_full_description(base="base_model_fit") def fit( self, x: Optional[np.ndarray] = None, y: Optional[np.ndarray] = None, w: Optional[np.ndarray] = None, **kwargs: Any, ) -> "BaseModel": """Fit the model. Parameters ---------- x Independent variables, array of shape ``(n_samples, 1)``. If :obj:`None`, use :attr:`x`. y Dependent variables, array of shape ``(n_samples, 1)``. If :obj:`None`, use :attr:`y`. w Optional weights of :attr:`x`, array of shape ``(n_samples,)``. If :obj:`None`, use :attr:`w`. kwargs Keyword arguments for underlying :attr:`model`'s fitting function. Returns ------- Fits the :attr:`model` and returns self. """ if not self.prepared: raise RuntimeError("The model has not been prepared yet, call `.prepare()` first.") self._check("_x", x) self._check("_y", y) self._check("_w", w, ndim=1) if self._x.shape != self._y.shape: raise ValueError(f"Inputs and targets differ in shape: `{self._x.shape}` vs. `{self._y.shape}`.") if self._y.shape[0] != self._w.shape[0]: raise ValueError(f"Inputs and weights differ in shape: `{self._y.shape[0]}` vs. `{self._w.shape[0]}`.") return self
[docs] @abc.abstractmethod @d.get_sections(base="base_model_predict", sections=["Parameters", "Returns"]) @d.get_full_description(base="base_model_predict") @d.dedent def predict( self, x_test: Optional[np.ndarray] = None, key_added: Optional[str] = "_x_test", **kwargs: Any, ) -> np.ndarray: """Run the prediction. Parameters ---------- x_test Array of shape ``(n_samples,)`` used for prediction. If :obj:`None`, use :attr:`x_test`. key_added Attribute name where to save the :attr:`x_test` for later use. If :obj:`None`, don't save it. kwargs Keyword arguments for underlying :attr:`model`'s prediction method. Returns ------- Returns and updates the following fields: - :attr:`y_test` - %(base_model_y_test.summary)s """
[docs] @abc.abstractmethod @d.get_sections(base="base_model_ci", sections=["Parameters", "Returns"]) @d.get_summary(base="base_model_ci") @d.get_full_description(base="base_model_ci") @d.dedent def confidence_interval(self, x_test: Optional[np.ndarray] = None, **kwargs) -> np.ndarray: """Calculate the confidence interval. Use :meth:`default_confidence_interval` function if underlying :attr:`model` has no method for confidence interval calculation. Parameters ---------- x_test Array of shape ``(n_samples,)`` used for confidence interval calculation. If :obj:`None`, use :attr:`x_test`. kwargs Keyword arguments for underlying :attr:`model`'s confidence method or for :meth:`default_confidence_interval`. Returns ------- Returns self and updates the following fields: - :attr:`conf_int` - %(base_model_conf_int.summary)s """
[docs] @d.dedent def default_confidence_interval( self, x_test: Optional[np.ndarray] = None, **kwargs: Any, ) -> np.ndarray: """Calculate the confidence interval, if the underlying :attr:`model` has no method for it. This formula is taken from :cite:`desalvo:70`, eq. 5. Parameters ---------- %(base_model_ci.parameters)s Returns ------- %(base_model_ci.returns)s Also updates the following fields: - :attr:`x_hat` - %(base_model_x_hat.summary)s - :attr:`y_hat` - %(base_model_y_hat.summary)s """ use_ixs = self.w > 0 x_hat = self.x[use_ixs] if x_test is None: x_test = self.x_test self._y_hat = self.predict(x_hat, key_added="_x_hat", **kwargs) self._y_test = self.predict(x_test, key_added="_x_test", **kwargs) n = np.sum(use_ixs) sigma_hat = np.sqrt(((self.y_hat - self.y[use_ixs].squeeze()) ** 2).sum() / (n - 2)) mean = np.mean(self.x) # fmt: off with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) stds = sigma_hat * np.sqrt(1 + 1 / n + ((self.x_test - mean) ** 2) / ((self.x - mean) ** 2).sum()) # fmt: on stds = np.squeeze(stds) self._conf_int = np.c_[self._y_test - stds / 2.0, self._y_test + stds / 2.0] return self.conf_int
[docs] @d.dedent def plot( self, figsize: Tuple[float, float] = (8, 5), same_plot: bool = False, hide_cells: bool = False, perc: Tuple[float, float] = None, fate_prob_cmap: colors.ListedColormap = cm.viridis, cell_color: Optional[str] = None, lineage_color: str = "black", alpha: float = 0.8, lineage_alpha: float = 0.2, title: Optional[str] = None, size: int = 15, lw: float = 2, cbar: bool = True, margins: float = 0.015, xlabel: str = "pseudotime", ylabel: str = "expression", conf_int: bool = True, lineage_probability: bool = False, lineage_probability_conf_int: Union[bool, float] = False, lineage_probability_color: Optional[str] = None, obs_legend_loc: Optional[str] = "best", dpi: int = None, fig: mpl.figure.Figure = None, ax: mpl.axes.Axes = None, return_fig: bool = False, save: Optional[str] = None, **kwargs: Any, ) -> Optional[mpl.figure.Figure]: """Plot the smoothed gene expression. Parameters ---------- figsize Size of the figure. same_plot Whether to plot all trends in the same plot. hide_cells Whether to hide the cells. perc Percentile by which to clip the fate probabilities. fate_prob_cmap Colormap to use when coloring in the fate probabilities. cell_color Key in :attr:`~anndata.AnnData.obs` or :attr:`~anndata.AnnData.var_names` used for coloring the cells. lineage_color Color for the lineage. alpha Alpha value in :math:`[0, 1]` for the transparency of cells. lineage_alpha Alpha value in :math:`[0, 1]` for the transparency lineage confidence intervals. title Title of the plot. size Size of the points. lw Line width for the smoothed values. cbar Whether to show the colorbar. margins Margins around the plot. xlabel Label on the x-axis. ylabel Label on the y-axis. conf_int Whether to show the confidence interval. lineage_probability Whether to show smoothed lineage probability as a dashed line. Note that this will require 1 additional model fit. lineage_probability_conf_int Whether to compute and show smoothed lineage probability confidence interval. lineage_probability_color Color to use when plotting the smoothed ``lineage_probability``. If :obj:`None`, it's the same as ``lineage_color``. Only used when ``show_lineage_probability = True``. obs_legend_loc Location of the legend when ``cell_color`` corresponds to a categorical variable. dpi Dots per inch. fig Figure to use. If :obj:`None`, create a new one. ax Ax to use. If :obj:`None`, create a new one. return_fig If :obj:`True`, return the figure object. save Filename where to save the plot. If :obj:`None`, just shows the plots. kwargs Keyword arguments for :meth:`~matplotlib.axes.Axes.legend`. Returns ------- %(just_plots)s """ if self.y_test is None: raise RuntimeError("Run `.predict()` first.") if fig is None or ax is None: fig, ax = plt.subplots(figsize=figsize, constrained_layout=True) if dpi is not None: fig.set_dpi(dpi) conf_int = conf_int and self.conf_int is not None hide_cells = hide_cells or self.x_all is None or self.w_all is None or self.y_all is None lineage_probability_color = lineage_color if lineage_probability_color is None else lineage_probability_color scaler = kwargs.pop( "scaler", self._create_scaler( lineage_probability, show_conf_int=conf_int, ), ) if lineage_probability and ylabel in ("expression", self._gene): ylabel = f"scaled {ylabel}" vmin, vmax = None, None key, color, typp, mapper = self._get_colors(cell_color, same_plot=same_plot) if not hide_cells: cbar = cbar and (typp == ColorType.CONT) if typp == ColorType.CONT: vmin, vmax = _minmax(color, perc) _ = ax.scatter( self.x_all.squeeze(), scaler(self.y_all.squeeze()), c=color, s=size, cmap=fate_prob_cmap, vmin=vmin, vmax=vmax, alpha=alpha, ) if title is None: title = f"{self._gene} @ {self._lineage}" if self._lineage is not None else f"{self._gene}" ax.plot(self.x_test, scaler(self.y_test), color=lineage_color, lw=lw, label=title) if title is not None: ax.set_title(title) if ylabel is not None: ax.set_ylabel(ylabel) if xlabel is not None: ax.set_xlabel(xlabel) ax.margins(margins) if conf_int: ax.fill_between( self.x_test.squeeze(), scaler(self.conf_int[:, 0]), scaler(self.conf_int[:, 1]), alpha=lineage_alpha, color=lineage_color, linestyle="--", ) if lineage_probability and not isinstance(self, FittedModel) and not np.allclose(self.w, 1.0): from cellrank.pl._utils import _is_any_gam_mgcv model = copy.deepcopy(self) model._y = self._reshape_and_retype(self.w).copy() model = model.fit() if not lineage_probability_conf_int: y = model.predict() elif _is_any_gam_mgcv(model): y = model.predict( level=lineage_probability_conf_int if isinstance(lineage_probability_conf_int, float) else 0.95 ) else: y = model.predict() model.confidence_interval() ax.fill_between( model.x_test.squeeze(), model.conf_int[:, 0], model.conf_int[:, 1], alpha=lineage_alpha, color=lineage_probability_color, linestyle="--", ) handle = ax.plot( model.x_test, y, color=lineage_probability_color, lw=lw, linestyle="--", zorder=-1, label="probability", ) self._maybe_add_legend(fig, ax, mapper=handle, title=None, **kwargs) if cbar and not hide_cells and not same_plot and not np.allclose(self.w_all, 1.0): norm = colors.Normalize(vmin=vmin, vmax=vmax) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="2%", pad=0.1) _ = mpl.colorbar.ColorbarBase( cax, norm=norm, cmap=fate_prob_cmap, ticks=np.linspace(norm.vmin, norm.vmax, 5), ) cax.set_ylabel(key) elif typp == ColorType.CAT: self._maybe_add_legend(fig, ax, mapper, title=key, loc=obs_legend_loc, is_line=False) if save is not None: save_fig(fig, save) if return_fig: return fig
def _maybe_add_legend( self, fig: mpl.figure.Figure, ax: mpl.axes.Axes, mapper: Union[Sequence[Any], Mapping[str, Any]], title: Optional[str] = None, loc: Optional[str] = "best", is_line: bool = True, **kwargs: Any, ) -> None: from cellrank.pl._utils import _position_legend if loc in ("none", None): return if isinstance(mapper, dict): handles = [ ax.plot([], [], label=name, color=color)[0] if is_line else ax.scatter([], [], label=name, color=color) for name, color in mapper.items() ] else: handles = mapper legend = _position_legend( ax, legend_loc=loc, handles=handles, title=title, **kwargs, ) fig.add_artist(legend) def _reshape_and_retype(self, arr: np.ndarray) -> np.ndarray: """Convert 1D or 2D array to 2D array of shape ``(n, 1)`` and set the data type. Parameters ---------- arr Array to convert. Returns ------- Array of shape ``(n, 1)`` with dtype set to :attr:`_dtype`. """ if arr.ndim not in (1, 2): raise ValueError(f"Expected array to be 1 or 2 dimensional, found `{arr.ndim}` dimension(s).") if arr.ndim == 2 and arr.shape[1] != 1: raise ValueError(f"Expected the 2nd dimension to be 1, found `{arr.shape[1]}.`") return np.reshape(arr, (-1, 1)).astype(self._dtype) def _check(self, attr_name: Optional[str], arr: Optional[np.ndarray], ndim: int = 2) -> Optional[np.ndarray]: """Check if the attribute exists with the correct dimension and optionally set it. Parameters ---------- attr_name Attribute name to check. arr Value to set. If :obj:`None`, just perform the checking. ndim Expected number of dimensions of the ``arr``. Returns ------- The attribute under ``attr_name``. """ if attr_name is None: return if arr is None: # already called prepare if not hasattr(self, attr_name): raise AttributeError(f"No attribute `{attr_name!r}` found.") if not isinstance(getattr(self, attr_name, None), np.ndarray): raise AttributeError( f"Expected `{attr_name!r}` to be `numpy.ndarray`, " f"found `{type(getattr(self, attr_name, None)).__name__!r}`." ) if getattr(self, attr_name).ndim != ndim: raise ValueError( f"Expected attribute `{attr_name!r}` to have `{ndim}` dimensions, " f"found `{getattr(self, attr_name).ndim}` dimensions." ) return getattr(self, attr_name) setattr(self, attr_name, self._reshape_and_retype(arr)) if attr_name.startswith("_"): try: getattr(self, attr_name[1:]) except AttributeError: setattr( self, attr_name[1:], property(lambda self: getattr(self, attr_name)), ) return getattr(self, attr_name) def _deepcopy_attributes(self, dst: "BaseModel") -> None: # __deepcopy__ will usually call `_shallowcopy_attributes` twice, since it calls `.copy()`, # which should always call it (it copies the lineage and gene names + deepcopies the model) self._shallowcopy_attributes(dst) for attr in [ "_x_all", "_y_all", "_w_all", "_x", "_y", "_w", "_x_test", "_y_test", "_x_hat", "_y_hat", "_conf_int", "_prepared", ]: setattr(dst, attr, copy.copy(getattr(self, attr))) def _shallowcopy_attributes(self, dst: "BaseModel") -> None: for attr in ["_gene", "_lineage", "_use_raw", "_data_key", "_is_bulk"]: setattr(dst, attr, copy.copy(getattr(self, attr))) # user is not exposed to this dst._obs_names = self._obs_names # always deepcopy the model (we're manipulating it in multiple threads/processed) dst._model = copy.deepcopy(self.model)
[docs] @abc.abstractmethod @d.dedent def copy(self) -> "BaseModel": """%(copy)s""" # noqa
def __copy__(self) -> "BaseModel": return self.copy() def __deepcopy__(self, memodict={}) -> "BaseModel": # noqa # deepcopy expects that `.copy()` makes a really shallow copy (i.e. only references to the arrays) # it should also not copy the `.prepared` attribute, since copying is happening mostly during # parallelization and it serves as 1 extra sanity check (a precaution that's not necessary, per-se, but highly # desirable) res = self.copy() # we don't copy the adata object for 2 reasons: # 1. these objects are meant to be as lightweight as possible # 2. in `.plot`, we deepcopy the model when plotting smoothed probabilities self._deepcopy_attributes(res) memodict[id(self)] = res return res def __bool__(self): return True def __str__(self) -> str: return repr(self) def __repr__(self) -> str: return "<{}[gene={!r}, lineage={!r}, model={}]>".format( self.__class__.__name__, self._gene, self._lineage, None if self.model is None else _dup_spaces.sub(" ", str(self.model).replace("\n", " ")).strip(), ) def _get_colors( self, key: Optional[str], *, same_plot: bool = False, ) -> Tuple[Optional[str], Optional[Union[str, np.ndarray]], ColorType, Optional[Dict[str, Any]],]: """ Get color array. Parameters ---------- key Key in :attr:`~anndata.AnnData.obs`, :attr:`~anndata.AnnData.var_names` or :attr:`~anndata.AnnData.raw.var_names`. Search first starts in `.obs`, then `.raw.var_names` and lastly `.layers`, using :meth:`~anndata.AnnData.obs_vector`. Returns ------- Triple of the following: - name of the colorbar label. - array of values to map. - type of the color. - color mapper for categorical colors. """ if colors.is_color_like(key): return None, key, ColorType.STR, None if key in self.adata.obs: if isinstance(self.adata.obs[key].dtype, pd.CategoricalDtype): add_colors_for_categorical_sample_annotation( self.adata, key=key, force_update_colors=False, palette=None, ) col_dict = collections.defaultdict( lambda: colors.to_rgb("grey"), zip( self.adata.obs[key].cat.categories, [colors.to_rgb(i) for i in self.adata.uns[f"{key}_colors"]], ), ) return ( key, np.array([col_dict[v] for v in self.adata.obs[key]]), ColorType.CAT, col_dict, ) if np.issubdtype(self.adata.obs[key].dtype, np.number): return key, self.adata.obs[key].values, ColorType.CONT, None logg.debug(f"Unable to interpret cell color from type `{infer_dtype(self.adata.obs[key])}`") return None, "black", ColorType.STR, None if self._use_raw and key in self.adata.raw.var_names: return ( key, _densify_squeeze(self.adata.raw[:, key], np.float64), ColorType.CONT, None, ) try: # can in principle return data from `.obs`, in which case it's mostly invalid return ( key, self.adata.obs_vector(key, layer=self._data_key), ColorType.CONT, None, ) except KeyError: logg.debug( f"Key `{key!r}` not found in `adata.obs` or " f"`adata{'.raw' if self._use_raw else ''}.var_names`. Ignoring`" ) if same_plot or np.allclose(self.w_all, 1.0): return None, "black", ColorType.STR, None return "fate probability", np.squeeze(self.w_all), ColorType.CONT, None def _create_scaler(self, show_lineage_probability: bool, show_conf_int: bool): if not show_lineage_probability: return lambda _: _ minn, maxx = self._return_min_max(show_conf_int) return lambda x: (x - minn) / (maxx - minn) def _return_min_max(self, show_conf_int: bool): if self.y_test is None: raise RuntimeError("Run `.predict()` first.") vals = [self.y_test] # FittedModel Does not need to have these if self.y_all is not None: vals.append(self.y_all) if show_conf_int and self.conf_int is not None: vals.append(self.conf_int) minn = min(map(np.min, vals)) maxx = max(map(np.max, vals)) return minn, maxx
class FailedModel(BaseModel): """Model representing a failure of the original :attr:`model`. Parameters ---------- model The original model which has failed. exc The exception that caused the :attr:`model` to fail or a :class:`str` containing the message. In the latter case, :meth:`~cellrank.models.FailedModel.reraise` a :class:`RuntimeError` with that message. If :obj:`None`, :class`UnknownModelError` will eventually be raised. """ # in a functional programming language like Haskell essentially BaseModel would be a Maybe monad and this Nothing def __init__(self, model: BaseModel, exc: Optional[Union[BaseException, str]] = None): if not isinstance(model, BaseModel): raise TypeError(f"Expected `model` to be of type `BaseModel`, found `{type(model).__name__!r}`.") if exc is not None: if isinstance(exc, str): exc = RuntimeError(exc) if not isinstance(exc, BaseException): raise TypeError( f"Expected `exc` to be either a string or a `BaseException`, found `{type(exc).__name__!r}`." ) else: exc = UnknownModelError() super().__init__(model.adata, model) self._gene = model._gene self._lineage = model._lineage self._prepared = model._prepared self._is_bulk = model._is_bulk self._exc = exc def prepare( self, *_args, **_kwargs, ) -> "FailedModel": """Do nothing and return self.""" return self def fit( self, x: Optional[np.ndarray] = None, y: Optional[np.ndarray] = None, w: Optional[np.ndarray] = None, **kwargs: Any, ) -> "FailedModel": """Do nothing and return self.""" return self def predict( self, x_test: Optional[np.ndarray] = None, key_added: Optional[str] = "_x_test", **kwargs: Any, ) -> np.ndarray: """Do nothing.""" def confidence_interval(self, x_test: Optional[np.ndarray] = None, **kwargs) -> np.ndarray: """Do nothing.""" def default_confidence_interval( self, x_test: Optional[np.ndarray] = None, **kwargs: Any, ) -> np.ndarray: """Do nothing.""" def reraise(self) -> None: """Raise the original exception with additional model information.""" # retain the exception type and also the original exception raise type(self._exc)(f"Fatal model failure `{self}`.") from self._exc def _get_colors( self, key: Optional[str], *, same_plot: bool = False, ) -> Tuple[Optional[str], Optional[Union[str, np.ndarray]], ColorType, Optional[Dict[str, Any]],]: return None, "black", ColorType.STR, None def _return_min_max(self, show_conf_int: bool): return np.inf, -np.inf @d.dedent def copy(self) -> "FailedModel": """%(copy)s""" # noqa return FailedModel(self.model.copy(), exc=self._exc) def __bool__(self): return False def __str__(self) -> str: return repr(self) def __repr__(self) -> str: return f"<{self.__class__.__name__}[origin={repr(self.model).strip('<>')}]>" @d.dedent class FittedModel(BaseModel): """Class representing an already fitted model. Useful when the smoothed gene expression was computed externally. Parameters ---------- x_test %(base_model_x_test.summary)s y_test %(base_model_y_test.summary)s conf_int %(base_model_conf_int.summary)s If :obj:`None`, always set ``conf_int=False`` in :meth:`~cellrank.models.BaseModel.plot`. x_all %(base_model_x_all.summary)s If :obj:`None`, always sets ``hide_cells=True`` in :meth:`~cellrank.models.BaseModel.plot`. y_all %(base_model_y_all.summary)s If :obj:`None`, always sets `hide_cells=True`` in :meth:`~cellrank.models.BaseModel.plot`. w_all %(base_model_w_all.summary)s If :obj:`None` and :attr:`x_all` and :attr:`y_all` are present, it will be set an array of :math:`1`. """ def __init__( self, x_test: ArrayLike, y_test: ArrayLike, conf_int: Optional[ArrayLike] = None, x_all: Optional[ArrayLike] = None, y_all: Optional[ArrayLike] = None, w_all: Optional[ArrayLike] = None, ): super().__init__(None, None) self._x_test = _densify_squeeze(x_test, self._dtype)[:, np.newaxis] self._y_test = _densify_squeeze(y_test, self._dtype) if self.x_test.ndim != 2 or self.x_test.shape[1] != 1: raise ValueError(f"Expected `x_test` to be of shape `(..., 1)`, found `{self.x_test.shape}`.") if self.y_test.shape != (self.x_test.shape[0],): raise ValueError( f"Expected `y_test` to be of shape `({self.x_test.shape[0]},)`, " f"found `{self.y_test.shape}`." ) if conf_int is not None: self._conf_int = _densify_squeeze(conf_int, self._dtype) if self.conf_int.shape != (self.x_test.shape[0], 2): raise ValueError(f"Expected `conf_int` to be of shape `({self.x_test.shape[0]}, 2)`.") else: logg.debug("No `conf_int` have been supplied, will be ignored during plotting") if x_all is not None and y_all is not None: self._x_all = _densify_squeeze(x_all, self._dtype)[:, np.newaxis] if self.x_all.ndim != 2 or self.x_all.shape[1] != 1: raise ValueError(f"Expected `x_all` to be of shape `(..., 1)`, found `{self.x_all.shape}`.") self._y_all = _densify_squeeze(y_all, self._dtype)[:, np.newaxis] if self.y_all.shape != self.x_all.shape: raise ValueError( f"Expected `y_all` to be of shape `({self.x_all.shape[0]}, 1)`, found `{self.y_all.shape}`." ) if w_all is not None: self._w_all = _densify_squeeze(w_all, self._dtype) if self.w_all.ndim != 1 or self.w_all.shape[0] != self.x_all.shape[0]: raise ValueError( f"Expected `w_all` to be of shape `({self.x_all.shape[0]},)`, " f"found `{self.w_all.shape}`." ) else: logg.debug("Setting `w_all` to an array of `1`") self._w_all = np.ones_like(self.x_all, dtype=np.float64).squeeze(1) else: logg.debug( "None or partially incomplete `x_all` and `y_all` have been supplied, " "will be ignored during plotting" ) self._prepared = True def _get_colors( self, key: Optional[str], *, same_plot: bool = False, ) -> Tuple[Optional[str], Optional[Union[str, np.ndarray]], ColorType, Optional[Dict[str, Any]],]: # w_all does not need to be defined if same_plot or self.w_all is None or np.allclose(self.w_all, 1.0): return None, "black", ColorType.STR, None return "fate probability", np.squeeze(self.w_all), ColorType.CONT, None def prepare(self, *_args, **_kwargs) -> "FittedModel": """Do nothing and return self.""" return self def fit( self, x: Optional[np.ndarray] = None, y: Optional[np.ndarray] = None, w: Optional[np.ndarray] = None, **kwargs: Any, ) -> "FittedModel": """Do nothing and return self.""" return self @d.dedent def predict( self, x_test: Optional[np.ndarray] = None, key_added: Optional[str] = "_x_test", **kwargs: Any, ) -> np.ndarray: """%(base_model_y_test.summary)s.""" return self._y_test @d.dedent def confidence_interval(self, x_test: Optional[np.ndarray] = None, **kwargs) -> np.ndarray: """%(base_model_conf_int.summary)s Raise a :class:`RuntimeError` if not present.""" if self.conf_int is None: raise RuntimeError( "No confidence interval has been supplied. " "Use `conf_int=...` when instantiating this class." ) return self.conf_int @d.dedent def default_confidence_interval( self, x_test: Optional[np.ndarray] = None, **kwargs: Any, ) -> np.ndarray: """%(base_model_conf_int.summary)s Raise a :class:`RuntimeError` if not present.""" return self.confidence_interval() @d.dedent def copy(self) -> "FittedModel": """%(copy)s""" # noqa # here we return a deepcopy since it doesn't make sense to make a shallow one return FittedModel.from_model(self) @staticmethod def from_model(model: BaseModel) -> "FittedModel": """Create a :class:`~cellrank.models.FittedModel` from a :class:`~cellrank.models.BaseModel`.""" if not isinstance(model, BaseModel): raise TypeError(f"Expected `model` to be of type `BaseModel`, found `{type(model).__name__!r}`.") fm = FittedModel( model.x_test, model.y_test, conf_int=model.conf_int, x_all=model.x_all, y_all=model.y_all, w_all=model.w_all, ) fm._gene = model._gene fm._lineage = model._lineage fm._is_bulk = model._is_bulk fm._prepared = True return fm