import collections
import copy
import enum
import types
import warnings
from typing import Any, Literal, Mapping, Optional, Union
from pygam import GAM as pGAM
from pygam import (
ExpectileGAM,
GammaGAM,
InvGaussGAM,
LinearGAM,
LogisticGAM,
PoissonGAM,
s,
)
import numpy as np
from anndata import AnnData
from cellrank import logging as logg
from cellrank._utils._docs import d
from cellrank._utils._enum import ModeEnum
from cellrank._utils._utils import _filter_kwargs
from cellrank.models import BaseModel
__all__ = ["GAM"]
class GamLinkFunction(ModeEnum):
IDENTITY = enum.auto()
LOGIT = enum.auto()
INVERSE = enum.auto()
LOG = enum.auto()
INV_SQUARED = enum.auto()
class GamDistribution(ModeEnum):
NORMAL = enum.auto()
BINOMIAL = enum.auto()
POISSON = enum.auto()
GAMMA = enum.auto()
GAUSSIAN = enum.auto()
INV_GAUSS = enum.auto()
_gams = collections.defaultdict(
lambda: pGAM,
{
(GamDistribution.NORMAL, GamLinkFunction.IDENTITY): LinearGAM,
(GamDistribution.BINOMIAL, GamLinkFunction.LOGIT): LogisticGAM,
(GamDistribution.POISSON, GamLinkFunction.LOG): PoissonGAM,
(GamDistribution.GAMMA, GamLinkFunction.LOG): GammaGAM,
(GamDistribution.INV_GAUSS, GamLinkFunction.LOG): InvGaussGAM,
},
)
[docs]
@d.dedent
class GAM(BaseModel):
"""Fit Generalized Additive Models (GAMs) using :mod:`pygam`.
Parameters
----------
%(adata)s
n_knots
Number of knots.
spline_order
Order of the splines, e.g., :math:`3` for cubic splines.
distribution
Name of the distribution. Available distributions can be found
`here <https://pygam.readthedocs.io/en/latest/notebooks/tour_of_pygam.html#Distribution:>`__.
link
Name of the link function. Available link functions can be found
`here <https://pygam.readthedocs.io/en/latest/notebooks/tour_of_pygam.html#Link-function:>`__.
max_iter
Maximum number of iterations for optimization.
expectile
Expectile for :class:`~pygam.pygam.ExpectileGAM`. This forces the distribution to be ``'normal'``
and link function to ``'identity'``. Must be in :math:`(0, 1)`.
grid
Whether to perform a grid search. Keys correspond to a parameter names and values to range to be searched.
If ``'default'``, use the default grid. If :obj:`None`, don't perform a grid search.
spline_kwargs
Keyword arguments for :func:`~pygam.terms.s`.
kwargs
Keyword arguments for the :class:`~pygam.pygam.GAM`.
"""
def __init__(
self,
adata: AnnData,
n_knots: Optional[int] = 6,
spline_order: int = 3,
distribution: Literal["normal", "binomial", "poisson", "gamma", "gaussian", "inv_gauss"] = "gamma",
link: Literal["identity", "logit", "inverse", "log", "inv_squared"] = "log",
max_iter: int = 2000,
expectile: Optional[float] = None,
grid: Optional[Union[str, Mapping[str, Any]]] = None,
spline_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
):
term = s(
0,
spline_order=spline_order,
n_splines=n_knots,
**_filter_kwargs(s, **{**{"lam": 3, "penalties": ["derivative", "l2"]}, **spline_kwargs}),
)
link = GamLinkFunction(link)
distribution = GamDistribution(distribution)
if distribution == GamDistribution.GAUSSIAN:
distribution = GamDistribution.NORMAL
if expectile is not None:
if not (0 < expectile < 1):
raise ValueError(f"Expected `expectile` to be in `(0, 1)`, found `{expectile}`.")
if distribution != "normal" or link != "identity":
logg.warning(
f"Expectile GAM works only with `normal` distribution and `identity` link function,"
f"found `{distribution!r}` distribution and {link!r} link functions."
)
model = ExpectileGAM(term, expectile=expectile, max_iter=max_iter, verbose=False, **kwargs)
else:
# doing it like this ensure that user can specify scale
gam = _gams[distribution, link]
filtered_kwargs = _filter_kwargs(gam.__init__, **kwargs)
if len(kwargs) != len(filtered_kwargs):
raise TypeError(
f"Invalid arguments `{list(set(kwargs) - set(filtered_kwargs))}` " f"for `{type(gam).__name__!r}`."
)
filtered_kwargs["link"] = link
filtered_kwargs["distribution"] = distribution
model = gam(
term,
max_iter=max_iter,
verbose=False,
**_filter_kwargs(gam.__init__, **filtered_kwargs),
)
super().__init__(adata, model=model)
if grid is None:
self._grid = None
elif isinstance(grid, dict):
self._grid = copy.copy(grid)
elif isinstance(grid, str):
self._grid = object() if grid == "default" else None
else:
raise TypeError(f"Expected `grid` to be `dict`, `str` or `None`, found `{type(grid).__name__!r}`.")
[docs]
@d.dedent
def fit(
self,
x: Optional[np.ndarray] = None,
y: Optional[np.ndarray] = None,
w: Optional[np.ndarray] = None,
**kwargs: Any,
) -> "GAM":
"""%(base_model_fit.full_desc)s
Parameters
----------
%(base_model_fit.parameters)s
Returns
-------
Fits the model and returns self.
""" # noqa
super().fit(x, y, w, **kwargs)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=".* is a deprecated alias for the builtin",
)
if self._grid is not None:
# use default search
grid = {} if not isinstance(self._grid, dict) else self._grid
try:
self.model.gridsearch(
self.x,
self.y,
weights=self.w,
keep_best=True,
progress=False,
**grid,
**kwargs,
)
return self
except Exception as e: # noqa: BLE001
# workaround for: https://github.com/dswah/pyGAM/issues/273
self.model.fit(self.x, self.y, weights=self.w, **kwargs)
logg.error(f"Grid search failed, reason: `{e}`. Fitting with default values")
try:
self.model.fit(self.x, self.y, weights=self.w, **kwargs)
return self
except Exception as e: # noqa: BLE001
raise RuntimeError(
f"Unable to fit `{type(self).__name__}` for gene "
f"`{self._gene!r}` in lineage `{self._lineage!r}`. Reason: `{e}`"
) from e
[docs]
@d.dedent
def predict(
self,
x_test: Optional[np.ndarray] = None,
key_added: Optional[str] = "_x_test",
**kwargs: Any,
) -> np.ndarray:
"""%(base_model_predict.full_desc)s
Parameters
----------
%(base_model_predict.parameters)s
Returns
-------
%(base_model_predict.returns)s
""" # noqa
x_test = self._check(key_added, x_test)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=".* is a deprecated alias for the builtin",
)
self._y_test = self.model.predict(x_test, **kwargs)
self._y_test = np.squeeze(self._y_test).astype(self._dtype)
return self.y_test
[docs]
@d.dedent
def confidence_interval(self, x_test: Optional[np.ndarray] = None, **kwargs: Any) -> np.ndarray:
"""%(base_model_ci.summary)s
Parameters
----------
%(base_model_ci.parameters)s
Returns
-------
%(base_model_ci.returns)s
""" # noqa
x_test = self._check("_x_test", x_test)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=".* is a deprecated alias for the builtin",
)
self._conf_int = self.model.confidence_intervals(x_test, **kwargs).astype(self._dtype)
return self.conf_int
[docs]
@d.dedent
def copy(self) -> "BaseModel":
"""%(copy)s""" # noqa
res = GAM(self.adata)
self._shallowcopy_attributes(res)
res._grid = copy.deepcopy(self._grid)
return res