Source code for cellrank._utils._lineage

import copy as copy_
import enum
import functools
import inspect
import itertools
import pathlib
import types
from typing import (
    Any,
    Callable,
    Iterable,
    List,
    Literal,
    Mapping,
    Optional,
    Tuple,
    TypeVar,
    Union,
)

import numpy as np
import pandas as pd
import scipy.stats as st
from pandas.api.types import infer_dtype

import matplotlib.pyplot as plt
from matplotlib import colors

from anndata import AnnData
from anndata._io.specs.methods import H5Group, ZarrGroup, write_basic
from anndata._io.specs.registry import _REGISTRY, IOSpec

from cellrank import logging as logg
from cellrank._utils._colors import (
    _compute_mean_color,
    _create_categorical_colors,
    _get_bg_fg_colors,
)
from cellrank._utils._docs import d, inject_docs
from cellrank._utils._enum import ModeEnum
from cellrank._utils._key import Key

__all__ = ["Lineage", "LineageView"]

ColorLike = TypeVar("ColorLike")
_ERROR_NOT_ITERABLE = "Expected `{}` to be iterable, found type `{}`."
_ERROR_WRONG_SIZE = "Expected `{}` to be of size `{{}}`, found `{{}}`."

_HT_CELLS = 10  # head and tail cells to show
_HTML_REPR_THRESH = 100
_DUMMY_CELL = "<td style='text-align: right;'>...</td>"
_ORDER = "C"


class PrimingDegree(ModeEnum):  # noqa: D101
    KL_DIVERGENCE = enum.auto()
    ENTROPY = enum.auto()


class DistanceMeasure(ModeEnum):  # noqa: D101
    COSINE_SIM = enum.auto()
    WASSERSTEIN_DIST = enum.auto()
    KL_DIV = enum.auto()
    JS_DIV = enum.auto()
    MUTUAL_INFO = enum.auto()
    EQUAL = enum.auto()


class NormWeights(ModeEnum):  # noqa: D101
    SCALE = enum.auto()
    SOFTMAX = enum.auto()


class Reduction(ModeEnum):  # noqa: D101
    DIST = enum.auto()
    SCALE = enum.auto()


class LinKind(ModeEnum):  # noqa: D101
    MACROSTATES = enum.auto()
    TERM_STATES = enum.auto()
    FATE_PROBS = enum.auto()


def _at_least_2d(array: np.ndarray, dim: int):
    return np.expand_dims(array, dim) if array.ndim < 2 else array


def wrap(numpy_func: Callable) -> Callable:
    """Wrap a :mod:`numpy` function.

    Modifies functionality of some function (e.g. ignoring `.squeeze`, retaining dimensions).

    Parameters
    ----------
    numpy_func
        Function to be wrapped.

    Returns
    -------
    Wrapped function which takes a :class:`cellrank.Lineage` and return :class:`cellrank.Lineage`.
    """

    @functools.wraps(numpy_func)
    def decorator(array: "Lineage", *args, **kwargs):
        if not isinstance(array, Lineage):
            raise TypeError(f"Expected array to be of type `Lineage`, found `{type(array).__name__}`.")
        if fname == "squeeze":
            return array
        if fname == "array_repr":
            return repr(array)

        if "axis" in kwargs:
            axis = kwargs["axis"]
        elif axis_ix < len(args):
            axis = args[axis_ix]
        else:
            axis = default_axis

        res = np.array(numpy_func(array.X, *args, **kwargs), copy=False)

        # handle expand_dim
        if res.ndim > 2:
            return array

        # handle reductions
        if not res.shape:
            return Lineage(np.array([[res]]), names=[fname], colors=["grey"])
        if res.shape == array.shape:
            return Lineage(res, names=array.names, colors=array.colors)

        res = np.expand_dims(res, axis)

        is_t = int(array._is_transposed)
        if is_t:
            res = res.T

        lin = None
        if res.shape[0] == array.shape[is_t]:
            lin = Lineage(
                res,
                names=[f"{fname} of {', '.join(array.names)}"],
                colors=["grey"],
            )

        if res.shape[1] == array.shape[1 - is_t]:
            lin = Lineage(res, names=[f"{fname} of {n}" for n in array.names], colors=array.colors)

        if lin is not None:
            return lin.T if is_t else lin

        raise RuntimeError(
            f"Unable to interpret result of function `{fname}` called " f"with args `{args}`, kwargs: `{kwargs}`."
        )

    params = inspect.signature(numpy_func).parameters
    if "axis" in params:
        axis_ix = list(params.keys()).index("axis") - 1
        default_axis = params["axis"].default
    else:
        axis_ix = 256
        default_axis = None
    assert axis_ix >= 0, f"Expected argument `'axis'` not to be first for function `{numpy_func.__name__}`."

    fname = numpy_func.__name__
    if fname == "amin":
        fname = "min"
    elif fname == "amax":
        fname = "max"

    return decorator


