Source code for cellrank.models._gamr_model

import copy
import enum
from typing import Any, Literal

import numpy as np
import pandas as pd
import scipy.stats as st
from anndata import AnnData

from cellrank._utils._docs import d, inject_docs
from cellrank._utils._enum import ModeEnum
from cellrank.models import BaseModel, FailedModel
from cellrank.models._utils import _OFFSET_KEY, _get_knotlocs, _get_offset

__all__ = ["GAMR"]

# cache so we don't have to re-load
_r_lib = None
_r_lib_name = None


class KnotLocs(ModeEnum):
    AUTO = enum.auto()
    DENSITY = enum.auto()


[docs] @inject_docs(key=_OFFSET_KEY, kloc=KnotLocs) @d.dedent class GAMR(BaseModel): """Wrapper around R's `mgcv <https://cran.r-project.org/web/packages/mgcv/>`__ package for fitting GAMs. Parameters ---------- %(adata)s n_knots Number of knots. distribution Distribution family in `rpy2.robjects.r`, such as ``'gaussian'`` or ``'nb'`` for negative binomial. If ``'nb'``, raw count data in :attr:`~anndata.AnnData.raw` is always used. basis Basis for the smoothing term. See `here <https://www.rdocumentation.org/packages/mgcv/versions/1.8-33/topics/s>`__ for valid options. knotlocs Position of the knots. Can be one of the following: - ``{kloc.AUTO!r}`` - let `mgcv` handle the knot positions. - ``{kloc.DENSITY!r}`` - position the knots based on the density of the pseudotime. offset Offset term for the GAM. Only available when ``distribution='nb'``. If `'default'`, it is calculated according to :cite:`robinson:10`. The values are saved in :attr:`adata.obs['{key}'] <anndata.AnnData.obs>`. If :obj:`None`, no offset is used. smoothing_penalty Penalty for the smoothing term. The larger the value, the smoother the fitted curve. kwargs Keyword arguments for `gam control <https://www.rdocumentation.org/packages/mgcv/versions/1.8-33/topics/gam.control>`__. """ def __init__( self, adata: AnnData, n_knots: int = 5, distribution: str = "gaussian", basis: str = "cr", knotlocs: Literal["auto", "density"] = KnotLocs.AUTO, offset: np.ndarray | Literal["default"] | None = "default", smoothing_penalty: float = 1.0, **kwargs: Any, ): if n_knots <= 0: raise ValueError(f"Expected `n_splines` to be positive, found `{n_knots}`.") if smoothing_penalty < 0: raise ValueError(f"Expected `smoothing_penalty` to be non-negative, found `{smoothing_penalty}`.") super().__init__(adata, model=None) self._n_knots = n_knots self._family = distribution self._formula = f"y ~ s(x, bs={basis!r}, k={self._n_knots}, sp={smoothing_penalty})" self._design_mat = None self._offset = None self._knotslocs = KnotLocs(knotlocs) self._control_kwargs = copy.copy(kwargs) self._lib = None self._lib_name = None # it's a bit costly to import, copying just passes the reference if kwargs.pop("perform_import_check", True): self._lib, self._lib_name = _maybe_import_r_lib("mgcv") if distribution == "nb" and offset is not None: if not isinstance(offset, (np.ndarray, str)): raise TypeError( f"Expected `offset` to be either `'default'` or `numpy.ndarray`,got `{type(offset).__name__}`." ) if isinstance(offset, str): if offset != "default": raise ValueError("Only value `'default'` is allowed when `offset` is a string.") offset = _get_offset(adata, use_raw=True, recompute=False) offset = np.asarray(offset, dtype=self._dtype) if offset.shape != (adata.n_obs,): raise ValueError(f"Expected offset to be of shape `{(adata.n_obs,)}`, found `{offset.shape}`.") adata.obs[_OFFSET_KEY] = offset self._formula += " + offset(offset)" self._offset = offset
[docs] @d.dedent def prepare( self, *args: Any, **kwargs: Any, ) -> "GAMR": """%(base_model_prepare.full_desc)s This also removes the zero and negative weights and prepares the design matrix. Parameters ---------- %(base_model_prepare.parameters)s Returns ------- %(base_model_prepare.returns)s """ # noqa: D400 if self._family == "nb": kwargs["use_raw"] = True prepared = super().prepare(*args, **kwargs) if isinstance(prepared, FailedModel): return prepared use_ixs = self.w > 0 self._x = self.x[use_ixs] self._y = self.y[use_ixs] self._w = self.w[use_ixs] self._design_mat = pd.DataFrame( np.c_[self.x, self.y], columns=["x", "y"], ) if self._offset is not None: mask = np.isin(self.adata.obs_names, self._obs_names[use_ixs]) self._design_mat["offset"] = self._offset[mask] return self
[docs] @d.dedent def fit( self, x: np.ndarray | None = None, y: np.ndarray | None = None, w: np.ndarray | None = None, **kwargs: Any, ) -> "GAMR": """%(base_model_fit.full_desc)s Parameters ---------- %(base_model_fit.parameters)s Returns ------- Fits the model and returns self. Updates the following fields by filtering out :math:`0` weights :attr:`w`: - :attr:`x` - %(base_model_x.summary)s - :attr:`y` - %(base_model_y.summary)s - :attr:`w` - %(base_model_w.summary)s """ # noqa: D400 import rpy2.robjects as ro from rpy2.robjects import pandas2ri from rpy2.robjects.conversion import localconverter super().fit(x, y, w, **kwargs) # order seems to matter, pandas2ri + default_converter results in: # Conversion 'py2rpy' not defined for objects of type '<class 'rpy2.rlike.container.OrdDict'>' with localconverter(pandas2ri.converter + ro.default_converter): family = getattr(ro.r, self._family) kwargs = {} if self._knotslocs != KnotLocs.AUTO: kwargs["knots"] = pd.DataFrame( _get_knotlocs( self.x, self._n_knots, uniform=False, ), columns=["x"], ) self._model = self._lib.gam( ro.Formula(self._formula), data=self._design_mat, family=family, weights=pd.Series(self.w), control=self._lib.gam_control(**self._control_kwargs), **kwargs, ) return self
def _get_x_test(self, x_test: np.ndarray | None = None, key_added: str = "_x_test") -> pd.DataFrame: newdata = pd.DataFrame(self._check(key_added, x_test), columns=["x"]) if "offset" in self._design_mat: newdata["offset"] = np.mean(self._design_mat["offset"]) return newdata
[docs] @d.dedent def predict( self, x_test: np.ndarray | None = None, key_added: str = "_x_test", level: float | None = None, **kwargs, ) -> np.ndarray: """%(base_model_predict.full_desc)s This method can also compute the confidence interval. Parameters ---------- %(base_model_predict.parameters)s level Confidence level for confidence interval calculation. If :obj:`None`, don't compute the confidence interval. Must be in :math:`[0, 1]`. Returns ------- %(base_model_predict.returns)s """ # noqa: D400 import rpy2.robjects as ro from rpy2.robjects import pandas2ri from rpy2.robjects.conversion import localconverter if self.model is None: raise RuntimeError("Trying to call an uninitialized model. To initialize it, run `.fit()` first.") if self._lib is None: raise RuntimeError(f"Unable to run prediction, R package `{self._lib_name!r}` is not imported.") if level is not None and not (0 <= level <= 1): raise ValueError(f"Expected the confidence level to be in interval `[0, 1]`, found `{level}`.") newdata = self._get_x_test(x_test) with localconverter(pandas2ri.converter + ro.default_converter): res = ro.r.predict( self.model, newdata=ro.pandas2ri.py2rpy(newdata), type="response", se=level is not None, ) if level is None: self._y_test = np.array(res).squeeze().astype(self._dtype) else: self._y_test = np.array(res.rx2("fit")).squeeze().astype(self._dtype) se = np.array(res.rx2("se.fit")).squeeze().astype(self._dtype) level = st.norm.ppf(level + (1 - level) / 2) self._conf_int = np.c_[self.y_test - level * se, self.y_test + level * se] return self.y_test
[docs] @d.dedent def confidence_interval(self, x_test: np.ndarray | None = None, level: float = 0.95, **kwargs) -> np.ndarray: """%(base_model_ci.summary)s Internally, this method calls :meth:`~cellrank.models.GAMR.predict` to extract the confidence interval, if needed. Parameters ---------- %(base_model_ci.parameters)s level Confidence level. Returns ------- %(base_model_ci.returns)s """ # noqa: D400 # this is 2x as fast as opposed to calling `robjects.r.predict` again # on my PC (Michal): # -`.predict` take ~6.5ms (without se=True, it's ~5.5ms, 20% slowdown) # -`.predict` withouty se=True + `.confidence_interval` take ~12ms # this impl. achieves the 2x speedup withouty sacrificing the 20% if x_test is not None or level != 0.95 or self._conf_int is None: _ = self.predict(x_test=x_test, level=level) return self._conf_int
def _deepcopy_attributes(self, dst: "GAMR") -> None: super()._deepcopy_attributes(dst) dst._offset = copy.deepcopy(self._offset) dst._design_mat = copy.deepcopy(self._design_mat)
[docs] @d.dedent def copy(self) -> "GAMR": """%(copy)s""" # noqa: D400, D401 res = GAMR( self.adata, self._n_knots, distribution=self._family, offset=self._offset, knotlocs=self._knotslocs, perform_import_check=False, ) self._shallowcopy_attributes(res) # does not copy `prepared`, since we're not copying the arrays res._formula = self._formula # we don't save the basis inside res._design_mat = self._design_mat res._offset = self._offset res._lib = self._lib # just passing the reference res._lib_name = self._lib_name res._control_kwargs = self._control_kwargs return res
def __getstate__(self) -> dict: return {k: v for k, v in self.__dict__.items() if k != "_lib"} def __setstate__(self, state: dict): self.__dict__ = state self._lib, self._lib_name = _maybe_import_r_lib(self._lib_name, raise_exc=True)
def _maybe_import_r_lib(name: str, raise_exc: bool = False) -> tuple[Any | None, str | None]: global _r_lib, _r_lib_name if name == _r_lib_name and _r_lib is not None: return _r_lib, _r_lib_name try: from logging import ERROR from rpy2.rinterface_lib.callbacks import logger from rpy2.robjects import r from rpy2.robjects.packages import PackageNotInstalledError, importr logger.setLevel(ERROR) r["options"](warn=-1) except ImportError as e: raise ImportError("Unable to import `rpy2`, install it first as `pip install rpy2` version `>=3.3.0`.") from e try: _r_lib = importr(name) _r_lib_name = name return _r_lib, _r_lib_name except PackageNotInstalledError as e: if not raise_exc: return None, None raise RuntimeError(f"Install R library `{name!r}` first as `install.packages({name!r}).`") from e