Source code for cellrank.tl.kernels._base_kernel

from typing import Any, Dict, List, Tuple, Union, Callable, Optional, Sequence

from abc import ABC, abstractmethod
from copy import deepcopy
from pathlib import Path

from anndata import AnnData
from cellrank import logging as logg
from cellrank.ul._docs import d, inject_docs
from cellrank.tl._utils import save_fig, _normalize
from cellrank.tl._mixins import IOMixin
from cellrank.tl.kernels.utils import RandomWalk, FlowPlotter
from cellrank.tl.kernels._utils import require_tmat
from cellrank.tl.kernels._mixins import BidirectionalMixin, UnidirectionalMixin
from cellrank.tl.kernels.utils._projection import Projector
from cellrank.tl.kernels.utils._random_walk import Indices_t

import numpy as np
from scipy.sparse import issparse, spmatrix, csr_matrix, isspmatrix_csr

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

__all__ = ("UnidirectionalKernel", "BidirectionalKernel")

Tmat_t = Union[np.ndarray, spmatrix]


class KernelExpression(IOMixin, ABC):
    def __init__(
        self,
        parent: Optional["KernelExpression"] = None,
        **kwargs: Any,
    ):
        super().__init__()
        self._parent = parent
        self._normalize = parent is None
        self._transition_matrix = None
        self._params: Dict[str, Any] = {}

    def __init_subclass__(cls, **_: Any) -> None:
        super().__init_subclass__()

    @abstractmethod
    def compute_transition_matrix(
        self, *args: Any, **kwargs: Any
    ) -> "KernelExpression":
        """
        Compute transition matrix.

        Parameters
        ----------
        args
            Positional arguments.
        kwargs
            Keyword arguments.

        Returns
        -------
        Modifies :attr:`transition_matrix` and returns self.
        """

    @abstractmethod
    def copy(self, *, deep: bool = False) -> "KernelExpression":
        """Return a copy of itself. The underlying :attr:`adata` object is not copied."""

    @property
    @abstractmethod
    def adata(self) -> AnnData:
        """Annotated data object."""

    @adata.setter
    @abstractmethod
    def adata(self, value: Optional[AnnData]) -> None:
        pass

    @property
    @abstractmethod
    def kernels(self) -> Tuple["KernelExpression", ...]:
        """Underlying base kernels."""

    @property
    @abstractmethod
    def shape(self) -> Tuple[int, int]:
        """``(n_cells, n_cells)``."""

    @property
    @abstractmethod
    def backward(self) -> Optional[bool]:
        """Direction of the process."""

    @abstractmethod
    def __getitem__(self, ix: int) -> "KernelExpression":
        pass

    @abstractmethod
    def __len__(self) -> int:
        pass

    @d.get_full_description(base="plot_single_flow")
    @d.get_sections(base="plot_single_flow", sections=["Parameters", "Returns"])
    @d.dedent
    @require_tmat
    def plot_single_flow(
        self,
        cluster: str,
        cluster_key: str,
        time_key: str,
        clusters: Optional[Sequence[Any]] = None,
        time_points: Optional[Sequence[Union[int, float]]] = None,
        min_flow: float = 0,
        remove_empty_clusters: bool = True,
        ascending: Optional[bool] = False,
        legend_loc: Optional[str] = "upper right out",
        alpha: Optional[float] = 0.8,
        xticks_step_size: Optional[int] = 1,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        show: bool = True,
    ) -> Optional[plt.Axes]:
        """
        Visualize outgoing flow from a cluster of cells :cite:`mittnenzweig:21`.

        Parameters
        ----------
        cluster
            Cluster for which to visualize outgoing flow.
        cluster_key
            Key in :attr:`anndata.AnnData.obs` where clustering is stored.
        time_key
            Key in :attr:`anndata.AnnData.obs` where experimental time is stored.
        clusters
            Visualize flow only for these clusters. If `None`, use all clusters.
        time_points
            Visualize flow only for these time points. If `None`, use all time points.
        %(flow.parameters)s
        %(plotting)s
        show
            If `False`, return :class:`matplotlib.pyplot.Axes`.

        Returns
        -------
        The axes object, if ``show = False``.
        %(just_plots)s

        Notes
        -----
        This function is a Python reimplementation of the following
        `original R function <https://github.com/tanaylab/embflow/blob/main/scripts/generate_paper_figures/plot_vein.r>`_
        with some minor stylistic differences.
        This function will not recreate the results from :cite:`mittnenzweig:21`, because there, the Metacell model
        :cite:`baran:19` was used to compute the flow, whereas here the transition matrix is used.
        """  # noqa: E501
        fp = FlowPlotter(self.adata, self.transition_matrix, cluster_key, time_key)
        fp = fp.prepare(cluster, clusters, time_points)

        ax = fp.plot(
            min_flow=min_flow,
            remove_empty_clusters=remove_empty_clusters,
            ascending=ascending,
            alpha=alpha,
            xticks_step_size=xticks_step_size,
            legend_loc=legend_loc,
            figsize=figsize,
            dpi=dpi,
        )

        if save is not None:
            save_fig(ax.figure, save)

        if not show:
            return ax

    @d.dedent
    @require_tmat
    def plot_random_walks(
        self,
        n_sims: int = 100,
        max_iter: Union[int, float] = 0.25,
        seed: Optional[int] = None,
        successive_hits: int = 0,
        start_ixs: Indices_t = None,
        stop_ixs: Indices_t = None,
        basis: str = "umap",
        cmap: Union[str, LinearSegmentedColormap] = "gnuplot",
        linewidth: float = 1.0,
        linealpha: float = 0.3,
        ixs_legend_loc: Optional[str] = None,
        n_jobs: Optional[int] = None,
        backend: str = "loky",
        show_progress_bar: bool = True,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        **kwargs: Any,
    ) -> None:
        """
        Plot random walks in an embedding.

        This method simulates random walks on the Markov chain defined though the corresponding transition matrix. The
        method is intended to give qualitative rather than quantitative insights into the transition matrix. Random
        walks are simulated by iteratively choosing the next cell based on the current cell's transition probabilities.

        Parameters
        ----------
        n_sims
            Number of random walks to simulate.
        %(rw_sim.parameters)s
        start_ixs
            Cells from which to sample the starting points. If `None`, use all cells.
            %(rw_ixs)s
            For example ``{'dpt_pseudotime': [0, 0.1]}`` means that starting points for random walks
            will be sampled uniformly from cells whose pseudotime is in `[0, 0.1]`.
        stop_ixs
            Cells which when hit, the random walk is terminated. If `None`, terminate after ``max_iters``.
            %(rw_ixs)s
            For example ``{'clusters': ['Alpha', 'Beta']}`` and ``successive_hits = 3`` means that the random walk will
            stop prematurely after cells in the above specified clusters have been visited successively 3 times in a
            row.
        basis
            Basis in :attr:`anndata.AnnData.obsm` to use as an embedding.
        cmap
            Colormap for the random walk lines.
        linewidth
            Width of the random walk lines.
        linealpha
            Alpha value of the random walk lines.
        ixs_legend_loc
            Legend location for the start/top indices.
        %(parallel)s
        %(plotting)s
        kwargs
            Keyword arguments for :func:`scvelo.pl.scatter`.

        Returns
        -------
        %(just_plots)s
        For each random walk, the first/last cell is marked by the start/end colors of ``cmap``.
        """
        rw = RandomWalk(
            self.adata, self.transition_matrix, start_ixs=start_ixs, stop_ixs=stop_ixs
        )
        sims = rw.simulate_many(
            n_sims=n_sims,
            max_iter=max_iter,
            seed=seed,
            n_jobs=n_jobs,
            backend=backend,
            successive_hits=successive_hits,
            show_progress_bar=show_progress_bar,
        )

        rw.plot(
            sims,
            basis=basis,
            cmap=cmap,
            linewidth=linewidth,
            linealpha=linealpha,
            ixs_legend_loc=ixs_legend_loc,
            figsize=figsize,
            dpi=dpi,
            save=save,
            **kwargs,
        )

    @require_tmat
    def plot_projection(
        self,
        basis: str = "umap",
        key_added: Optional[str] = None,
        recompute: bool = False,
        stream: bool = True,
        **kwargs: Any,
    ) -> None:
        """
        Plot :attr:`transition_matrix` as a stream or a grid plot.

        Parameters
        ----------
        basis
            Key in :attr:`anndata.AnnData.obsm` containing the basis onto which to project.
        key_added
            If not `None`, save the result to :attr:`anndata.AnnData.obsm` ``['{key_added}']``.
            Otherwise, save the result to `'T_fwd_{basis}'` or `T_bwd_{basis}`, depending on the direction.
        recompute
            Whether to recompute the projection if it already exists.
        stream
            If ``True``, use :func:`scvelo.pl.velocity_embedding_stream`.
            Otherwise, use :func:`scvelo.pl.velocity_embedding_grid`.
        kwargs
            Keyword argument for the plotting function.

        Returns
        -------
        Nothing, just plots and modifies :attr:`anndata.AnnData.obsm` with a key based on ``key_added``.
        """
        proj = Projector(self, basis=basis)
        proj.project(key_added=key_added, recompute=recompute)
        proj.plot(stream=stream, **kwargs)

    def __add__(self, other: "KernelExpression") -> "KernelExpression":
        return self.__radd__(other)

    def __radd__(self, other: "KernelExpression") -> "KernelExpression":
        def same_level_add(k1: "KernelExpression", k2: "KernelExpression") -> bool:
            if not (isinstance(k1, KernelAdd) and isinstance(k2, KernelMul)):
                return False

            for kexpr in k1:
                if not isinstance(kexpr, KernelMul):
                    return False
                if not kexpr._bin_consts:
                    return False

            return True

        if not isinstance(other, KernelExpression):
            return NotImplemented

        s = self * 1.0 if isinstance(self, Kernel) else self
        o = other * 1.0 if isinstance(other, Kernel) else other

        # (c1 * x + c2 * y + ...) + (c3 * z) => (c1 * x + c2 * y + c3 + ... + * z)
        if same_level_add(s, o):
            return KernelAdd(*tuple(s) + (o,))
        if same_level_add(o, s):
            return KernelAdd(*tuple(o) + (s,))

        # add virtual constant
        if not isinstance(s, KernelMul):
            s = s * 1.0
        if not isinstance(o, KernelMul):
            o = o * 1.0

        return KernelAdd(s, o)

    def __mul__(
        self, other: Union[float, int, "KernelExpression"]
    ) -> "KernelExpression":
        return self.__rmul__(other)

    def __rmul__(
        self, other: Union[int, float, "KernelExpression"]
    ) -> "KernelExpression":
        def same_level_mul(k1: "KernelExpression", k2: "KernelExpression") -> bool:
            return (
                isinstance(k1, KernelMul)
                and isinstance(k2, Constant)
                and k1._bin_consts
            )

        if isinstance(other, (int, float, np.integer, np.floating)):
            other = Constant(self.adata, other)

        if not isinstance(other, KernelExpression):
            return NotImplemented

        # fmt: off
        s = self if isinstance(self, (KernelMul, Constant)) else KernelMul(Constant(self.adata, 1.0), self)
        o = other if isinstance(other, (KernelMul, Constant)) else KernelMul(Constant(other.adata, 1.0), other)
        # fmt: on

        # at this point, only KernelMul and Constant is possible (two constants are not possible)
        # (c1 * k) * c2 => (c1 * c2) * k
        if same_level_mul(s, o):
            c, expr = s._split_const
            return KernelMul(c * o, expr)
        if same_level_mul(o, s):
            c, expr = o._split_const
            return KernelMul(c * s, expr)

        return KernelMul(s, o)

    @d.get_sections(base="write_to_adata", sections=["Parameters"])
    @inject_docs()  # gets rid of {{}} in %(write_to_adata)s
    @d.dedent
    @require_tmat
    def write_to_adata(self, key: Optional[str] = None, copy: bool = False) -> None:
        """
        Write the transition matrix and parameters used for computation to the underlying :attr:`adata` object.

        Parameters
        ----------
        key
            Key used when writing transition matrix to :attr:`adata`.
            If `None`, the key will be determined automatically.

        Returns
        -------
        %(write_to_adata)s
        """
        from cellrank._key import Key

        if self.adata is None:
            raise ValueError("Underlying annotated data object is not set.")

        key = Key.uns.kernel(self.backward, key=key)
        # retain the embedding info
        self.adata.uns[f"{key}_params"] = {
            **self.adata.uns.get(f"{key}_params", {}),
            **{"params": self.params},
        }
        self.adata.obsp[key] = (
            self.transition_matrix.copy() if copy else self.transition_matrix
        )

    @property
    def transition_matrix(self) -> Union[np.ndarray, csr_matrix]:
        """Row-normalized transition matrix."""
        if self._parent is None and self._transition_matrix is None:
            self.compute_transition_matrix()
        return self._transition_matrix

    @transition_matrix.setter
    def transition_matrix(self, matrix: Union[np.ndarray, spmatrix]) -> None:
        """
        Set the transition matrix.

        Parameters
        ----------
        matrix
            Transition matrix. The matrix is row-normalized if necessary.

        Returns
        -------
        Nothing, just updates the :attr:`transition_matrix` and optionally normalizes it.
        """
        # fmt: off
        if matrix.shape != self.shape:
            raise ValueError(
                f"Expected matrix to be of shape `{self.shape}`, found `{matrix.shape}`."
            )

        def should_norm(mat: Union[np.ndarray, spmatrix]) -> bool:
            return not np.all(np.isclose(np.asarray(mat.sum(1)).squeeze(), 1.0, rtol=1e-12))

        if issparse(matrix) and not isspmatrix_csr(matrix):
            matrix = csr_matrix(matrix)
        matrix = matrix.astype(np.float64, copy=False)

        force_normalize = (self._parent is None or self._normalize) and should_norm(matrix)
        if force_normalize:
            if np.any((matrix.data if issparse(matrix) else matrix) < 0):
                raise ValueError("Unable to normalize matrix with negative values.")
            matrix = _normalize(matrix)
            if should_norm(matrix):  # some rows are all 0s/contain invalid values
                n_inv = np.sum(~np.isclose(np.asarray(matrix.sum(1)).squeeze(), 1.0, rtol=1e-12))
                raise ValueError(f"Transition matrix is not row stochastic, {n_inv} rows do not sum to 1.")
        # fmt: on

        self._transition_matrix = matrix

    @property
    def params(self) -> Dict[str, Any]:
        """Parameters which are used to compute the transition matrix."""
        if len(self.kernels) == 1:
            return self._params
        return {f"{k!r}:{i}": k.params for i, k in enumerate(self.kernels)}

    def _reuse_cache(
        self, expected_params: Dict[str, Any], *, time: Optional[Any] = None
    ) -> bool:
        # fmt: off
        try:
            if expected_params == self._params:
                assert self.transition_matrix is not None
                logg.debug("Using cached transition matrix")
                logg.info("    Finish", time=time)
                return True
            return False
        except AssertionError:
            logg.warning("Transition matrix does not exist for the given parameters. Recomputing")
            return False
        except Exception as e:  # noqa: B902
            # e.g. the dict is not comparable
            logg.warning(f"Expected and actually parameters are not comparable, reason `{e}`. Recomputing")
            expected_params = {}  # clear the params
            return False
        finally:
            self._params = expected_params
        # fmt: on


