cellrank.ul.models.SKLearnModel

class cellrank.ul.models.SKLearnModel(adata, model, weight_name=None, ignore_raise=False)[source]

Wrapper around sklearn.base.BaseEstimator.

Parameters
  • adata (anndata.AnnData) – Annotated data object.

  • model (BaseEstimator) – Instance of the underlying sklearn estimator, such as sklearn.svm.SVR.

  • weight_name (Optional[str]) – Name of the weight argument for model .fit. If None, to determine it automatically. If and empty string, no weights will be used.

  • ignore_raise (bool) – Do not raise an exception if weight argument is not found in the fitting function of model. This is useful in case when weight is passed in **kwargs and cannot be determined from signature.

Attributes

adata

Annotated data object.

conf_int

Array of shape (n_samples, 2) containing the lower and upper bounds of the confidence interval.

model

The underlying sklearn.base.BaseEstimator.

prepared

Whether the model is prepared for fitting.

w

Filtered weights of shape (n_filtered_cells,) used for fitting.

w_all

Unfiltered weights of shape (n_cells,).

x

Filtered independent variables of shape (n_filtered_cells, 1) used for fitting.

x_all

Unfiltered independent variables of shape (n_cells, 1).

x_hat

Filtered independent variables used when calculating default confidence interval, usually same as x.

x_test

Independent variables of shape (n_samples, 1) used for prediction.

y

Filtered dependent variables of shape (n_filtered_cells, 1) used for fitting.

y_all

Unfiltered dependent variables of shape (n_cells, 1).

y_hat

Filtered dependent variables used when calculating default confidence interval, usually same as y.

y_test

Prediction values of shape (n_samples,) for x_test.

Methods

confidence_interval([x_test])

Calculate the confidence interval.

copy()

Return a copy of self.

default_confidence_interval([x_test])

Calculate the confidence interval, if the underlying model has no method for it.

fit([x, y, w])

Fit the model.

plot([figsize, same_plot, hide_cells, perc, …])

Plot the smoothed gene expression.

predict([x_test, key_added])

Run the prediction.

prepare(gene, lineage[, backward, …])

Prepare the model to be ready for fitting.

read(fname)

Deserialize self from a file.

write(fname[, ext])

Serialize self to a file.