Source code for

from typing import Any, Dict, Tuple, Union, Optional, Sequence

from types import MappingProxyType
from pathlib import Path

import scanpy as sc
from anndata import AnnData
from cellrank import logging as logg
from cellrank._key import Key
from import _DEFAULT_BACKEND, Backend_t
from cellrank.ul._docs import d
from import (
from import save_fig, _unique_order_preserving
from cellrank.ul._utils import _genesymbols, _get_n_cores, _check_collection

import numpy as np
from sklearn.preprocessing import StandardScaler

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, is_color_like
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec

[docs]@d.dedent @_genesymbols def cluster_lineage( adata: AnnData, model: _input_model_type, genes: Sequence[str], lineage: str, backward: bool = False, time_range: _time_range_type = None, clusters: Optional[Sequence[str]] = None, n_points: int = 200, time_key: str = "latent_time", covariate_key: Optional[Union[str, Sequence[str]]] = None, ratio: float = 0.05, cmap: Optional[str] = "viridis", norm: bool = True, recompute: bool = False, callback: _callback_type = None, ncols: int = 3, sharey: Union[str, bool] = False, key: Optional[str] = None, random_state: Optional[int] = None, show_progress_bar: bool = True, n_jobs: Optional[int] = 1, backend: Backend_t = _DEFAULT_BACKEND, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}), neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}), clustering_kwargs: Dict = MappingProxyType({}), return_models: bool = False, **kwargs: Any, ) -> Optional[_return_model_type]: """ Cluster gene expression trends within a lineage and plot the clusters. This function is based on Palantir, see :cite:`setty:19`. It can be used to discover modules of genes that drive development along a given lineage. Consider running this function on a subset of genes which are potential lineage drivers, identified e.g. by running :func:``. Parameters ---------- %(adata)s %(model)s %(genes)s lineage Name of the lineage for which to cluster the genes. %(backward)s %(time_range)s clusters Cluster identifiers to plot. If `None`, all clusters will be considered. Useful when plotting previously computed clusters. n_points Number of points used for prediction. time_key Key in :attr:`anndata.AnnData.obs` where the pseudotime is stored. covariate_key Key(s) in :attr:`anndata.AnnData.obs` containing observations to be plotted at the bottom of each plot. %(gene_symbols)s ratio Height ratio of each covariate in ``covariate_key``. cmap Colormap to use for continuous covariates in ``covariate_key``. norm Whether to z-normalize each trend to have zero mean, unit variance. recompute If `True`, recompute the clustering, otherwise try to find already existing one. %(model_callback)s ncols Number of columns for the plot. sharey Whether to share y-axis across multiple plots. key Key in :attr:`anndata.AnnData.uns` where to save the results. If `None`, it will be saved as ``'lineage_{lineage}_trend'`` . random_state Random seed for reproducibility. %(parallel)s %(plotting)s pca_kwargs Keyword arguments for :func:`scanpy.pp.pca`. neighbors_kwargs Keyword arguments for :func:`scanpy.pp.neighbors`. clustering_kwargs Keyword arguments for :func:``. %(return_models)s kwargs: Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`. Returns ------- %(plots_or_returns_models)s Also updates ``adata.uns`` with the following: - ``key`` or ``lineage_{lineage}_trend`` - an :class:`anndata.AnnData` object of shape `(n_genes, n_points)` containing the clustered genes. """ def plot_cluster( row: int, col: int, cluster: str, sharey_ax: Optional[str] = None ) -> Optional[plt.Axes]: gss = GridSpecFromSubplotSpec( row_delta, 1, subplot_spec=gs[row : row + row_delta, col], hspace=0, height_ratios=[1.0] + [ratio] * (row_delta - 1), ) ax = fig.add_subplot(gss[0, 0], sharey=sharey_ax) data = trends[trends.obs["clusters"] == c].X mean, sd = np.mean(data, axis=0), np.std(data, axis=0) for i in range(data.shape[0]): ax.plot(data[i], color="gray", lw=0.5) ax.plot(mean, lw=2, color="black") ax.plot(mean - sd, lw=1.5, color="black", linestyle="--") ax.plot(mean + sd, lw=1.5, color="black", linestyle="--") ax.fill_between( range(len(mean)), mean - sd, mean + sd, color="black", alpha=0.1 ) ax.set_title(f"cluster {cluster}") ax.set_xticks([]) if sharey: ax.set_yticks([]) ax.margins(0) if covariate_colors is not None: for i, colors in enumerate(covariate_colors): ax_clusters = fig.add_subplot(gss[i + 1, 0]) if is_color_like(colors[0]): # e.g. categorical cm = ListedColormap(colors, N=len(colors)) ax_clusters.imshow(np.arange(cm.N)[None, :], cmap=cm, aspect="auto") else: cm = plt.get_cmap(cmap) ax_clusters.imshow(colors[None, :], cmap=cm, aspect="auto") ax_clusters.set_xticks([]) ax_clusters.set_yticks([]) ax_clusters.set_xticks(np.linspace(0, len(colors), 5)) ax_clusters.set_xticklabels( [f"{v:.3f}" for v in np.linspace(tmin, tmax, 5)] ) return ax if sharey else None use_raw = kwargs.get("use_raw", False) lineage_key = Key.obsm.abs_probs(backward) if lineage_key not in adata.obsm: raise KeyError(f"Lineages not found in `adata.obsm[{lineage_key!r}]`.") _ = adata.obsm[lineage_key][lineage] genes = _unique_order_preserving(genes) _check_collection(adata, genes, "var_names", use_raw=use_raw) if key is None: key = f"lineage_{lineage}_trend" if recompute or key not in adata.uns: kwargs["backward"] = backward kwargs["time_key"] = time_key kwargs["n_test_points"] = n_points models = _create_models(model, genes, [lineage]) all_models, models, genes, _ = _fit_bulk( models, _create_callbacks(adata, callback, genes, [lineage], **kwargs), genes, lineage, time_range, return_models=True, # always return (better error messages) filter_all_failed=True, parallel_kwargs={ "show_progress_bar": show_progress_bar, "n_jobs": _get_n_cores(n_jobs, len(genes)), "backend": _get_backend(models, backend), }, **kwargs, ) # `n_genes, n_test_points` trends = np.vstack([model[lineage].y_test for model in models.values()]).T if norm: logg.debug("Normalizing trends") _ = StandardScaler(copy=False).fit_transform(trends) mod = next(mod for tmp in all_models.values() for mod in tmp.values()) trends = AnnData(trends.T) trends.obs_names = genes trends.var["x_test"] = x_test = mod.x_test # sanity check if trends.n_obs != len(genes): raise RuntimeError( f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`." ) if trends.n_vars != n_points: raise RuntimeError( f"Expected to find `{n_points}` points, found `{trends.n_vars}`." ) random_state = np.random.mtrand.RandomState(random_state).randint(2 ** 16) pca_kwargs = dict(pca_kwargs) pca_kwargs.setdefault("n_comps", min(50, n_points, len(genes)) - 1) pca_kwargs.setdefault("random_state", random_state) sc.pp.pca(trends, **pca_kwargs) neighbors_kwargs = dict(neighbors_kwargs) neighbors_kwargs.setdefault("random_state", random_state) sc.pp.neighbors(trends, **neighbors_kwargs) clustering_kwargs = dict(clustering_kwargs) clustering_kwargs["key_added"] = "clusters" clustering_kwargs.setdefault("random_state", random_state), **clustering_kwargs)"Saving data to `adata.uns[{key!r}]`") adata.uns[key] = trends else: all_models = None"Loading data from `adata.uns[{key!r}]`") trends = adata.uns[key] x_test = trends.var["x_test"] if "clusters" not in trends.obs: raise KeyError("Unable to find the clustering in `trends.obs['clusters']`.") if clusters is None: clusters = trends.obs["clusters"].cat.categories for c in clusters: if c not in trends.obs["clusters"].cat.categories: raise ValueError( f"Invalid cluster name `{c!r}`. " f"Valid options are `{list(trends.obs['clusters'].cat.categories)}`." ) nrows = int(np.ceil(len(clusters) / ncols)) fig = plt.figure( dpi=dpi, figsize=(ncols * 10, nrows * 10) if figsize is None else figsize, tight_layout=True, ) if covariate_key is None: covariate_colors, row_delta = None, 1 else: tmin, tmax = np.min(x_test), np.max(x_test) covariate_colors = _get_sorted_colors( adata, covariate_key, time_key, tmin=tmin, tmax=tmax, ) row_delta = len(covariate_colors) + 1 gs = GridSpec(nrows=nrows * row_delta, ncols=ncols, figure=fig) row, sharey_ax = -row_delta, None for i, c in enumerate(clusters): if i % ncols == 0: row += row_delta sharey_ax = plot_cluster(row, i % ncols, c, sharey_ax=sharey_ax) if save is not None: save_fig(fig, save) if return_models: return all_models