Source code for cellrank.models._pygam_model

import collections
import contextlib
import copy
import enum
import io
import logging
import types
from collections.abc import Mapping
from typing import Any, Literal

import numpy as np
from anndata import AnnData
from pygam import GAM as pGAM
from pygam import (
    ExpectileGAM,
    GammaGAM,
    InvGaussGAM,
    LinearGAM,
    LogisticGAM,
    PoissonGAM,
    s,
)

from cellrank._utils._docs import d
from cellrank._utils._enum import ModeEnum
from cellrank._utils._utils import _filter_kwargs
from cellrank.models import BaseModel

logger = logging.getLogger(__name__)
__all__ = ["GAM"]


class GamLinkFunction(ModeEnum):
    IDENTITY = enum.auto()
    LOGIT = enum.auto()
    INVERSE = enum.auto()
    LOG = enum.auto()
    INV_SQUARED = enum.auto()


class GamDistribution(ModeEnum):
    NORMAL = enum.auto()
    BINOMIAL = enum.auto()
    POISSON = enum.auto()
    GAMMA = enum.auto()
    GAUSSIAN = enum.auto()
    INV_GAUSS = enum.auto()


_gams = collections.defaultdict(
    lambda: pGAM,
    {
        (GamDistribution.NORMAL, GamLinkFunction.IDENTITY): LinearGAM,
        (GamDistribution.BINOMIAL, GamLinkFunction.LOGIT): LogisticGAM,
        (GamDistribution.POISSON, GamLinkFunction.LOG): PoissonGAM,
        (GamDistribution.GAMMA, GamLinkFunction.LOG): GammaGAM,
        (GamDistribution.INV_GAUSS, GamLinkFunction.LOG): InvGaussGAM,
    },
)


[docs] @d.dedent class GAM(BaseModel): """Fit Generalized Additive Models (GAMs) using :mod:`pygam`. Parameters ---------- %(adata)s n_knots Number of knots. spline_order Order of the splines, e.g., :math:`3` for cubic splines. distribution Name of the distribution. Available distributions can be found `here <https://pygam.readthedocs.io/en/latest/notebooks/tour_of_pygam.html#Distribution:>`__. link Name of the link function. Available link functions can be found `here <https://pygam.readthedocs.io/en/latest/notebooks/tour_of_pygam.html#Link-function:>`__. max_iter Maximum number of iterations for optimization. expectile Expectile for :class:`~pygam.pygam.ExpectileGAM`. This forces the distribution to be ``'normal'`` and link function to ``'identity'``. Must be in :math:`(0, 1)`. grid Whether to perform a grid search. Keys correspond to a parameter names and values to range to be searched. If ``'default'``, use the default grid. If :obj:`None`, don't perform a grid search. spline_kwargs Keyword arguments for :func:`~pygam.terms.s`. kwargs Keyword arguments for the :class:`~pygam.pygam.GAM`. """ def __init__( self, adata: AnnData, n_knots: int | None = 6, spline_order: int = 3, distribution: Literal["normal", "binomial", "poisson", "gamma", "gaussian", "inv_gauss"] = "gamma", link: Literal["identity", "logit", "inverse", "log", "inv_squared"] = "log", max_iter: int = 2000, expectile: float | None = None, grid: str | Mapping[str, Any] | None = None, spline_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Any, ): term = s( 0, spline_order=spline_order, n_splines=n_knots, **_filter_kwargs(s, **{**{"lam": 3, "penalties": ["derivative", "l2"]}, **spline_kwargs}), ) link = GamLinkFunction(link) distribution = GamDistribution(distribution) if distribution == GamDistribution.GAUSSIAN: distribution = GamDistribution.NORMAL if expectile is not None: if not (0 < expectile < 1): raise ValueError(f"Expected `expectile` to be in `(0, 1)`, found `{expectile}`.") if distribution != "normal" or link != "identity": logger.warning( "Expectile GAM works only with `normal` distribution and `identity` link function, " "found `%r` distribution and %r link functions.", distribution, link, ) model = ExpectileGAM(term, expectile=expectile, max_iter=max_iter, verbose=False, **kwargs) else: # doing it like this ensures that user can specify scale gam = _gams[distribution, link] filtered_kwargs = _filter_kwargs(gam.__init__, **kwargs) if len(kwargs) != len(filtered_kwargs): raise TypeError( f"Invalid arguments `{list(set(kwargs) - set(filtered_kwargs))}` for `{type(gam).__name__!r}`." ) filtered_kwargs["link"] = link filtered_kwargs["distribution"] = distribution model = gam( term, max_iter=max_iter, verbose=False, **_filter_kwargs(gam.__init__, **filtered_kwargs), ) super().__init__(adata, model=model) self._converged: bool = True if grid is None: self._grid = None elif isinstance(grid, dict): self._grid = copy.copy(grid) elif isinstance(grid, str): self._grid = object() if grid == "default" else None else: raise TypeError(f"Expected `grid` to be `dict`, `str` or `None`, found `{type(grid).__name__!r}`.")
[docs] @d.dedent def fit( self, x: np.ndarray | None = None, y: np.ndarray | None = None, w: np.ndarray | None = None, **kwargs: Any, ) -> "GAM": """%(base_model_fit.full_desc)s Parameters ---------- %(base_model_fit.parameters)s Returns ------- Fits the model and returns self. """ # noqa: D400 super().fit(x, y, w, **kwargs) if self._grid is not None: # use default search grid = {} if not isinstance(self._grid, dict) else self._grid try: with contextlib.redirect_stdout(io.StringIO()): self.model.gridsearch( self.x, self.y, weights=self.w, keep_best=True, progress=False, **grid, **kwargs, ) return self except Exception as e: # noqa: BLE001 # workaround for: https://github.com/dswah/pyGAM/issues/273 with contextlib.redirect_stdout(io.StringIO()): self.model.fit(self.x, self.y, weights=self.w, **kwargs) logger.error("Grid search failed, reason: `%s`. Fitting with default values", e) try: buf = io.StringIO() with contextlib.redirect_stdout(buf): self.model.fit(self.x, self.y, weights=self.w, **kwargs) self._converged = "did not converge" not in buf.getvalue() return self except Exception as e: # noqa: BLE001 raise RuntimeError( f"Unable to fit `{type(self).__name__}` for gene " f"`{self._gene!r}` in lineage `{self._lineage!r}`. Reason: `{e}`" ) from e
[docs] @d.dedent def predict( self, x_test: np.ndarray | None = None, key_added: str | None = "_x_test", **kwargs: Any, ) -> np.ndarray: """%(base_model_predict.full_desc)s Parameters ---------- %(base_model_predict.parameters)s Returns ------- %(base_model_predict.returns)s """ # noqa: D400 x_test = self._check(key_added, x_test) self._y_test = self.model.predict(x_test, **kwargs) self._y_test = np.squeeze(self._y_test).astype(self._dtype) return self.y_test
[docs] @d.dedent def confidence_interval(self, x_test: np.ndarray | None = None, **kwargs: Any) -> np.ndarray: """%(base_model_ci.summary)s Parameters ---------- %(base_model_ci.parameters)s Returns ------- %(base_model_ci.returns)s """ # noqa: D400 x_test = self._check("_x_test", x_test) self._conf_int = self.model.confidence_intervals(x_test, **kwargs).astype(self._dtype) return self.conf_int
[docs] @d.dedent def copy(self) -> "BaseModel": """%(copy)s""" # noqa: D400, D401 res = GAM(self.adata) self._shallowcopy_attributes(res) res._grid = copy.deepcopy(self._grid) return res