def _register_handled_functions():
    # adapted from:
    # https://github.com/numpy/numpy/blob/v1.26.0/numpy/testing/overrides.py#L50
    try:
        from numpy.core.overrides import ARRAY_FUNCTIONS
    except ImportError:
        ARRAY_FUNCTIONS = [getattr(np, attr) for attr in dir(np)]

    handled_fns = {}
    for fn in ARRAY_FUNCTIONS:
        try:
            sig = inspect.signature(fn)
            if "axis" in sig.parameters:
                handled_fns[fn] = wrap(fn)
        except Exception:  # noqa: BLE001
            pass

    handled_fns.pop(np.expand_dims, None)

    handled_fns[np.allclose] = wrap(np.allclose)
    handled_fns[np.array_repr] = wrap(np.array_repr)
    handled_fns[st.entropy] = wrap(st.entropy)  # qol change

    return handled_fns


_HANDLED_FUNCTIONS = _register_handled_functions()


class LineageMeta(type):
    """Metaclass for Lineage.

    It registers functions which are handled by us and overloads common attributes, such as `.sum` with these functions.
    """

    __overloaded_functions__ = dict(  # noqa
        sum=np.sum,
        mean=np.mean,
        min=np.min,
        argmin=np.argmin,
        max=np.max,
        argmax=np.argmax,
        std=np.std,
        var=np.var,
        sort=np.sort,
        squeeze=np.squeeze,
        entropy=st.entropy,
    )

    def __new__(cls, clsname, superclasses, attributedict):  # noqa
        res = type.__new__(cls, clsname, superclasses, attributedict)
        for attrname, fn in LineageMeta.__overloaded_functions__.items():
            wrapped_fn = _HANDLED_FUNCTIONS.get(fn, None)
            if wrapped_fn:
                setattr(res, attrname, wrapped_fn)

        return res


