Source code for cellrank.ul.models._pygam_model

from typing import Any, Union, Mapping, Optional
from typing_extensions import Literal

import warnings
from copy import copy as _copy
from copy import deepcopy
from enum import auto
from types import MappingProxyType
from collections import defaultdict

from cellrank import logging as logg
from import ModeEnum
from cellrank.ul._docs import d
from cellrank.ul.models import BaseModel
from import _filter_kwargs
from cellrank.ul.models._base_model import AnnData

import numpy as np
from pygam import GAM as pGAM
from pygam import (

class GamLinkFunction(ModeEnum):  # noqa: D101
    IDENTITY = auto()
    LOGIT = auto()
    INVERSE = auto()
    LOG = auto()
    INV_SQUARED = "inverse-squared"

class GamDistribution(ModeEnum):  # noqa: D101
    NORMAL = auto()
    BINOMIAL = auto()
    POISSON = auto()
    GAMMA = auto()
    GAUSSIAN = auto()
    INV_GAUSS = auto()

_gams = defaultdict(
    lambda: pGAM,
        (GamDistribution.NORMAL, GamLinkFunction.IDENTITY): LinearGAM,
        (GamDistribution.BINOMIAL, GamLinkFunction.LOGIT): LogisticGAM,
        (GamDistribution.POISSON, GamLinkFunction.LOG): PoissonGAM,
        (GamDistribution.GAMMA, GamLinkFunction.LOG): GammaGAM,
        (GamDistribution.INV_GAUSS, GamLinkFunction.LOG): InvGaussGAM,

[docs]@d.dedent class GAM(BaseModel): """ Fit Generalized Additive Models (GAMs) using :mod:`pygam`. Parameters ---------- %(adata)s n_knots Number of knots. spline_order Order of the splines, i.e. `3` for cubic splines. distribution Name of the distribution. Available distributions can be found `here <>`__. link Name of the link function. Available link functions can be found `here <>`__. max_iter Maximum number of iterations for optimization. expectile Expectile for :class:`pygam.pygam.ExpectileGAM`. This forces the distribution to be `'normal'` and link function to `'identity'`. Must be in interval `(0, 1)`. grid Whether to perform a grid search. Keys correspond to a parameter names and values to range to be searched. If `'default'`, use the default grid. If `None`, don't perform a grid search. spline_kwargs Keyword arguments for :class:`pygam.s`. kwargs Keyword arguments for :class:`pygam.pygam.GAM`. """ def __init__( self, adata: AnnData, n_knots: Optional[int] = 6, spline_order: int = 3, distribution: Literal[ "normal", "binomial", "poisson", "gamma", "gaussian", "inv_gauss" ] = "gamma", link: Literal["identity", "logit", "inverse", "log", "inverse-squared"] = "log", max_iter: int = 2000, expectile: Optional[float] = None, grid: Optional[Union[str, Mapping[str, Any]]] = None, spline_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ): term = s( 0, spline_order=spline_order, n_splines=n_knots, penalties=["derivative", "l2"], **_filter_kwargs(s, **{**{"lam": 3}, **spline_kwargs}), ) link = GamLinkFunction(link) distribution = GamDistribution(distribution) if distribution == GamDistribution.GAUSSIAN: distribution = GamDistribution.NORMAL if expectile is not None: if not (0 < expectile < 1): raise ValueError( f"Expected `expectile` to be in `(0, 1)`, found `{expectile}`." ) if distribution != "normal" or link != "identity": logg.warning( f"Expectile GAM works only with `normal` distribution and `identity` link function," f"found `{distribution!r}` distribution and {link!r} link functions." ) model = ExpectileGAM( term, expectile=expectile, max_iter=max_iter, verbose=False, **kwargs ) else: # doing it like this ensure that user can specify scale gam = _gams[distribution, link] filtered_kwargs = _filter_kwargs(gam.__init__, **kwargs) if len(kwargs) != len(filtered_kwargs): raise TypeError( f"Invalid arguments `{list(set(kwargs) - set(filtered_kwargs))}` " f"for `{type(gam).__name__!r}`." ) filtered_kwargs["link"] = link filtered_kwargs["distribution"] = distribution model = gam( term, max_iter=max_iter, verbose=False, **_filter_kwargs(gam.__init__, **filtered_kwargs), ) super().__init__(adata, model=model) if grid is None: self._grid = None elif isinstance(grid, dict): self._grid = _copy(grid) elif isinstance(grid, str): self._grid = object() if grid == "default" else None else: raise TypeError( f"Expected `grid` to be `dict`, `str` or `None`, found `{type(grid).__name__!r}`." )
[docs] @d.dedent def fit( self, x: Optional[np.ndarray] = None, y: Optional[np.ndarray] = None, w: Optional[np.ndarray] = None, **kwargs, ) -> "GAM": """ %(base_model_fit.full_desc)s Parameters ---------- %(base_model_fit.parameters)s Returns ------- Fits the model and returns self. """ # noqa super().fit(x, y, w, **kwargs) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=DeprecationWarning, message=".* is a deprecated alias for the builtin", ) if self._grid is not None: # use default search grid = {} if not isinstance(self._grid, dict) else self._grid try: self.model.gridsearch( self.x, self.y, weights=self.w, keep_best=True, progress=False, **grid, **kwargs, ) return self except Exception as e: # noqa: B902 # workaround for:, self.y, weights=self.w, **kwargs) logg.error( f"Grid search failed, reason: `{e}`. Fitting with default values" ) try:, self.y, weights=self.w, **kwargs) return self except Exception as e: # noqa: B902 raise RuntimeError( f"Unable to fit `{type(self).__name__}` for gene " f"`{self._gene!r}` in lineage `{self._lineage!r}`. Reason: `{e}`" ) from e
[docs] @d.dedent def predict( self, x_test: Optional[np.ndarray] = None, key_added: Optional[str] = "_x_test", **kwargs, ) -> np.ndarray: """ %(base_model_predict.full_desc)s Parameters ---------- %(base_model_predict.parameters)s Returns ------- %(base_model_predict.returns)s """ # noqa x_test = self._check(key_added, x_test) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=DeprecationWarning, message=".* is a deprecated alias for the builtin", ) self._y_test = self.model.predict(x_test, **kwargs) self._y_test = np.squeeze(self._y_test).astype(self._dtype) return self.y_test
[docs] @d.dedent def confidence_interval( self, x_test: Optional[np.ndarray] = None, **kwargs ) -> np.ndarray: """ %(base_model_ci.summary)s Parameters ---------- %(base_model_ci.parameters)s Returns ------- %(base_model_ci.returns)s """ # noqa x_test = self._check("_x_test", x_test) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=DeprecationWarning, message=".* is a deprecated alias for the builtin", ) self._conf_int = self.model.confidence_intervals(x_test, **kwargs).astype( self._dtype ) return self.conf_int
[docs] @d.dedent def copy(self) -> "BaseModel": """%(copy)s""" # noqa res = GAM(self.adata) self._shallowcopy_attributes(res) res._grid = deepcopy(self._grid) return res