Source code for cellrank.ul.models._sklearn_model

"""Module containing model which wraps around :mod:`sklearn` estimators."""
from typing import Iterable, Optional

import warnings
from inspect import signature

from cellrank.ul._docs import d
from cellrank.ul.models import BaseModel
from cellrank.ul.models._base_model import AnnData

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils import check_X_y

[docs]@d.dedent class SKLearnModel(BaseModel): """ Wrapper around :class:`sklearn.base.BaseEstimator`. Parameters ---------- %(adata)s model Instance of the underlying :mod:`sklearn` estimator, such as :class:`sklearn.svm.SVR`. weight_name Name of the weight argument for :attr:`model` ``.fit``. If `None`, to determine it automatically. If and empty string, no weights will be used. ignore_raise Do not raise an exception if weight argument is not found in the fitting function of :attr:`model`. This is useful in case when weight is passed in ``**kwargs`` and cannot be determined from signature. """ _fit_names = ("fit", "__init__") _predict_names = ("predict", "__call__") _weight_names = ("w", "weights", "sample_weight", "sample_weights") _conf_int_names = ("conf_int", "confidence_intervals") def __init__( self, adata: AnnData, model: BaseEstimator, weight_name: Optional[str] = None, ignore_raise: bool = False, ): if not isinstance(model, BaseEstimator): raise TypeError( f"Expected model to be of type `BaseEstimator`, found `{type(model).__name__!r}`." ) super().__init__(adata, model) fit_name = self._find_func(self._fit_names) predict_name = self._find_func(self._predict_names) ci_name = self._find_func(self._conf_int_names, use_default=True, default=None) self._weight_name = ( self._find_arg_name(fit_name, self._weight_names) if weight_name is None else weight_name ) if self._weight_name is None: raise RuntimeError( f"Unable to determine weights for function `{fit_name!r}`, searched `{self._weight_names}`. " f"Consider specifying it manually as `weight_name=...`." ) elif ( not ignore_raise and self._weight_name != "" and self._weight_name not in signature(getattr(self.model, fit_name)).parameters ): raise ValueError( f"Unable to detect `{weight_name!r}` in the signature of `{fit_name!r}`." f"If it's in `kwargs`, set `ignore_raise=True.`" ) self._fit_fn = getattr(self.model, fit_name) self._pred_fn = getattr(self.model, predict_name) self._ci_fn = None if ci_name is None else getattr(self.model, ci_name, None)
[docs] @d.dedent def fit( self, x: Optional[np.ndarray] = None, y: Optional[np.ndarray] = None, w: Optional[np.ndarray] = None, **kwargs, ) -> "SKLearnModel": """ %(base_model_fit.full_desc)s Parameters ---------- %(base_model_fit.parameters)s Returns ------- Fits the model and returns self. """ # noqa: D400 super().fit(x, y, w, **kwargs) if self._weight_name not in (None, ""): kwargs[self._weight_name] = self._w with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) x, y = check_X_y( X=self.x, y=self.y, force_all_finite="allow-nan", copy=False, estimator=self.model, ) self._model = self._fit_fn(x, y, **kwargs) return self
[docs] @d.dedent def predict( self, x_test: Optional[np.ndarray] = None, key_added: str = "_x_test", **kwargs ) -> np.ndarray: """ %(base_model_predict.full_desc)s Parameters ---------- %(base_model_predict.parameters)s Returns ------- %(base_model_predict.returns)s """ # noqa x_test = self._check(key_added, x_test) self._y_test = self._pred_fn(x_test, **kwargs) self._y_test = np.squeeze(self._y_test).astype(self._dtype) return self.y_test
[docs] @d.dedent def confidence_interval( self, x_test: Optional[np.ndarray] = None, **kwargs ) -> np.ndarray: """ %(base_model_ci.full_desc)s Parameters ---------- %(base_model_ci.parameters)s Returns ------- %(base_model_ci.returns)s """ # noqa if self._ci_fn is None: return self.default_confidence_interval(x_test=x_test, **kwargs) x_test = self._check("_x_test", x_test) self._conf_int = self._ci_fn(x_test, **kwargs) return self.conf_int
def _find_func( self, func_names: Iterable[str], use_default: bool = False, default: Optional[str] = None, ) -> Optional[str]: """ Find a function in :attr:`model` from given names. If `None` is found, use ``default`` or raise a :class:`RuntimeError`. Parameters ---------- func_names Function names to search. The first one found is returned. use_default Whether to return the ``default`` if it is `None` or raise :class:`RuntimeError`. default The default function name to use if `None` was found. Returns ------- Name of the function or the default name. """ for name in func_names: if hasattr(self.model, name) and callable(getattr(self.model, name)): return name if use_default: return default raise RuntimeError( f"Unable to find function and no default specified, searched for `{list(func_names)}`." ) def _find_arg_name( self, func_name: Optional[str], param_names: Iterable[str] ) -> Optional[str]: """ Find an argument in :attr:`model`'s ``func_name``. Parameters ---------- func_name Function name of :attr:`model`. param_names Parameter names to search. The first one found is returned. Returns ------- The parameter name or `None`, if `None` was found or ``func_name`` was `None`. """ if func_name is None: return None for param in signature(getattr(self.model, func_name)).parameters: if param in param_names: return param return None @property def model(self) -> BaseEstimator: """The underlying :class:`sklearn.base.BaseEstimator`.""" # noqa return self._model
[docs] @d.dedent def copy(self) -> "SKLearnModel": """%(copy)s""" # noqa res = SKLearnModel( self.adata, self._model, weight_name=self._weight_name, ignore_raise=True ) self._shallowcopy_attributes(res) # this deepcopies the underlying model return res