[docs] class Lineage(np.ndarray, metaclass=LineageMeta): """Lightweight :class:`~numpy.ndarray` wrapper that adds names and colors. Parameters ---------- input_array Input array containing lineage probabilities stored in columns. names Lineage names. colors Lineage colors. """ def __new__( cls, input_array: np.ndarray, *, names: Iterable[str], colors: Optional[Iterable[ColorLike]] = None, ) -> "Lineage": """Create and return a new object.""" if not isinstance(input_array, np.ndarray): raise TypeError(f"Input array must be of type `numpy.ndarray`, found `{type(input_array).__name__!r}`.") if input_array.ndim == 1: input_array = np.expand_dims(input_array, -1) elif input_array.ndim > 2: raise ValueError(f"Input array must be 2-dimensional, found `{input_array.ndim}`.") if not input_array.shape[0]: raise ValueError("Expected number cells to be at least `1`, found `0`.") if not input_array.shape[1]: raise ValueError("Expected number of lineages to be at least `1`, found `0`.") obj = np.array(input_array, copy=True).view(cls) obj._n_lineages = obj.shape[1] obj._is_transposed = False obj.names = names # these always create a copy, which is a good thing obj.colors = colors return obj def __array_finalize__(self, obj) -> None: if obj is None: return _names = getattr(obj, "_names", None) if _names is not None: self._names = _names self._n_lineages = len(_names) self._names_to_ixs = {n: i for i, n in enumerate(self.names)} else: self._names = None self._names_to_ixs = None self._n_lineages = getattr(obj, "_n_lineages", obj.shape[1] if obj.ndim == 2 else 0) self._colors = getattr(obj, "colors", None) self._is_transposed = getattr(obj, "_is_transposed", False) def __array_function__(self, func, types, args, kwargs): if func not in _HANDLED_FUNCTIONS: return NotImplemented # Note: this allows subclasses that don't override # __array_function__ to handle MyArray objects if not all(issubclass(t, self.__class__) for t in types): return NotImplemented return _HANDLED_FUNCTIONS[func](*args, **kwargs) def __getitem__(self, item) -> "Lineage": was_transposed = False if self._is_transposed: was_transposed = True self = self.T if isinstance(item, tuple): item = item[::-1] obj = self.__getitem(item) return obj.T if was_transposed else obj def _mix_lineages(self, rows, mixtures: Iterable[Union[str, Any]]) -> "Lineage": from cellrank._utils._utils import _unique_order_preserving def unsplit(names: str) -> Tuple[str, ...]: return tuple(sorted({name.strip(" ") for name in names.strip(" ,").split(",")})) keys = [ tuple(self._maybe_convert_names(unsplit(mixture), default=mixture)) if isinstance(mixture, str) else (mixture,) for mixture in mixtures ] keys = _unique_order_preserving(keys) # check the `keys` are unique overlap = [set(ks) for ks in keys] for c1, c2 in itertools.combinations(overlap, 2): overlap = c1 & c2 if overlap: raise ValueError(f"Found overlapping keys: `{self.names[list(overlap)]}`.") names, colors, res = [], [], [] for key in map(list, keys): if key: res.append(self[rows, key].X.sum(1)) names.append(", ".join(self.names[key])) colors.append(_compute_mean_color(self.colors[key])) return Lineage(np.stack(res, axis=-1), names=names, colors=colors) def __getitem(self, item): if isinstance(item, tuple): if len(item) > 2: raise ValueError(f"Expected key to be of length `2`, found `{len(item)}`.") item = list(item) if item[0] is Ellipsis or item[0] is None: item[0] = range(self.shape[0]) if len(item) == 2 and (item[1] is Ellipsis or item[1] is None): item[1] = range(self.shape[1]) item = tuple(item) is_tuple_len_2 = ( isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], (int, np.integer, range, slice, tuple, list, np.ndarray)) ) if is_tuple_len_2: rows, col = item if isinstance(col, (int, np.integer, str)): col = [col] try: # slicing an array where row/col are like 2D indices if 1 < len(col) == len(rows) and len(rows) == self.shape[0]: col = self._maybe_convert_names(col, make_unique=False) return Lineage( # never remove this expand_dims - it's critical np.expand_dims(self.X[rows, col], axis=-1), names=["mixture"], colors=["#000000"], ) except TypeError: # because of range pass if isinstance(col, (list, tuple, np.ndarray)): if any(isinstance(i, str) and "," in i for i in col): return self._mix_lineages(rows, col) col = self._maybe_convert_names(col) item = rows, col else: if isinstance(item, (int, np.integer, str)): item = [item] col = range(len(self.names)) if isinstance(item, (tuple, list, np.ndarray)): if any(isinstance(i, str) and "," in i for i in item): return self._mix_lineages(slice(None, None, None), item) if any(isinstance(i, str) for i in item): item = (slice(None, None, None), self._maybe_convert_names(item)) col = item[1] shape, row_order, col_order = None, None, None if is_tuple_len_2 and not isinstance(item[0], slice) and not isinstance(item[1], slice): item_0 = np.array(item[0]) if not isinstance(item[0], np.ndarray) else item[0] item_1 = np.array(item[1]) if not isinstance(item[1], np.ndarray) else item[1] item_0 = _at_least_2d(item_0, -1) item_1 = _at_least_2d(item_1, 0) # handle boolean indexing if item_1.dtype == bool: if item_0.dtype != bool: if not issubclass(item_0.dtype.type, np.integer): raise TypeError(f"Invalid type `{item_0.dtype.type}`.") row_order = ( item_0[:, 0] if item_0.shape[0] == self.shape[0] else np.argsort(np.argsort(item_0[:, 0])) ) item_0 = _at_least_2d(np.isin(np.arange(self.shape[0]), item_0), -1) item = item_0 * item_1 shape = np.max(np.sum(item, axis=0)), np.max(np.sum(item, axis=1)) col = np.where(np.all(item_1, axis=0))[0] elif item_0.dtype == bool: if item_1.dtype != bool: if not issubclass(item_1.dtype.type, np.integer): raise TypeError(f"Invalid type `{item_1.dtype.type}`.") col_order = ( item_1[0, :] if item_1.shape[1] == self.shape[1] else np.argsort(np.argsort(item_1[0, :])) ) item_1 = _at_least_2d(np.isin(np.arange(self.shape[1]), item_1), 0) item = item_0 * item_1 shape = np.max(np.sum(item, axis=0)), np.max(np.sum(item, axis=1)) col = np.where(np.all(item_1, axis=0))[0] else: # defer to numpy item = (item_0, item_1) obj = super().__getitem__(item) # keep the resulting shape if shape is not None: obj = obj.reshape(shape) # correctly reorder if row_order is not None: obj = obj[row_order, :] if col_order is not None: obj = obj[:, col_order] # correctly set names and colors if isinstance(obj, Lineage): obj._names = np.atleast_1d(self.names[col]) obj._colors = np.atleast_1d(self.colors[col]) if col_order is not None: obj._names = obj._names[col_order] obj._colors = obj._colors[col_order] obj._names_to_ixs = {name: i for i, name in enumerate(obj._names)} return obj @property def names(self) -> np.ndarray: """Lineage names.""" return self._names @names.setter def names(self, value: Iterable[str]) -> None: if not isinstance(value, Iterable): raise TypeError(_ERROR_NOT_ITERABLE.format("names", type(value).__name__)) value = [str(v) for v in value] value = self._check_axis1_shape(value, _ERROR_WRONG_SIZE.format("names")) if len(set(value)) != len(value): raise ValueError(f"Not all lineage names are unique: `{value}`.") self._names = self._prepare_annotation(value) self._names_to_ixs = {name: ix for ix, name in enumerate(self.names)} @property def colors(self) -> np.ndarray: """Lineage colors.""" return self._colors @colors.setter def colors(self, value: Optional[Iterable[ColorLike]]) -> None: if value is None: value = _create_categorical_colors(self._n_lineages) elif not isinstance(value, Iterable): raise TypeError(_ERROR_NOT_ITERABLE.format("colors", type(value).__name__)) value = self._check_axis1_shape(value, _ERROR_WRONG_SIZE.format("colors")) self._colors = self._prepare_annotation( value, checker=colors.is_color_like, transformer=colors.to_hex, checker_msg="Value `{}` is not a valid color.", ) @property def X(self) -> np.ndarray: """Convert self to an array.""" return np.array(self, copy=False) @property def T(self): """Transpose of self.""" obj = self.transpose() obj._is_transposed = not self._is_transposed return obj @property def nlin(self) -> int: """Number of lineages.""" return self.shape[1]
[docs] @d.get_full_description(base="lin_pd") @d.get_sections(base="lin_pd", sections=["Parameters", "Returns"]) def priming_degree( self, method: Literal["kl_divergence", "entropy"] = "kl_divergence", early_cells: Optional[np.ndarray] = None, ) -> np.ndarray: """Compute the degree of lineage priming. It returns a score in :math:`[0, 1]` where :math:`0` stands for naive and :math:`1` stands for committed. Parameters ---------- method The method used to compute the degree of lineage priming. Valid options are: - ``'kl_divergence'`` - as in :cite:`velten:17`, computes KL-divergence between the fate probabilities of a cell and the average fate probabilities. Computation of average fate probabilities can be restricted to a set of user-defined ``early_cells``. - ``'entropy'`` - as in :cite:`setty:19`, computes entropy over a cell's fate probabilities. early_cells Cell IDs or a mask marking early cells. If :obj:`None`, use all cells. Only used when ``method = 'kl_divergence'``. Returns ------- The priming degree. """ early_cells = np.ones((len(self),), dtype=np.bool_) if early_cells is None else np.asarray(early_cells) if not np.issubdtype(early_cells.dtype, np.bool_): early_cells = np.unique(early_cells) method = PrimingDegree(method) probs = self.X early_subset = probs[early_cells, :] if not len(early_subset): raise ValueError("No early cells have been specified.") with np.errstate(divide="ignore", invalid="ignore"): if method == PrimingDegree.KL_DIVERGENCE: probs = np.nan_to_num( np.sum(probs * np.log2(probs / np.mean(early_subset, axis=0)), axis=1), nan=1.0, copy=False, ) elif method == PrimingDegree.ENTROPY: probs = st.entropy(probs, axis=1) probs = np.max(probs) - probs else: raise NotImplementedError(f"Method `{method}` is not yet implemented") minn, maxx = np.min(probs), np.max(probs) return (probs - minn) / (maxx - minn)
[docs] @d.dedent def plot_pie( self, reduction: Callable, title: Optional[str] = None, legend_loc: Optional[str] = "on data", legend_kwargs: Mapping = types.MappingProxyType({}), figsize: Optional[Tuple[float, float]] = None, dpi: Optional[float] = None, save: Optional[Union[pathlib.Path, str]] = None, **kwargs: Any, ) -> None: """Plot a pie chart visualizing aggregated lineage probabilities. Parameters ---------- reduction Function that will be applied lineage-wise. title Title of the figure. legend_loc Location of the legend. If :obj:`None`, it is not shown. legend_kwargs Keyword arguments for :meth:`~matplotlib.axes.Axes.legend`. %(plotting)s kwargs Keyword arguments for :meth:`~matplotlib.axes.Axes.pie`. Returns ------- %(just_plots)s """ from cellrank._utils._utils import save_fig if len(self.names) == 1: raise ValueError("Cannot plot pie chart for only 1 lineage.") fig, ax = plt.subplots(figsize=figsize, dpi=dpi) if "autopct" not in kwargs: autopct_found = False autopct = "{:.1f}%".format # we don't really care, we don't shot the pct, but the value else: autopct_found = True autopct = kwargs.pop("autopct") if title is None: title = reduction.__name__ if hasattr(reduction, "__name__") else None reduction = reduction(self, axis=int(self._is_transposed)).X.squeeze() reduction_norm = reduction / np.sum(reduction) wedges, texts, *autotexts = ax.pie( reduction_norm.squeeze(), labels=self.names if legend_loc == "on data" else None, autopct=autopct, wedgeprops={"edgecolor": "w"}, colors=self.colors, **kwargs, ) # if autopct is not None if len(autotexts): autotexts = autotexts[0] for name, at in zip(self.names, autotexts): ix = self._names_to_ixs[name] at.set_color(_get_bg_fg_colors(self.colors[ix])[1]) if not autopct_found: at.set_text(f"{reduction[ix]:.4f}") if legend_loc not in (None, "none", "on data"): ax.legend( wedges, self.names, title="lineages", loc=legend_loc, **legend_kwargs, ) ax.set_title(title) ax.set_aspect("equal") if save is not None: save_fig(fig, save)
[docs] @d.dedent @inject_docs(m=Reduction, dm=DistanceMeasure, nw=NormWeights) def reduce( self, *keys: str, mode: Literal["dist", "scale"] = Reduction.DIST, dist_measure: Literal[ "cosine_sim", "wasserstein_dist", "kl_div", "js_div", "mutual_info", "equal" ] = DistanceMeasure.MUTUAL_INFO, normalize_weights: Literal["scale", "softmax"] = NormWeights.SOFTMAX, softmax_scale: float = 1.0, return_weights: bool = False, ) -> Union["Lineage", Tuple["Lineage", Optional[pd.DataFrame]]]: """Subset states and normalize them so that they again sum to :math:`1`. Parameters ---------- keys List of keys that define the states, to which this object will be reduced by projecting the values of the other states. mode Reduction mode to use. Valid options are: - ``{m.DIST!r}`` - use a distance measure ``dist_measure`` to compute weights. - ``{m.SCALE!r}`` - just rescale the values. dist_measure Used to quantify similarity between query and reference states. Valid options are: - ``{dm.COSINE_SIM!r}`` - cosine similarity. - ``{dm.WASSERSTEIN_DIST!r}`` - Wasserstein distance. - ``{dm.KL_DIV!r}`` - Kullback–Leibler divergence. - ``{dm.JS_DIV!r}`` - Jensen–Shannon divergence. - ``{dm.MUTUAL_INFO!r}`` - mutual information. - ``{dm.EQUAL!r}`` - equally redistribute the mass among the rest. Only use when ``mode = {m.DIST!r}``. normalize_weights How to row-normalize the weights. Valid options are: - ``{nw.SCALE!r}`` - divide by the sum. - ``{nw.SOFTMAX!r}``- use a softmax. Only used when ``mode = {m.DIST!r}``. softmax_scale Scaling factor in the softmax, used for normalizing the weights to sum to :math:`1`. return_weights If `True`, a :class:`~pandas.DataFrame` of the weights used for the projection is also returned. If ``mode = {m.SCALE!r}``, the weights will be `None`. Returns ------- The lineage object, reduced to the %(initial_or_terminal)s states. The weights used for the projection of shape ``(n_query, n_reference)``, if ``return_weights = True``. """ mode = Reduction(mode) dist_measure = DistanceMeasure(dist_measure) normalize_weights = NormWeights(normalize_weights) if self._is_transposed: raise RuntimeError("This method works only on non-transposed lineages.") if not len(keys): raise ValueError("Unable to perform the reduction, no keys specified.") # check the lineage object if not np.allclose(np.sum(self.X, axis=1), 1.0): raise ValueError("Memberships do not sum to one row-wise.") if len(keys) == 1: tmp = self[:, keys] return Lineage( np.ones((self.shape[0], 1), dtype=self.dtype), names=tmp.names, colors=tmp.colors, ) # check input parameters if return_weights and mode == Reduction.SCALE: logg.warning(f"If `mode={mode!r}`, no weights are computed. Returning `None`") reference = self[:, keys] rest = [k for k in self.names if all(k not in rk for rk in reference.names)] if not rest: logg.warning( "Unable to perform reduction because all keys have been selected. Returning combined object only" ) return (reference.copy(), None) if return_weights else reference.copy() query = self[:, rest] if mode == Reduction.SCALE: reference = _row_normalize(reference) elif mode == Reduction.DIST: # compute a set of weights of shape (n_query x n_reference) if dist_measure == DistanceMeasure.COSINE_SIM: weights = _cosine_sim(reference.X, query.X) elif dist_measure == DistanceMeasure.WASSERSTEIN_DIST: weights = _wasserstein_dist(reference.X, query.X) elif dist_measure == DistanceMeasure.KL_DIV: weights = _kl_div(reference.X, query.X) elif dist_measure == DistanceMeasure.JS_DIV: weights = _js_div(reference.X, query.X) elif dist_measure == DistanceMeasure.MUTUAL_INFO: weights = _mutual_info(reference.X, query.X) elif dist_measure == DistanceMeasure.EQUAL: weights = _row_normalize(np.ones((query.shape[1], reference.shape[1]))) else: raise NotImplementedError(f"Distance measure `{dist_measure}` is not yet implemented.") # make some checks on the weights if weights.shape != (query.shape[1], reference.shape[1]): raise ValueError( f"Expected weight matrix to be of shape `({query.shape[1]}, {reference.shape[1]})`, " f"found `{weights.shape}`." ) if not np.isfinite(weights).all(): raise ValueError("Weights matrix contains elements that are not finite.") if (weights < 0).any(): raise ValueError("Weights matrix contains negative elements.") if (weights == 0).any(): logg.warning("Weights matrix contains exact zeros.") # normalize the weights to row-sum to one if normalize_weights == NormWeights.SCALE: weights_n = _row_normalize(weights) elif normalize_weights == NormWeights.SOFTMAX: weights_n = _softmax(_row_normalize(weights), softmax_scale) else: raise NotImplementedError(f"Normalization method `{normalize_weights}` is yet implemented.") # check that the weights row-sum to one now if not np.allclose(weights_n.sum(1), 1.0): raise ValueError("Weights do not sum to 1 row-wise.") # use the weights to re-distribute probability mass form query to reference for i, w in enumerate(weights_n): reference += np.dot(query[:, i].X, w[None, :]) else: raise NotImplementedError(f"Reduction mode `{mode}` is not yet implemented.") # check that the lineages row-sum to one now if not np.allclose(reference.sum(1), 1.0): raise ValueError("Reduced lineage rows do not sum to 1.") # potentially create a weights-df and return everything if return_weights: if mode == Reduction.DIST: return ( reference, pd.DataFrame(data=weights_n, columns=reference.names, index=query.names), ) return reference, None return reference
[docs] @classmethod @d.dedent @inject_docs(lk=LinKind) def from_adata( cls, adata: AnnData, backward: bool = False, estimator_backward: Optional[bool] = None, kind: Literal["macrostates", "term_states", "fate_probs"] = LinKind.FATE_PROBS, copy: bool = False, ) -> "Lineage": """Reconstruct the :class:`~cellrank.Lineage` object from :class:`~anndata.AnnData` object. Parameters ---------- %(adata)s %(backward)s estimator_backward Key which helps to determine whether these states are initial or terminal. kind Which kind of object to reconstruct. Valid options are: - ``{lk.MACROSTATES!r}``- macrostates memberships from :class:`cellrank.estimators.GPCCA`. - ``{lk.TERM_STATES!r}``- terminal states memberships from :class:`cellrank.estimators.GPCCA`. - ``{lk.FATE_PROBS!r}``- fate probabilities. copy Whether to return a copy of the underlying array. Returns ------- The reconstructed lineage object. """ kind = LinKind(kind) if kind == LinKind.MACROSTATES: nkey = Key.obs.macrostates(backward) key = Key.obsm.memberships(nkey) elif kind == LinKind.TERM_STATES: nkey = Key.obs.term_states(estim_bwd=estimator_backward, bwd=backward) key = Key.obsm.memberships(nkey) elif kind == LinKind.FATE_PROBS: nkey = Key.obs.term_states(estim_bwd=estimator_backward, bwd=backward) key = Key.obsm.fate_probs(backward) else: raise NotImplementedError(f"Lineage kind `{kind}` is not yet implemented.") ckey = Key.uns.colors(nkey) if key not in adata.obsm: raise KeyError(f"Unable to find lineage data in `adata.obsm[{key!r}]`.") data: Union[np.ndarray, Lineage] = adata.obsm[key] if copy: data = copy_.copy(data) if isinstance(data, Lineage): return data if data.ndim != 2: raise ValueError(f"Expected 2 dimensional data, found `{data.ndim}`.") states = adata.obs.get(nkey, None) if states is None: logg.warning(f"Unable to find states in `adata.obs[{nkey!r}]`. Using default names") elif not isinstance(states.dtype, pd.CategoricalDtype): logg.warning( f"Expected `adata.obs[{key!r}]` to be `categorical`, " f"found `{infer_dtype(adata.obs[nkey])}`. Using default names" ) else: states = list(states.cat.categories) if len(states) != data.shape[1]: logg.warning( f"Expected to find `{data.shape[1]}` names, found `{len(states)}`. " f"Using default names" ) if states is None or len(states) != data.shape[1]: states = [str(i) for i in range(data.shape[1])] colors = adata.uns.get(ckey, None) if colors is None: logg.warning(f"Unable to find colors in `adata.uns[{ckey!r}]`. " f"Using default colors") elif len(colors) != data.shape[1]: logg.warning(f"Expected to find `{data.shape[1]}` colors, found `{len(colors)}`. " f"Using default colors") colors = None return Lineage(data, names=states, colors=colors)
[docs] def view(self, dtype=None, type=None, *_, **__) -> "LineageView": """Return a view of self.""" return LineageView(self)
def __repr__(self) -> str: return f'{super().__repr__()[:-1]},\n names([{", ".join(self.names)}]))' def __str__(self): return f'{super().__str__()}\n names=[{", ".join(self.names)}]' @property def _fmt(self) -> Callable[[Any], str]: return "{:.06f}".format if np.issubdtype(self.dtype, float) else "{}".format def _repr_html_(self) -> str: def format_row(r): rng = ( range(self.shape[1]) if not self._is_transposed or (self._is_transposed and self.shape[1] <= _HTML_REPR_THRESH) else list(range(_HT_CELLS)) + [...] + list(range(self.shape[1] - _HT_CELLS - 1, self.shape[1] - 1)) ) cells = "".join( f"<td style='text-align: right;'>" f"{self._fmt(self.X[r, c])}" f"</td>" if isinstance(c, int) else _DUMMY_CELL for c in rng ) return f"<tr>{(names[r] if self._is_transposed else '') + cells}</tr>" def dummy_row() -> str: values = "".join(_DUMMY_CELL for _ in range(self.shape[1])) return f"<tr>{values}</tr>" if self.names is None or self.colors is None: raise RuntimeError( f"Name or colors are `None`. This can happen when running `array.view({type(self).__name__}`." ) if self.ndim != 2: return repr(self) styles = [ f"'background-color: {bg}; color: {fg}; text-align: center; word-wrap: break-word; max-width: 100px'" for bg, fg in map(_get_bg_fg_colors, self.colors) ] names = [f"<th style={style}>{n}</th>" for n, style in zip(self.names, styles)] header = f"<tr>{''.join(names)}</tr>" if self.shape[0] > _HTML_REPR_THRESH: body = "".join(format_row(i) for i in range(_HT_CELLS)) body += dummy_row() body += "".join(format_row(i) for i in range(self.shape[0] - _HT_CELLS - 1, self.shape[0] - 1)) else: body = "".join(format_row(i) for i in range(self.shape[0])) cells = "cells" if self.shape[0] > 1 else "cell" lineages = "lineages" if self.shape[1] > 1 else "lineage" if self._is_transposed: cells, lineages = lineages, cells metadata = f"<p>{self.shape[0]} {cells} x {self.shape[1]} {lineages}</p>" if self._is_transposed: header = "" return ( f"<div style='scoped' class='rendered_html'>" f"<table class='dataframe'>{header}{body}</table>{metadata}" f"</div>" ) def __format__(self, format_spec): if self.shape == (1, 1): return format_spec.format(self.X[0, 0]) if self.shape == (1,): return format_spec.format(self.X[0]) return NotImplemented def __setstate__(self, state, *_, **__): *state, names, colors, is_t = state names = names[-1] colors = colors[-1] self._names = np.empty(names[1]) self._colors = np.empty(colors[1]) super().__setstate__(tuple(state)) self._names.__setstate__(tuple(names)) self._colors.__setstate__(tuple(colors)) self._is_transposed = is_t self._n_lineages = len(self.names) self._names_to_ixs = {name: ix for ix, name in enumerate(self.names)} def __reduce__(self): res = list(super().__reduce__()) names = self.names.__reduce__() colors = self.colors.__reduce__() res[-1] += (names, colors, self._is_transposed) return tuple(res) def copy(self, _="C") -> "Lineage": """Return a copy of itself.""" obj = Lineage( self.T if self._is_transposed else self, names=np.array(self.names, copy=True, order=_ORDER), colors=np.array(self.colors, copy=True, order=_ORDER), ) return obj.T if self._is_transposed else obj def __copy__(self): return self.copy() def _check_axis1_shape(self, array: Iterable[Union[str, ColorLike]], msg: str) -> List[Union[str, ColorLike]]: """Check whether the size of the 1D array has the correct length.""" array = list(array) if len(array) != self._n_lineages: raise ValueError(msg.format(self._n_lineages, len(array))) return array def _maybe_convert_names( self, names: Iterable[Union[int, str, bool]], is_singleton: bool = False, default: Optional[Union[int, str]] = None, make_unique: bool = True, ) -> Union[int, List[int], List[bool]]: """Convert string indices to their corresponding int indices.""" from cellrank._utils._utils import _unique_order_preserving if all(isinstance(n, (bool, np.bool_)) for n in names): return list(names) res = [] for name in names: if isinstance(name, str): if name in self._names_to_ixs: name = self._names_to_ixs[name] elif default is not None: if isinstance(default, str): if default not in self._names_to_ixs: raise KeyError( f"Invalid lineage name: `{name}`. " f"Valid names are: `{list(self.names)}`." ) name = self._names_to_ixs[default] else: name = default else: raise KeyError(f"Invalid lineage name `{name!r}`. Valid names are: `{list(self.names)}`.") res.append(name) if make_unique: res = _unique_order_preserving(res) return res[0] if is_singleton else res @staticmethod def _prepare_annotation( array: List[str], checker: Optional[Callable] = None, transformer: Optional[Callable] = None, checker_msg: Optional[str] = None, ) -> np.ndarray: if checker is not None: assert checker_msg, "Please provide a message when `checker` is not `None`." for v in array: if not checker(v): raise ValueError(checker_msg.format(v)) if transformer is not None: array = np.array([transformer(v) for v in array]) return np.array(array)
class LineageView(Lineage): """View of :class:`~cellrank.Lineage`.""" def __new__(cls, lineage: Lineage) -> "LineageView": """Create a LineageView.""" if not isinstance(lineage, Lineage): raise TypeError(f"Cannot create a `{cls.__name__}` of `{type(lineage).__name__}`.") view = np.array(lineage, copy=False).view(cls) view._owner = lineage view._names = lineage.names view._n_lineages = len(view.names) view._names_to_ixs = lineage._names_to_ixs view._colors = lineage.colors view._is_transposed = lineage._is_transposed return view @property def names(self) -> np.ndarray: """Lineage names.""" return super().names @property def owner(self) -> Lineage: """Return the lineage associated with this view.""" return self._owner @names.setter def names(self, _): raise RuntimeError(f"Unable to set names of `{type(self).__name__}`.") @property def colors(self) -> np.ndarray: """Lineage colors.""" return super().colors @colors.setter def colors(self, _): raise RuntimeError(f"Unable to set colors of `{type(self).__name__}`.") def view(self, dtype=None, type=None, *_, **__) -> "LineageView": """Return self.""" return self def copy(self, _="C") -> Lineage: """Return a copy of self.""" was_trasposed = False if self._is_transposed: self = self.T was_trasposed = True obj = Lineage( self, names=np.array(self.names, copy=True, order=_ORDER), colors=np.array(self.colors, copy=True, order=_ORDER), ) return obj.T if was_trasposed else obj def _remove_zero_rows(a: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: if a.shape[0] != b.shape[0]: raise ValueError("Lineage objects have unequal cell numbers") bool_a = (a == 0).any(axis=1) bool_b = (b == 0).any(axis=1) mask = ~np.logical_or(bool_a, bool_b) logg.warning(f"Removed {a.shape[0] - np.sum(mask)} rows because they contained zeros") return a[mask, :], b[mask, :] def _softmax(X, beta: float = 1): return np.exp(X * beta) / np.expand_dims(np.sum(np.exp(X * beta), axis=1), -1) def _row_normalize(X: Union[np.ndarray, Lineage]) -> Union[np.ndarray, Lineage]: if isinstance(X, Lineage): return X / X.sum(1) # lineage is shape-preserving return X / X.sum(1, keepdims=True) def _col_normalize(X, norm_ord=2): from numpy.linalg import norm return X / norm(X, ord=norm_ord, axis=0) def _cosine_sim(reference, query): # the cosine similarity is symmetric # normalize these to have 2-norm 1 reference_n, query_n = _col_normalize(reference, 2), _col_normalize(query, 2) return (reference_n.T @ query_n).T def _point_wise_distance(reference, query, distance): # utility function for all point-wise distances/divergences # take care of rows that contain zeros reference_no_zero, query_no_zero = _remove_zero_rows(reference, query) # normalize these to be valid probability distributions (column-wise) reference_n, query_n = ( _col_normalize(reference_no_zero, 1), _col_normalize(query_no_zero, 1), ) # loop over query and reference columns and compute pairwise wasserstein distances weights = np.zeros((query.shape[1], reference.shape[1])) for i, q_d in enumerate(query_n.T): for j, r_d in enumerate(reference_n.T): weights[i, j] = 1.0 / distance(q_d, r_d) return weights def _wasserstein_dist(reference, query): # the wasserstein distance is symmetric return _point_wise_distance(reference, query, st.wasserstein_distance) def _kl_div(reference, query): return _point_wise_distance(reference, query, st.entropy) def _js_div(reference, query): # the js divergence is symmetric from scipy.spatial.distance import jensenshannon return _point_wise_distance(reference, query, jensenshannon) def _mutual_info(reference, query): # mutual information is not symmetric. We don't need to normalise the vectors, it's invariant under scaling. from sklearn.feature_selection import mutual_info_regression weights = np.zeros((query.shape[1], reference.shape[1])) for i, target in enumerate(query.T): weights[i, :] = mutual_info_regression(reference, target) return weights @_REGISTRY.register_write(H5Group, Lineage, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(H5Group, LineageView, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, Lineage, IOSpec("array", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, LineageView, IOSpec("array", "0.2.0")) def _write_lineage( f: Any, k: str, elem: Union[Lineage, LineageView], _writer: Any, dataset_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> None: write_basic(f, k, elem=elem.X, _writer=_writer, dataset_kwargs=dataset_kwargs)