Source code for ott.geometry.graph

from typing import Any, Dict, Literal, Optional, Sequence, Tuple, Union

import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp

from ott.geometry import geometry
from ott.math import decomposition, fixed_point_loop
from ott.math import utils as mu

__all__ = ["Graph"]

Sparse_t = Union[jesp.CSR, jesp.CSC, jesp.COO, jesp.BCOO]


[docs]@jax.tree_util.register_pytree_node_class class Graph(geometry.Geometry): r"""Graph distance approximation using heat kernel :cite:`heitz:21,crane:13`. Approximates the heat kernel for large ``n_steps``, which for small ``t`` approximates the geodesic exponential kernel :math:`e^{\frac{-d(x, y)^2}{t}}`. For sparse graphs, :mod:`sksparse.cholmod` is required to compute the Cholesky decomposition. Differentiating w.r.t. the edge weights is currently possible only when the graph is represented as a dense adjacency matrix. Args: graph: Graph represented as an adjacency matrix of shape ``[n, n]``. If `None`, the symmetric graph Laplacian has to be specified. laplacian: Symmetric graph Laplacian. The check for symmetry is **NOT** performed. If `None`, the graph has to be specified instead. t: Constant used when approximating the geodesic exponential kernel. If `None`, use :math:`\frac{1}{|E|} \sum_{(u, v) \in E} weight(u, v)` :cite:`crane:13`. In this case, the ``graph`` must be specified and the edge weights are all assumed to be positive. n_steps: Maximum number of steps used to approximate the heat kernel. numerical_scheme: Numerical scheme used to solve the heat diffusion. directed: Whether the ``graph`` is directed. If not, it will be made undirected as :math:`G + G^T`. This parameter is ignored when directly passing the Laplacian, which is assumed to be symmetric. normalize: Whether to normalize the Laplacian as :math:`L^{sym} = \left(D^+\right)^{\frac{1}{2}} L \left(D^+\right)^{\frac{1}{2}}`, where :math:`L` is the unnormalized Laplacian and :math:`D` the degree matrix. tol: Relative tolerance with respect to the Hilbert metric, see :cite:`peyre:19`, Remark 4.12. Used when iteratively updating scalings. If negative, this option is ignored and only ``n_steps`` is used. kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. """ def __init__( self, graph: Optional[Union[jnp.ndarray, jesp.BCOO]] = None, laplacian: Optional[Union[jnp.ndarray, Sparse_t]] = None, t: Optional[float] = 1e-3, n_steps: int = 100, numerical_scheme: Literal["backward_euler", "crank_nicolson"] = "backward_euler", directed: bool = False, normalize: bool = False, tol: float = -1., **kwargs: Any ): assert ((graph is None and laplacian is not None) or (laplacian is None and graph is not None)), \ "Please provide a graph or a symmetric graph Laplacian." # arbitrary epsilon; can't use `None` as `mean_cost_matrix` would be used super().__init__(epsilon=1., **kwargs) self._graph = graph self._lap = laplacian self._solver: Optional[decomposition.CholeskySolver] = None self._t = t self.n_steps = n_steps self.numerical_scheme = numerical_scheme self.directed = directed self.normalize = normalize self._tol = tol
[docs] def apply_kernel( self, scaling: jnp.ndarray, eps: Optional[float] = None, axis: int = 0, ) -> jnp.ndarray: def conf_fn( iteration: int, solver_lap: Tuple[decomposition.CholeskySolver, Optional[Union[jnp.ndarray, Sparse_t]]], old_new: Tuple[jnp.ndarray, jnp.ndarray] ) -> bool: del iteration, solver_lap x_old, x_new = old_new x_old, x_new = mu.safe_log(x_old), mu.safe_log(x_new) # center x_old, x_new = x_old - jnp.nanmax(x_old), x_new - jnp.nanmax(x_new) # Hilbert metric, see Remark 4.12 in `Computational Optimal Transport` f = x_new - x_old return (jnp.nanmax(f) - jnp.nanmin(f)) > self._tol def body_fn( iteration: int, solver_lap: Tuple[decomposition.CholeskySolver, Optional[Union[jnp.ndarray, Sparse_t]]], old_new: Tuple[jnp.ndarray, jnp.ndarray], compute_errors: bool ) -> Tuple[jnp.ndarray, jnp.ndarray]: del iteration, compute_errors solver, scaled_lap = solver_lap _, b = old_new if self.numerical_scheme == "crank_nicolson": # below is a preferred way of specifying the update (albeit more FLOPS), # as CSR/CSC/COO matrices don't support adding a diagonal matrix now: # b' = (2 * I - M) @ b = (2 * I - (I + c * L)) @ b = (I - c * L) @ b b = b - scaled_lap @ b return b, solver.solve(b) # eps we cannot use since it would require a re-solve # axis we can ignore since the matrix is symmetric del eps, axis force_scan = self._tol < 0. fixpoint_fn = ( fixed_point_loop.fixpoint_iter if force_scan else fixed_point_loop.fixpoint_iter_backprop ) state = (jnp.full_like(scaling, jnp.nan), scaling) if self.numerical_scheme == "crank_nicolson": constants = self.solver, self._scaled_laplacian else: constants = self.solver, None return fixpoint_fn( cond_fn=(lambda *_, **__: True) if force_scan else conf_fn, body_fn=body_fn, min_iterations=self.n_steps if force_scan else 1, max_iterations=self.n_steps, inner_iterations=1, constants=constants, state=state, )[1]
[docs] def apply_transport_from_scalings( self, u: jnp.ndarray, v: jnp.ndarray, vec: jnp.ndarray, axis: int = 0 ) -> jnp.ndarray: def body_fn(carry: None, vec: jnp.ndarray) -> jnp.ndarray: if axis == 1: return carry, u * self.apply_kernel(v * vec, axis=axis) return carry, v * self.apply_kernel(u * vec, axis=axis) if not self.is_sparse: return super().apply_transport_from_scalings(u, v, vec, axis=axis) # we solve the triangular system's on host, but # batching rules are implemented only for `id_tap`, not for `call` if vec.ndim == 1: _, res = jax.lax.scan(body_fn, None, vec[None, :]) return res[0, :] _, res = jax.lax.scan(body_fn, None, vec) return res
@property def kernel_matrix(self) -> jnp.ndarray: n, _ = self.shape kernel = self.apply_kernel(jnp.eye(n)) # force symmetry because of numerical imprecisions # happens when `numerical_scheme='backward_euler'` and small `t` return (kernel + kernel.T) * .5 @property def cost_matrix(self) -> jnp.ndarray: return -self.t * mu.safe_log(self.kernel_matrix) @property def laplacian(self) -> Union[jnp.ndarray, Sparse_t]: """The (normalized) graph Laplacian.""" return self._norm_laplacian if self.normalize else self._laplacian def _degree_matrix(self, *, inv_sqrt: bool = False) -> Union[jnp.ndarray, Sparse_t]: if not self.is_sparse: data = self.graph.sum(1) if inv_sqrt: data = jnp.where(data > 0., 1. / jnp.sqrt(data), 0.) return jnp.diag(data) n, _ = self.shape data, ixs = self.graph.sum(1).todense(), jnp.arange(n) if inv_sqrt: data = jnp.where(data > 0., 1. / jnp.sqrt(data), 0.) return jesp.BCOO((data, jnp.c_[ixs, ixs]), shape=(n, n)) @property def _laplacian(self) -> Union[jnp.ndarray, Sparse_t]: if self._lap is not None: return self._lap # in the sparse case, we don't sum duplicates here because # we need to know `nnz` a priori for JIT (could be exposed in `__init__`) # instead, `ott.math.decomposition._jax_sparse_to_scipy` handles it on host return self._degree_matrix() - self.graph @property def _norm_laplacian(self) -> Union[jnp.ndarray, Sparse_t]: # assumes symmetric Laplacian, as mentioned in `__init__` lap = self._laplacian inv_sqrt_deg = self._degree_matrix(inv_sqrt=True) if not self.is_sparse: return inv_sqrt_deg @ lap @ inv_sqrt_deg inv_sqrt_deg = inv_sqrt_deg.data # (n,) # much faster than doing sparse MM return inv_sqrt_deg[:, None] * lap * inv_sqrt_deg[None, :] @property def t(self) -> float: """Constant used when approximating the geodesic exponential kernel.""" if self._t is None: graph = self.graph assert graph is not None, "No graph was specified." if self.is_sparse: return jnp.mean(graph.data) ** 2 return (jnp.sum(graph) / jnp.sum(graph > 0.)) ** 2 return self._t @property def _scale(self) -> float: """Constant used to scale the Laplacian.""" if self.numerical_scheme == "backward_euler": return self.t / (4. * self.n_steps) if self.numerical_scheme == "crank_nicolson": return self.t / (2. * self.n_steps) raise NotImplementedError( f"Numerical scheme `{self.numerical_scheme}` is not implemented." ) @property def _scaled_laplacian(self) -> Union[float, jnp.ndarray, Sparse_t]: """Laplacian scaled by a constant, depending on the numerical scheme.""" if self.is_sparse: return mu.sparse_scale(self._scale, self.laplacian) return self._scale * self.laplacian @property def _M(self) -> Union[jnp.ndarray, Sparse_t]: n, _ = self.shape scaled_lap = self._scaled_laplacian # CHOLMOD supports solving `A + beta * I`, we set `beta = 1.0` # when instantiating the solver return scaled_lap if self.is_sparse else scaled_lap + jnp.eye(n) @property def solver(self) -> decomposition.CholeskySolver: """Instantiate the Cholesky solver and compute the factorization.""" if self._solver is None: # key/beta only used for sparse solver self._solver = decomposition.CholeskySolver.create( self._M, beta=1., key=hash(self) ) # compute the factorization to avoid tracer leaks in `apply_kernel` # due to the scan/while loop _ = self._solver.L return self._solver @property def shape(self) -> Tuple[int, int]: arr = self._graph if self._graph is not None else self._lap return arr.shape @property def is_sparse(self) -> bool: """Whether :attr:`graph` or :attr:`laplacian` is sparse.""" if self._lap is not None: return isinstance(self._lap, Sparse_t.__args__) if isinstance(self._graph, (jesp.CSR, jesp.CSC, jesp.COO)): raise NotImplementedError("Graph must be specified in `BCOO` format.") return isinstance(self._graph, jesp.BCOO) @property def graph(self) -> Optional[Union[jnp.ndarray, jesp.BCOO]]: """The underlying undirected graph as an adjacency matrix, if provided.""" if self._graph is None: return None return (self._graph + self._graph.T) if self.directed else self._graph @property def is_symmetric(self) -> bool: # there may be some numerical imprecisions, but it should be symmetric return True @property def dtype(self) -> jnp.dtype: return self._graph.dtype # TODO(michalk8): in future, use mixins for lse/kernel mode
[docs] def transport_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray ) -> jnp.ndarray: """Not implemented.""" raise ValueError("Not implemented.")
[docs] def apply_transport_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray, vec: jnp.ndarray, axis: int = 0 ) -> jnp.ndarray: """Since applying from potentials is not feasible in grids, use scalings.""" u, v = self.scaling_from_potential(f), self.scaling_from_potential(g) return self.apply_transport_from_scalings(u, v, vec, axis=axis)
[docs] def marginal_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray, axis: int = 0, ) -> jnp.ndarray: """Not implemented.""" raise ValueError("Not implemented.")
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: return [self._graph, self._lap, self.solver], { "t": self._t, "n_steps": self.n_steps, "numerical_scheme": self.numerical_scheme, "directed": self.directed, "normalize": self.normalize, "tol": self._tol, **self._kwargs, } @classmethod def tree_unflatten( cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "Graph": graph, laplacian, solver = children obj = cls(graph=graph, laplacian=laplacian, **aux_data) obj._solver = solver return obj