Source code for ott.math.decomposition

import abc
from typing import (
    Any,
    Dict,
    Generic,
    Hashable,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
    Union,
)

import jax
import jax.experimental.host_callback as hcb
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
import scipy.sparse as sp

try:
  from sksparse import cholmod
except ImportError:
  cholmod = None

__all__ = ["DenseCholeskySolver", "SparseCholeskySolver"]

T = TypeVar("T")


@jax.tree_util.register_pytree_node_class
class CholeskySolver(abc.ABC, Generic[T]):
  """Base class for Cholesky linear solver.

  Args:
    A: Symmetric positive definite matrix of shape ``[n, n]``.
  """

  def __init__(self, A: T):
    self._A = A
    self._L: Optional[T] = None  # Cholesky factor

  def solve(self, b: jnp.ndarray) -> jnp.ndarray:
    """Solve the linear system :math:`A * x = b`.

    Args:
        b: Vector of shape ``[n,]``.

    Returns:
        The solution of shape ``[n,]``.
    """
    return self._solve(self.L, b)

  @abc.abstractmethod
  def _decompose(self, A: T) -> Optional[T]:
    """Decompose matrix ``A`` into Cholesky factor."""

  @abc.abstractmethod
  def _solve(self, L: Optional[T], b: jnp.ndarray) -> jnp.ndarray:
    """Solve a triangular linear system :math:`L * x = b`."""

  @classmethod
  def create(cls, A: Union[T, sp.spmatrix], **kwargs: Any) -> "CholeskySolver":
    """Instantiate sparse or dense Cholesky solver.

    And optionally convert :class:`scipy.sparse.spmatrix` to its
    :mod:`jax` equivalent.

    Args:
      A: Symmetric positive definite matrix of shape ``[n, n]``.
      kwargs: Keyword arguments for the solver initialization.

    Returns:
      Sparse or dense Cholesky solver.
    """
    if isinstance(A, sp.spmatrix):
      A = _scipy_sparse_to_jax(A)
    if isinstance(A, (jesp.CSR, jesp.CSC, jesp.BCOO)):
      return SparseCholeskySolver(A, **kwargs)
    return DenseCholeskySolver(A, **kwargs)

  @property
  def A(self) -> jnp.ndarray:
    """Symmetric positive definite matrix of shape ``[n, n]``."""
    return self._A

  @property
  def L(self) -> Optional[T]:
    """Cholesky factor of :attr:`A`."""
    if self._L is None:
      self._L = self._decompose(self.A)
    return self._L

  def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
    return (self.A, self.L), {}

  @classmethod
  def tree_unflatten(
      cls, aux_data: Mapping[str, Any], children: Sequence[Any]
  ) -> "CholeskySolver":
    A, L = children
    obj = cls(A, **aux_data)
    obj._L = L
    return obj


