import collections
import contextlib
import copy
import enum
import io
import logging
import types
from collections.abc import Mapping
from typing import Any, Literal
import numpy as np
from anndata import AnnData
from pygam import GAM as pGAM
from pygam import (
ExpectileGAM,
GammaGAM,
InvGaussGAM,
LinearGAM,
LogisticGAM,
PoissonGAM,
s,
)
from cellrank._utils._docs import d
from cellrank._utils._enum import ModeEnum
from cellrank._utils._utils import _filter_kwargs
from cellrank.models import BaseModel
logger = logging.getLogger(__name__)
__all__ = ["GAM"]
class GamLinkFunction(ModeEnum):
IDENTITY = enum.auto()
LOGIT = enum.auto()
INVERSE = enum.auto()
LOG = enum.auto()
INV_SQUARED = enum.auto()
class GamDistribution(ModeEnum):
NORMAL = enum.auto()
BINOMIAL = enum.auto()
POISSON = enum.auto()
GAMMA = enum.auto()
GAUSSIAN = enum.auto()
INV_GAUSS = enum.auto()
_gams = collections.defaultdict(
lambda: pGAM,
{
(GamDistribution.NORMAL, GamLinkFunction.IDENTITY): LinearGAM,
(GamDistribution.BINOMIAL, GamLinkFunction.LOGIT): LogisticGAM,
(GamDistribution.POISSON, GamLinkFunction.LOG): PoissonGAM,
(GamDistribution.GAMMA, GamLinkFunction.LOG): GammaGAM,
(GamDistribution.INV_GAUSS, GamLinkFunction.LOG): InvGaussGAM,
},
)
[docs]
@d.dedent
class GAM(BaseModel):
"""Fit Generalized Additive Models (GAMs) using :mod:`pygam`.
Parameters
----------
%(adata)s
n_knots
Number of knots.
spline_order
Order of the splines, e.g., :math:`3` for cubic splines.
distribution
Name of the distribution. Available distributions can be found
`here <https://pygam.readthedocs.io/en/latest/notebooks/tour_of_pygam.html#Distribution:>`__.
link
Name of the link function. Available link functions can be found
`here <https://pygam.readthedocs.io/en/latest/notebooks/tour_of_pygam.html#Link-function:>`__.
max_iter
Maximum number of iterations for optimization.
expectile
Expectile for :class:`~pygam.pygam.ExpectileGAM`. This forces the distribution to be ``'normal'``
and link function to ``'identity'``. Must be in :math:`(0, 1)`.
grid
Whether to perform a grid search. Keys correspond to a parameter names and values to range to be searched.
If ``'default'``, use the default grid. If :obj:`None`, don't perform a grid search.
spline_kwargs
Keyword arguments for :func:`~pygam.terms.s`.
kwargs
Keyword arguments for the :class:`~pygam.pygam.GAM`.
"""
def __init__(
self,
adata: AnnData,
n_knots: int | None = 6,
spline_order: int = 3,
distribution: Literal["normal", "binomial", "poisson", "gamma", "gaussian", "inv_gauss"] = "gamma",
link: Literal["identity", "logit", "inverse", "log", "inv_squared"] = "log",
max_iter: int = 2000,
expectile: float | None = None,
grid: str | Mapping[str, Any] | None = None,
spline_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
):
term = s(
0,
spline_order=spline_order,
n_splines=n_knots,
**_filter_kwargs(s, **{**{"lam": 3, "penalties": ["derivative", "l2"]}, **spline_kwargs}),
)
link = GamLinkFunction(link)
distribution = GamDistribution(distribution)
if distribution == GamDistribution.GAUSSIAN:
distribution = GamDistribution.NORMAL
if expectile is not None:
if not (0 < expectile < 1):
raise ValueError(f"Expected `expectile` to be in `(0, 1)`, found `{expectile}`.")
if distribution != "normal" or link != "identity":
logger.warning(
"Expectile GAM works only with `normal` distribution and `identity` link function, "
"found `%r` distribution and %r link functions.",
distribution,
link,
)
model = ExpectileGAM(term, expectile=expectile, max_iter=max_iter, verbose=False, **kwargs)
else:
# doing it like this ensures that user can specify scale
gam = _gams[distribution, link]
filtered_kwargs = _filter_kwargs(gam.__init__, **kwargs)
if len(kwargs) != len(filtered_kwargs):
raise TypeError(
f"Invalid arguments `{list(set(kwargs) - set(filtered_kwargs))}` for `{type(gam).__name__!r}`."
)
filtered_kwargs["link"] = link
filtered_kwargs["distribution"] = distribution
model = gam(
term,
max_iter=max_iter,
verbose=False,
**_filter_kwargs(gam.__init__, **filtered_kwargs),
)
super().__init__(adata, model=model)
self._converged: bool = True
if grid is None:
self._grid = None
elif isinstance(grid, dict):
self._grid = copy.copy(grid)
elif isinstance(grid, str):
self._grid = object() if grid == "default" else None
else:
raise TypeError(f"Expected `grid` to be `dict`, `str` or `None`, found `{type(grid).__name__!r}`.")
[docs]
@d.dedent
def fit(
self,
x: np.ndarray | None = None,
y: np.ndarray | None = None,
w: np.ndarray | None = None,
**kwargs: Any,
) -> "GAM":
"""%(base_model_fit.full_desc)s
Parameters
----------
%(base_model_fit.parameters)s
Returns
-------
Fits the model and returns self.
""" # noqa: D400
super().fit(x, y, w, **kwargs)
if self._grid is not None:
# use default search
grid = {} if not isinstance(self._grid, dict) else self._grid
try:
with contextlib.redirect_stdout(io.StringIO()):
self.model.gridsearch(
self.x,
self.y,
weights=self.w,
keep_best=True,
progress=False,
**grid,
**kwargs,
)
return self
except Exception as e: # noqa: BLE001
# workaround for: https://github.com/dswah/pyGAM/issues/273
with contextlib.redirect_stdout(io.StringIO()):
self.model.fit(self.x, self.y, weights=self.w, **kwargs)
logger.error("Grid search failed, reason: `%s`. Fitting with default values", e)
try:
buf = io.StringIO()
with contextlib.redirect_stdout(buf):
self.model.fit(self.x, self.y, weights=self.w, **kwargs)
self._converged = "did not converge" not in buf.getvalue()
return self
except Exception as e: # noqa: BLE001
raise RuntimeError(
f"Unable to fit `{type(self).__name__}` for gene "
f"`{self._gene!r}` in lineage `{self._lineage!r}`. Reason: `{e}`"
) from e
[docs]
@d.dedent
def predict(
self,
x_test: np.ndarray | None = None,
key_added: str | None = "_x_test",
**kwargs: Any,
) -> np.ndarray:
"""%(base_model_predict.full_desc)s
Parameters
----------
%(base_model_predict.parameters)s
Returns
-------
%(base_model_predict.returns)s
""" # noqa: D400
x_test = self._check(key_added, x_test)
self._y_test = self.model.predict(x_test, **kwargs)
self._y_test = np.squeeze(self._y_test).astype(self._dtype)
return self.y_test
[docs]
@d.dedent
def confidence_interval(self, x_test: np.ndarray | None = None, **kwargs: Any) -> np.ndarray:
"""%(base_model_ci.summary)s
Parameters
----------
%(base_model_ci.parameters)s
Returns
-------
%(base_model_ci.returns)s
""" # noqa: D400
x_test = self._check("_x_test", x_test)
self._conf_int = self.model.confidence_intervals(x_test, **kwargs).astype(self._dtype)
return self.conf_int
[docs]
@d.dedent
def copy(self) -> "BaseModel":
"""%(copy)s""" # noqa: D400, D401
res = GAM(self.adata)
self._shallowcopy_attributes(res)
res._grid = copy.deepcopy(self._grid)
return res