[docs]class Kernel(KernelExpression, ABC): def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(**kwargs) self._adata = adata self._n_obs = adata.n_obs self._read_from_adata(**kwargs) def _read_from_adata(self, **kwargs: Any) -> None: pass @property def adata(self) -> AnnData: """Annotated data object.""" return self._adata @adata.setter def adata(self, adata: Optional[AnnData]) -> None: if adata is None: self._adata = None return if not isinstance(adata, AnnData): raise TypeError( f"Expected `adata` to be of type `AnnData`, found `{type(adata).__name__}`." ) shape = (adata.n_obs, adata.n_obs) if self.shape != shape: raise ValueError( f"Expected new `AnnData` object to have same shape as the previous `{self.shape}`, found `{shape}`." ) self._adata = adata
[docs] def copy(self, *, deep: bool = False) -> "Kernel": with self._remove_adata: k = deepcopy(self) k.adata = self.adata.copy() if deep else self.adata return k
def _copy_ignore(self, *attrs: str) -> "Kernel": # prevent copying attributes that are not necessary, e.g. during inversion sentinel, attrs = object(), set(attrs) objects = [ (attr, obj) for attr, obj in ((attr, getattr(self, attr, sentinel)) for attr in attrs) if obj is not sentinel ] try: for attr, _ in objects: setattr(self, attr, None) return self.copy(deep=False) finally: for attr, obj in objects: setattr(self, attr, obj) def _format_params(self) -> str: return ", ".join( f"{k}={round(v, 3) if isinstance(v, float) else v!r}" for k, v in self.params.items() ) @property def transition_matrix(self) -> Union[np.ndarray, csr_matrix]: """Row-normalized transition matrix.""" return self._transition_matrix @transition_matrix.setter def transition_matrix(self, matrix: Any) -> None: KernelExpression.transition_matrix.fset(self, matrix) @property def kernels(self) -> Tuple["KernelExpression", ...]: return (self,) @property def shape(self) -> Tuple[int, int]: """``(n_cells, n_cells)``.""" return self._n_obs, self._n_obs def __getitem__(self, ix: int) -> "Kernel": if ix != 0: raise IndexError(ix) return self def __len__(self) -> int: return 1 def __repr__(self) -> str: return f"{'~' if self.backward and self._parent is None else ''}{self.__class__.__name__}" def __str__(self) -> str: params_fmt = self._format_params() if params_fmt: return ( f"{'~' if self.backward and self._parent is None else ''}" f"{self.__class__.__name__}[{params_fmt}]" ) return repr(self)
class UnidirectionalKernel(UnidirectionalMixin, Kernel, ABC): """Kernel with no directionality.""" class BidirectionalKernel(BidirectionalMixin, Kernel, ABC): """Kernel with either forward or backward direction that can be inverted using :meth:`__invert__`.""" class Constant(UnidirectionalKernel): def __init__(self, adata: AnnData, value: Union[int, float]): super().__init__(adata) self.transition_matrix = value def compute_transition_matrix(self, value: Union[int, float]) -> "Constant": self.transition_matrix = value return self @property def transition_matrix(self) -> Union[int, float]: return self._transition_matrix @transition_matrix.setter def transition_matrix(self, value: Union[int, float]) -> None: if not isinstance(value, (int, float, np.integer, np.floating)): raise TypeError( f"Value must be a `float` or `int`, found `{type(value).__name__}`." ) if value <= 0: raise ValueError(f"Expected the scalar to be positive, found `{value}`.") self._transition_matrix = value self._params = {"value": value} def copy(self, *, deep: bool = False) -> "Constant": return Constant(self.adata, self.transition_matrix) # fmt: off def __radd__(self, other: Union[int, float, "KernelExpression"]) -> "Constant": if isinstance(other, (int, float, np.integer, np.floating)): other = Constant(self.adata, other) if isinstance(other, Constant): if self.shape != other.shape: raise ValueError(f"Expected kernel shape to be `{self.shape}`, found `{other.shape}`.") return Constant(self.adata, self.transition_matrix + other.transition_matrix) return super().__radd__(other) def __rmul__(self, other: Union[int, float, "KernelExpression"]) -> "Constant": if isinstance(other, (int, float, np.integer, np.floating)): other = Constant(self.adata, other) if isinstance(other, Constant): if self.shape != other.shape: raise ValueError(f"Expected kernel shape to be `{self.shape}`, found `{other.shape}`.") return Constant(self.adata, self.transition_matrix * other.transition_matrix) return super().__rmul__(other) # fmt: on def __repr__(self) -> str: return repr(round(self.transition_matrix, 3)) def __str__(self) -> str: return str(round(self.transition_matrix, 3)) class NaryKernelExpression(BidirectionalMixin, KernelExpression): def __init__( self, *kexprs: KernelExpression, parent: Optional[KernelExpression] = None ): super().__init__(parent=parent) self._validate(kexprs) @abstractmethod def _combine_transition_matrices(self, t1: Tmat_t, t2: Tmat_t) -> Tmat_t: pass @property @abstractmethod def _initial_value(self) -> Union[int, float, Tmat_t]: pass @property @abstractmethod def _combiner(self) -> str: pass def compute_transition_matrix(self) -> "NaryKernelExpression": for kexpr in self: if kexpr.transition_matrix is None: if isinstance(kexpr, Kernel): raise RuntimeError( f"`{kexpr}` is uninitialized. Compute its " f"transition matrix first as `.compute_transition_matrix()`." ) kexpr.compute_transition_matrix() tmat = self._initial_value for kexpr in self: tmat = self._combine_transition_matrices(tmat, kexpr.transition_matrix) self.transition_matrix = tmat return self def _validate(self, kexprs: Sequence[KernelExpression]) -> None: if not len(kexprs): raise ValueError("No kernels to combined.") shapes = {kexpr.shape for kexpr in kexprs} if len(shapes) > 1: raise ValueError( f"Expected all kernels to have the same shapes, found `{sorted(shapes)}`." ) directions = {kexpr.backward for kexpr in kexprs} if True in directions and False in directions: raise ValueError("Unable to combine both forward and backward kernels.") self._backward = ( True if True in directions else False if False in directions else None ) self._kexprs = kexprs for kexpr in self: kexpr._parent = self @property def adata(self) -> AnnData: """Annotated data object.""" return self[0].adata @adata.setter def adata(self, adata: Optional[AnnData]) -> None: # allow resetting (use for temp. pickling without adata) for kexpr in self: kexpr.adata = adata @property def kernels(self) -> Tuple["KernelExpression", ...]: """Underlying unique basic kernels.""" kernels = [] for kexpr in self: if isinstance(kexpr, Kernel) and not isinstance(kexpr, Constant): kernels.append(kexpr) elif isinstance(kexpr, NaryKernelExpression): # recurse kernels.extend(kexpr.kernels) # return only unique kernels return tuple(set(kernels)) def copy(self, *, deep: bool = False) -> "KernelExpression": kexprs = (k.copy(deep=deep) for k in self) return type(self)(*kexprs, parent=self._parent) @property def shape(self) -> Tuple[int, int]: """``(n_cells, n_cells)``.""" # all kernels have the same shape return self[0].shape def _format(self, formatter: Callable[[KernelExpression], str]) -> str: return ( f"{'~' if self.backward and self._parent is None else ''}(" + f" {self._combiner} ".join(formatter(kexpr) for kexpr in self) + ")" ) def __getitem__(self, ix: int) -> "KernelExpression": return self._kexprs[ix] def __len__(self) -> int: return len(self._kexprs) def __invert__(self) -> "KernelExpression": if self.backward is None: return self.copy() kexprs = tuple( ~k if isinstance(k, BidirectionalMixin) else k.copy() for k in self ) kexpr = type(self)(*kexprs, parent=self._parent) kexpr._transition_matrix = None return kexpr def __repr__(self) -> str: return self._format(repr) def __str__(self) -> str: return self._format(str) class KernelAdd(NaryKernelExpression): def compute_transition_matrix(self) -> "KernelAdd": self._maybe_recalculate_constants() return super().compute_transition_matrix() def _combine_transition_matrices(self, t1: Tmat_t, t2: Tmat_t) -> Tmat_t: return t1 + t2 def _maybe_recalculate_constants(self) -> None: """Normalize constants to sum to 1.""" constants = self._bin_consts if constants: total = sum((c.transition_matrix for c in constants), 0.0) for c in constants: c.transition_matrix = c.transition_matrix / total for kexpr in self: # don't normalize (c * x) kexpr._normalize = False @property def _bin_consts(self) -> List[KernelExpression]: """Return constant expressions for each binary multiplication children with at least 1 constant.""" return [c for k in self if isinstance(k, KernelMul) for c in k._bin_consts] @property def _combiner(self) -> str: return "+" @property def _initial_value(self) -> Union[int, float, Tmat_t]: return 0.0 class KernelMul(NaryKernelExpression): def _combine_transition_matrices(self, t1: Tmat_t, t2: Tmat_t) -> Tmat_t: if issparse(t1): return t1.multiply(t2) if issparse(t2): return t2.multiply(t1) return t1 * t2 def _format(self, formatter: Callable[[KernelExpression], str]) -> str: fmt = super()._format(formatter) if fmt[0] == "~" or not self._bin_consts: return fmt if fmt.startswith("("): fmt = fmt[1:] if fmt.endswith(")"): fmt = fmt[:-1] return fmt @property def _bin_consts(self) -> List[KernelExpression]: """Return all constants if this expression contains only 2 subexpressions.""" if len(self) != 2: return [] return [k for k in self if isinstance(k, Constant)] @property def _split_const(self) -> Tuple[Optional[Constant], Optional[KernelExpression]]: """Return a constant and the other expression, iff this expression is of length 2 and contains a constant.""" if not self._bin_consts: return None, None k1, k2 = self if isinstance(k1, Constant): return k1, k2 return k2, k1 @property def _combiner(self) -> str: return "*" @property def _initial_value(self) -> Union[int, float, Tmat_t]: return 1.0