[docs]@jax.tree_util.register_pytree_node_class class DenseCholeskySolver(CholeskySolver[jnp.ndarray]): """Dense Cholesky solver. Args: A: Symmetric positive definite matrix of shape ``[n, n]``. lower: Whether to compute lower-triangular Cholesky factor. kwargs: Additional keyword arguments, currently ignored. """ def __init__(self, A: T, lower: bool = True, **kwargs: Any): del kwargs super().__init__(A) self._lower = lower def _decompose(self, A: T) -> Optional[T]: return jsp.linalg.cholesky(A, lower=self._lower) def _solve(self, L: Optional[T], b: jnp.ndarray) -> jnp.ndarray: return jsp.linalg.solve_triangular(L, b, lower=self._lower) def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: children, aux_data = super().tree_flatten() aux_data["lower"] = self._lower return children, aux_data
[docs]@jax.tree_util.register_pytree_node_class class SparseCholeskySolver( CholeskySolver[Union[jesp.CSR, jesp.CSC, jesp.COO, jesp.BCOO]] ): r"""Sparse Cholesky solver using :func:`jax.experimental.host_callback.call`. Uses the CHOLMOD :cite:`cholmod:08` bindings from :mod:`sksparse.cholmod`. Args: A: Symmetric positive definite matrix of shape ``[n, n]``. beta: Decompose :math:`A + \beta * I` instead of :math:`A`. key: Key used to cache :class:`sksparse.cholmod.Factor`. This key **must** be unique to ``A`` to achieve correct results. If `None`, use :func:`hash` of this object. kwargs: Keyword arguments for :func:`sksparse.cholmod.cholesky`. """ # TODO(michalk8): in the future, define a jax primitive + use CHOLMOD directly _FACTOR_CACHE = {} def __init__( self, A: T, beta: float = 0.0, key: Optional[Hashable] = None, **kwargs: Any, ): if cholmod is None: raise ImportError( "Unable to import scikit-sparse. " "Please install it as `pip install scikit-sparse`." ) super().__init__(A) self._key = key self._beta = beta self._kwargs = kwargs def _host_decompose(self, A: T) -> None: # use float64 because CHOLMOD uses it internally # convert to CSC explicitly for efficiency/to avoid warnings mat = _jax_sparse_to_scipy(A, sum_duplicates=True, dtype=float).tocsc() self._FACTOR_CACHE[hash(self)] = cholmod.cholesky( mat, beta=self._beta, **self._kwargs ) def _decompose(self, A: T) -> Optional[T]: return hcb.call(self._host_decompose, A, result_shape=None) def _host_solve(self, b: jnp.ndarray) -> jnp.ndarray: factor = self._FACTOR_CACHE[hash(self)] return factor.solve_A(np.asarray(b, dtype=float)) def _solve(self, _: Optional[T], b: jnp.ndarray) -> jnp.ndarray: return hcb.call(self._host_solve, b, result_shape=b) @property def L(self) -> None: """Compute the lower-triangular factor of :attr:`A` and cache the result. The factor is not returned, but is used in subsequent :meth:`solve` calls. """ if hash(self) not in self._FACTOR_CACHE: self._decompose(jax.lax.stop_gradient(self.A))
[docs] @classmethod def clear_factor_cache(cls) -> None: """Clear the :class:`sksparse.cholmod.Factor` cache.""" cls._FACTOR_CACHE.clear()
def __hash__(self) -> int: return object.__hash__(self) if self._key is None else self._key def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: children, aux_data = super().tree_flatten() return children, { **aux_data, "beta": self._beta, "key": self._key, **self._kwargs }
def _jax_sparse_to_scipy( A: Union[jesp.CSR, jesp.CSC, jesp.COO], sum_duplicates: bool = False, **kwargs: Any ) -> sp.spmatrix: toarr = np.asarray if isinstance(A, (jesp.CSR, jesp.CSC)): data, indices, indptr = toarr(A.data), toarr(A.indices), toarr(A.indptr) return sp.csr_matrix((data, indices, indptr), **kwargs) if isinstance(A, jesp.COO): row, col, data = toarr(A.row), toarr(A.col), toarr(A.data) return sp.coo_matrix((data, (row, col)), **kwargs) if isinstance(A, jesp.BCOO): assert A.indices.ndim == 2, "Only 2D batched COO matrix is supported." row, col = A.indices[:, 0], A.indices[:, 1] data, row, col = toarr(A.data), toarr(row), toarr(col) mat = sp.coo_matrix((data, (row, col)), **kwargs) if not sum_duplicates: return mat # the original matrix can contain duplicates => shape will be wrong # we optionally correct it here mat.sum_duplicates() mat.eliminate_zeros() return sp.coo_matrix((mat.data, (mat.row, mat.col)), **kwargs) raise TypeError(type(A)) def _scipy_sparse_to_jax(A: sp.spmatrix, **kwargs: Any) -> Union[jesp.CSR, jesp.CSC, jesp.BCOO]: toarr = jnp.asarray kwargs["shape"] = A.shape if sp.isspmatrix_csr(A): return jesp.CSR((toarr(A.data), toarr(A.indices), toarr(A.indptr)), **kwargs) if sp.isspmatrix_csc(A): return jesp.CSC((toarr(A.data), toarr(A.indices), toarr(A.indptr)), **kwargs) if sp.isspmatrix_coo(A): # prefer BCOO since it's more feature-complete data, indices = toarr(A.data), jnp.c_[toarr(A.row), toarr(A.col)] return jesp.BCOO((data, indices), **kwargs) raise TypeError(type(A))