Source code for ott.initializers.quadratic.initializers

import abc
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple

import jax
import jax.numpy as jnp

from ott.geometry import geometry

if TYPE_CHECKING:
  from ott.initializers.linear import initializers_lr
  from ott.problems.linear import linear_problem
  from ott.problems.quadratic import quadratic_problem

__all__ = ["QuadraticInitializer", "LRQuadraticInitializer"]


@jax.tree_util.register_pytree_node_class
class BaseQuadraticInitializer(abc.ABC):
  """Base class for quadratic initializers.

  Args:
    kwargs: Keyword arguments.
  """

  def __init__(self, **kwargs: Any):
    self._kwargs = kwargs

  def __call__(
      self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any
  ) -> 'linear_problem.LinearProblem':
    """Compute the initial linearization of a quadratic problem.

    Args:
      quad_prob: Quadratic problem to linearize.
      kwargs: Additional keyword arguments.

    Returns:
      Linear problem.
    """
    from ott.problems.linear import linear_problem

    n, m = quad_prob.geom_xx.shape[0], quad_prob.geom_yy.shape[0]
    geom = self._create_geometry(quad_prob, **kwargs)
    assert geom.shape == (n, m), f"Expected geometry of shape `{n, m}`, " \
                                 f"found `{geom.shape}`."
    return linear_problem.LinearProblem(
        geom,
        a=quad_prob.a,
        b=quad_prob.b,
        tau_a=quad_prob.tau_a,
        tau_b=quad_prob.tau_b
    )

  @abc.abstractmethod
  def _create_geometry(
      self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any
  ) -> geometry.Geometry:
    """Compute initial geometry for linearization.

    Args:
      quad_problem: Quadratic problem.
      kwargs: Additional keyword arguments.

    Returns:
      Geometry used to initialize the linearized problem.
    """

  def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
    return [], self._kwargs

  @classmethod
  def tree_unflatten(
      cls, aux_data: Dict[str, Any], children: Sequence[Any]
  ) -> "BaseQuadraticInitializer":
    return cls(*children, **aux_data)


[docs]class QuadraticInitializer(BaseQuadraticInitializer): """Initialize a linear problem locally around a naive initializer ab'. If the problem is balanced (``tau_a = 1`` and ``tau_b = 1``), the equation of the cost follows eq. 6, p. 1 of :cite:`peyre:16`. If the problem is unbalanced (``tau_a < 1`` or ``tau_b < 1``), there are two possible cases. A first possibility is to introduce a quadratic KL divergence on the marginals in the objective as done in :cite:`sejourne:21` (``gw_unbalanced_correction = True``), which in turns modifies the local cost matrix. Alternatively, it could be possible to leave the formulation of the local cost unchanged, i.e. follow eq. 6, p. 1 of :cite:`peyre:16` (``gw_unbalanced_correction = False``) and include the unbalanced terms at the level of the linear problem only. Let :math:`P` [num_a, num_b] be the transport matrix, `cost_xx` is the cost matrix of `geom_xx` and `cost_yy` is the cost matrix of `geom_yy`. `left_x` and `right_y` depend on the loss chosen for GW. `gw_unbalanced_correction` is an boolean indicating whether or not the unbalanced correction applies. The equation of the local cost can be written as: `cost_matrix` = `marginal_dep_term` + `left_x`(`cost_xx`) :math:`P` `right_y`(`cost_yy`):math:`^T` + `unbalanced_correction` * `gw_unbalanced_correction` When working with the fused problem, a linear term is added to the cost matrix: `cost_matrix` += `fused_penalty` * `geom_xy.cost_matrix` """ def _create_geometry( self, quad_prob: 'quadratic_problem.QuadraticProblem', *, epsilon: float, **kwargs: Any ) -> geometry.Geometry: """Compute initial geometry for linearization. Args: quad_prob: Quadratic OT problem. epsilon: Epsilon regularization. kwargs: Additional keyword arguments, unused. Returns: The initial geometry used to initialize the linearized problem. """ from ott.problems.quadratic import quadratic_problem del kwargs marginal_cost = quad_prob.marginal_dependent_cost(quad_prob.a, quad_prob.b) geom_xx, geom_yy = quad_prob.geom_xx, quad_prob.geom_yy h1, h2 = quad_prob.quad_loss tmp1 = quadratic_problem.apply_cost(geom_xx, quad_prob.a, axis=1, fn=h1) tmp2 = quadratic_problem.apply_cost(geom_yy, quad_prob.b, axis=1, fn=h2) tmp = jnp.outer(tmp1, tmp2) if quad_prob.is_balanced: cost_matrix = marginal_cost.cost_matrix - tmp else: # initialize epsilon for Unbalanced GW according to Sejourne et. al (2021) init_transport = jnp.outer(quad_prob.a, quad_prob.b) marginal_1, marginal_2 = init_transport.sum(1), init_transport.sum(0) epsilon = quadratic_problem.update_epsilon_unbalanced( epsilon=epsilon, transport_mass=marginal_1.sum() ) unbalanced_correction = quad_prob.cost_unbalanced_correction( init_transport, marginal_1, marginal_2, epsilon=epsilon ) cost_matrix = marginal_cost.cost_matrix - tmp + unbalanced_correction cost_matrix += quad_prob.fused_penalty * quad_prob._fused_cost_matrix return geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon)
[docs]class LRQuadraticInitializer(BaseQuadraticInitializer): """Wrapper that wraps low-rank Sinkhorn initializers. Args: lr_linear_initializer: Low-rank linear initializer. """ def __init__(self, lr_linear_initializer: 'initializers_lr.LRInitializer'): super().__init__() self._linear_lr_initializer = lr_linear_initializer def _create_geometry( self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any ) -> geometry.Geometry: """Compute initial geometry for linearization. Args: quad_prob: Quadratic OT problem. kwargs: Keyword arguments for :meth:`~ott.initializers.linear.initializers_lr.LRInitializer.__call__`. Returns: The initial geometry used to initialize a linear problem. """ from ott.solvers.linear import sinkhorn_lr q, r, g = self._linear_lr_initializer(quad_prob, **kwargs) tmp_out = sinkhorn_lr.LRSinkhornOutput( q=q, r=r, g=g, costs=None, errors=None, ot_prob=None ) return quad_prob.update_lr_geom(tmp_out) @property def rank(self) -> int: """Rank of the transport matrix factorization.""" return self._linear_lr_initializer.rank def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: children, aux_data = super().tree_flatten() return children + [self._linear_lr_initializer], aux_data