# Source code for cellrank.tl.kernels._palantir_kernel

# -*- coding: utf-8 -*-
"""Palantir kernel module."""
from copy import copy

import numpy as np

from cellrank import logging as logg
from cellrank.ul._docs import d
from cellrank.tl._utils import _bias_knn, _connected
from cellrank.tl.kernels import Kernel
from cellrank.tl._constants import Direction
from cellrank.tl.kernels._base_kernel import (
_LOG_USING_CACHE,
_ERROR_EMPTY_CACHE_MSG,
AnnData,
_dtype,
)

[docs]@d.dedent
class PalantirKernel(Kernel):
"""
Kernel which computes transition probabilities in a similar way to *Palantir*, see [Setty19]_.

*Palantir* computes a KNN graph in gene expression space and a pseudotime, which it then uses to direct the edges of
the KNN graph, such that they are more likely to point into the direction of increasing pseudotime. To avoid
disconnecting the graph, it does not remove all edges that point into the direction of decreasing pseudotime
but keeps the ones that point to nodes inside a close radius. This radius is chosen according to the local density.

The implementation presented here won't exactly reproduce the original *Palantir* algorithm (see below)
but the results are qualitatively very similar.

%(density_correction)s

Parameters
----------
%(backward)s
time_key
Key in :paramref:adata .obs where the pseudotime is stored.
compute_cond_num
Whether to compute condition number of the transition matrix. Note that this might be costly,
since it does not use sparse implementation.
"""

def __init__(
self,
backward: bool = False,
time_key: str = "dpt_pseudotime",
compute_cond_num: bool = False,
check_connectivity: bool = False,
):
super().__init__(
backward=backward,
time_key=time_key,
compute_cond_num=compute_cond_num,
check_connectivity=check_connectivity,
)
self._time_key = time_key

time_key = kwargs.pop("time_key", "dpt_pseudotime")
raise KeyError(f"Could not find time key in adata.obs[{time_key!r}].")

if np.any(np.isnan(self._pseudotime)):
raise ValueError("Encountered NaN values in pseudotime.")

logg.debug("Clipping the pseudotime to 0-1 range")
self._pseudotime = np.clip(self._pseudotime, 0, 1)

[docs]    def compute_transition_matrix(
self, k: int = 3, density_normalize: bool = True
) -> "PalantirKernel":
"""
Compute transition matrix based on KNN graph and pseudotemporal ordering.

This is a re-implementation of the Palantir algorithm by [Setty19]_.
Note that this won't exactly reproduce the original Palantir results, for three reasons:

- Palantir computes the KNN graph in a scaled space of diffusion components.
- Palantir uses its own pseudotime to bias the KNN graph which is not implemented here.
- Palantir uses a slightly different mechanism to ensure the graph remains connected when removing edges
that point into the "pseudotime past".

If you would like to reproduce the original results, please use the original Palantir algorithm.

Parameters
----------
k
Number of neighbors to keep for each node, regardless of pseudotime.
This is done to ensure that the graph remains connected.
density_normalize
Whether or not to use the underlying KNN graph for density normalization.

Returns
-------
:class:cellrank.tl.kernels.PalantirKernel
Makes :paramref:transition_matrix available.
"""

start = logg.info("Computing transition matrix based on Palantir-like kernel")

# get the connectivities and number of neighbors
n_neighbors = (
.get("params", {})
.get("n_neighbors", None)
)
if n_neighbors is None:
logg.warning(
"Could not find 'n_neighbors' in adata.uns['neighbors']['params']. Using an estimate"
)
n_neighbors = np.min(self._conn.sum(1))

params = dict(k=k, dnorm=density_normalize, n_neighs=n_neighbors)  # noqa
if params == self._params:
assert self.transition_matrix is not None, _ERROR_EMPTY_CACHE_MSG
logg.debug(_LOG_USING_CACHE)
logg.info("    Finish", time=start)
return self

self._params = params

# handle backward case and run biasing function
pseudotime = (
np.max(self.pseudotime) - self.pseudotime
if self._direction == Direction.BACKWARD
else self.pseudotime
)
biased_conn = _bias_knn(
conn=self._conn, pseudotime=pseudotime, n_neighbors=n_neighbors, k=k
).astype(_dtype)
# make sure the biased graph is still connected
if not _connected(biased_conn):
logg.warning("Biased KNN graph is disconnected")

self._compute_transition_matrix(
matrix=biased_conn, density_normalize=density_normalize
)

logg.info("    Finish", time=start)

return self

@property
def pseudotime(self) -> np.array:
"""Pseudotemporal ordering of cells."""
return self._pseudotime

[docs]    def copy(self) -> "PalantirKernel":
"""Return a copy of self."""