Source code for ott.problems.quadratic.quadratic_problem

# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes defining OT problem(s) (objective function + utilities)."""

from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import jax.scipy as jsp

from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_costs
from ott.types import Transport

if TYPE_CHECKING:
  from ott.solvers.linear import sinkhorn_lr

__all__ = ["QuadraticProblem"]


[docs]@jax.tree_util.register_pytree_node_class class QuadraticProblem: r"""Quadratic OT problem. The quadratic loss of a single OT matrix is assumed to have the form given in :cite:`peyre:16`, eq. 4. The two geometries below parameterize matrices :math:`C` and :math:`\bar{C}` in that equation. The function :math:`L` (of two real values) in that equation is assumed to match the form given in eq. 5., with our notations: .. math:: L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y) Args: geom_xx: Ground geometry of the first space. geom_yy: Ground geometry of the second space. geom_xy: Geometry defining the linear penalty term for Fused Gromov-Wasserstein. If `None`, the problem reduces to a plain Gromov-Wasserstein problem. fused_penalty: multiplier of the linear term in Fused Gromov-Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem. Ignored if ``geom_xy`` is not specified. scale_cost: option to rescale the cost matrices: - if :obj:`True`, use the default for each geometry. - if :obj:`False`, keep the original scaling in geometries. - if :class:`str`, use a specific method available in :class:`~ott.geometry.geometry.Geometry` or :class:`~ott.geometry.pointcloud.PointCloud`. - if :obj:`None`, do not scale the cost matrices. a: array representing the probability weights of the samples from ``geom_xx``. If `None`, it will be uniform. b: array representing the probability weights of the samples from ``geom_yy``. If `None`, it will be uniform. loss: a 2-tuple of 2-tuples of Callable. The first tuple is the linear part of the loss. The second one is the quadratic part (quad1, quad2). By default, the loss is set as the 4 functions representing the squared Euclidean loss, and this property is taken advantage of in subsequent computations. Alternatively, KL loss can be specified in no less optimized way. tau_a: if `< 1.0`, defines how much unbalanced the problem is on the first marginal. tau_b: if `< 1.0`, defines how much unbalanced the problem is on the second marginal. gw_unbalanced_correction: Whether the unbalanced version of :cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b`` only affect the inner Sinkhorn loop. ranks: Ranks of the cost matrices, see :meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry`. Used when geometries are *not* :class:`~ott.geometry.pointcloud.PointCloud` with `'sqeucl'` cost function. If `-1`, the geometries will not be converted to low-rank. If :class:`tuple`, it specifies the ranks of ``geom_xx``, ``geom_yy`` and ``geom_xy``, respectively. If :class:`int`, rank is shared across all geometries. tolerances: Tolerances used when converting geometries to low-rank. Used when geometries are not :class:`~ott.geometry.pointcloud.PointCloud` with `'sqeucl'` cost. If :class:`float`, it is shared across all geometries. """ def __init__( self, geom_xx: geometry.Geometry, geom_yy: geometry.Geometry, geom_xy: Optional[geometry.Geometry] = None, fused_penalty: float = 1.0, scale_cost: Optional[Union[bool, float, str]] = False, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, loss: Union[Literal['sqeucl', 'kl'], quadratic_costs.GWLoss] = 'sqeucl', tau_a: Optional[float] = 1.0, tau_b: Optional[float] = 1.0, gw_unbalanced_correction: bool = True, ranks: Union[int, Tuple[int, ...]] = -1, tolerances: Union[float, Tuple[float, ...]] = 1e-2, ): assert fused_penalty > 0, fused_penalty self._geom_xx = geom_xx._set_scale_cost(scale_cost) self._geom_yy = geom_yy._set_scale_cost(scale_cost) self._geom_xy = ( None if geom_xy is None else geom_xy._set_scale_cost(scale_cost) ) self.fused_penalty = fused_penalty self.scale_cost = scale_cost self._a = a self._b = b self.tau_a = tau_a self.tau_b = tau_b self.gw_unbalanced_correction = gw_unbalanced_correction self.ranks = ranks self.tolerances = tolerances self._loss_name = loss if self._loss_name == 'sqeucl': self.loss = quadratic_costs.make_square_loss() elif loss == 'kl': self.loss = quadratic_costs.make_kl_loss() else: self.loss = loss
[docs] def marginal_dependent_cost( self, marginal_1: jnp.ndarray, marginal_2: jnp.ndarray ) -> low_rank.LRCGeometry: r"""Initialise cost term that depends on the marginals of the transport. Uses the first term in eq. 6, p. 1 of :cite:`peyre:16`. Let :math:`p` [num_a,] be the marginal of the transport matrix for samples from `geom_xx` and :math:`q` [num_b,] be the marginal of the transport matrix for samples from `geom_yy`. `cost_xx` (resp. `cost_yy`) is the cost matrix of `geom_xx` (resp. `geom_yy`). The cost term that depends on these marginals can be written as: `marginal_dep_term` = `lin1`(`cost_xx`) :math:`p \mathbb{1}_{num_b}^T` + (`lin2`(`cost_yy`) :math:`q \mathbb{1}_{num_a}^T)^T` Args: marginal_1: jnp.ndarray<float>[num_a,], marginal of the transport matrix for samples from geom_xx marginal_2: jnp.ndarray<float>[num_b,], marginal of the transport matrix for samples from geom_yy Returns: Low-rank geometry. """ if self._loss_name == 'sqeucl': # quadratic apply, efficient for LR tmp1 = self.geom_xx.apply_square_cost(marginal_1, axis=1) tmp2 = self.geom_yy.apply_square_cost(marginal_2, axis=1) else: f1, f2 = self.linear_loss tmp1 = apply_cost(self.geom_xx, marginal_1, axis=1, fn=f1) tmp2 = apply_cost(self.geom_yy, marginal_2, axis=1, fn=f2) x_term = jnp.concatenate((tmp1, jnp.ones_like(tmp1)), axis=1) y_term = jnp.concatenate((jnp.ones_like(tmp2), tmp2), axis=1) return low_rank.LRCGeometry(cost_1=x_term, cost_2=y_term)
[docs] def cost_unbalanced_correction( self, transport_matrix: jnp.ndarray, marginal_1: jnp.ndarray, marginal_2: jnp.ndarray, epsilon: epsilon_scheduler.Epsilon, ) -> float: r"""Calculate cost term from the quadratic divergence when unbalanced. In the unbalanced setting (``tau_a < 1.0 or tau_b < 1.0``), the introduction of a quadratic divergence :cite:`sejourne:21` adds a term to the GW local cost. Let :math:`a` [num_a,] be the target weights for samples from geom_xx and :math:`b` [num_b,] be the target weights for samples from `geom_yy`. Let :math:`P` [num_a, num_b] be the transport matrix, :math:`P1` the first marginal and :math:`P^T1` the second marginal. The term of the cost matrix coming from the quadratic KL in the unbalanced case can be written as: `unbalanced_correction_term` = :math:`tau_a / (1 - tau_a) * \sum(KL(P1|a))` :math:`+ tau_b / (1 - tau_b) * \sum(KL(P^T1|b))` :math:`+ epsilon * \sum(KL(P|ab'))` Args: transport_matrix: jnp.ndarray<float>[num_a, num_b], transport matrix. marginal_1: jnp.ndarray<float>[num_a,], marginal of the transport matrix for samples from :attr:`geom_xx`. marginal_2: jnp.ndarray<float>[num_b,], marginal of the transport matrix for samples from :attr:`geom_yy`. epsilon: regulariser. Returns: The cost term. """ def regularizer(tau: float) -> float: return eps * tau / (1.0 - tau) eps = epsilon._target_init marginal_1loga = jsp.special.xlogy(marginal_1, self.a).sum() marginal_2logb = jsp.special.xlogy(marginal_2, self.b).sum() cost = eps * jsp.special.xlogy(transport_matrix, transport_matrix).sum() if self.tau_a != 1.0: cost += regularizer( self.tau_a ) * (-jsp.special.entr(marginal_1).sum() - marginal_1loga) if self.tau_b != 1.0: cost += regularizer( self.tau_b ) * (-jsp.special.entr(marginal_2).sum() - marginal_2logb) return cost
# TODO(michalk8): highly coupled to the pre-defined initializer, refactor
[docs] def init_transport_mass(self) -> float: """Initialise the transport mass. Returns: The sum of the elements of the normalised transport matrix. """ a = jax.lax.stop_gradient(self.a) b = jax.lax.stop_gradient(self.b) return a.sum() * b.sum()
[docs] def update_lr_geom( self, lr_sink: 'sinkhorn_lr.LRSinkhornOutput' ) -> geometry.Geometry: """Recompute (possibly LRC) linearization using LR Sinkhorn output.""" marginal_1 = lr_sink.marginal(1) marginal_2 = lr_sink.marginal(0) marginal_cost = self.marginal_dependent_cost(marginal_1, marginal_2) # Extract factors from LR Sinkhorn output q, r, inv_sqg = lr_sink.q, lr_sink.r, 1.0 / jnp.sqrt(lr_sink.g) # Distribute middle marginal evenly across both factors. q, r = q * inv_sqg[None, :], r * inv_sqg[None, :] # Handle LRC Geometry case. h1, h2 = self.quad_loss tmp1 = apply_cost(self.geom_xx, q, axis=1, fn=h1) tmp2 = apply_cost(self.geom_yy, r, axis=1, fn=h2) if self.is_low_rank: geom = low_rank.LRCGeometry(cost_1=tmp1, cost_2=-tmp2) + marginal_cost if self.is_fused: geom = geom + self.geom_xy else: cost_matrix = marginal_cost.cost_matrix - jnp.dot(tmp1, tmp2.T) cost_matrix += self.fused_penalty * self._fused_cost_matrix geom = geometry.Geometry(cost_matrix=cost_matrix) return geom
[docs] def update_linearization( self, transport: Transport, epsilon: Optional[Union[epsilon_scheduler.Epsilon, float]] = None, old_transport_mass: float = 1.0 ) -> linear_problem.LinearProblem: """Update linearization of GW problem by updating cost matrix. If the problem is balanced (``tau_a = 1.0 and tau_b = 1.0``), the equation follows eq. 6, p. 1 of :cite:`peyre:16`. If the problem is unbalanced (``tau_a < 1.0 or tau_b < 1.0``), two cases are possible, as explained in :meth:`init_linearization` above. Finally, it is also possible to consider a Fused Gromov-Wasserstein problem. Details about the resulting cost matrix are also given in :meth:`init_linearization`. Args: transport: Solution of the linearization of the quadratic problem. epsilon: An epsilon scheduler or a float passed on to the linearization. old_transport_mass: Sum of the elements of the transport matrix at the previous iteration. Returns: Updated linear OT problem, a new local linearization of GW problem. """ rescale_factor = 1.0 unbalanced_correction = 0.0 if not self.is_balanced: marginal_1 = transport.marginal(axis=1) transport_mass = jax.lax.stop_gradient(marginal_1.sum()) rescale_factor = jnp.sqrt(old_transport_mass / transport_mass) marginal_1 = transport.marginal(axis=1) * rescale_factor marginal_2 = transport.marginal(axis=0) * rescale_factor marginal_cost = self.marginal_dependent_cost(marginal_1, marginal_2) transport_matrix = transport.matrix * rescale_factor if not self.is_balanced: # Rescales transport for Unbalanced GW according to Sejourne et al (2021). transport_mass = jax.lax.stop_gradient(marginal_1.sum()) epsilon = update_epsilon_unbalanced(epsilon, transport_mass) unbalanced_correction = self.cost_unbalanced_correction( transport_matrix, marginal_1, marginal_2, epsilon ) h1, h2 = self.quad_loss tmp = apply_cost(self.geom_xx, transport_matrix, axis=1, fn=h1) tmp = apply_cost(self.geom_yy, tmp.T, axis=1, fn=h2).T cost_matrix = marginal_cost.cost_matrix - tmp + unbalanced_correction cost_matrix += self.fused_penalty * self._fused_cost_matrix * rescale_factor geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon) return linear_problem.LinearProblem( geom, self.a, self.b, tau_a=self.tau_a, tau_b=self.tau_b )
[docs] def update_lr_linearization( self, lr_sink: 'sinkhorn_lr.LRSinkhornOutput' ) -> linear_problem.LinearProblem: """Update a Quad problem linearization using a LR Sinkhorn.""" return linear_problem.LinearProblem( self.update_lr_geom(lr_sink), self.a, self.b, tau_a=self.tau_a, tau_b=self.tau_b )
@property def _fused_cost_matrix(self) -> Union[float, jnp.ndarray]: if not self.is_fused: return 0. if isinstance( self.geom_xy, pointcloud.PointCloud ) and self.geom_xy.is_online: return self.geom_xy._compute_cost_matrix() * self.geom_xy.inv_scale_cost return self.geom_xy.cost_matrix @property def _is_low_rank_convertible(self) -> bool: def convertible(geom: geometry.Geometry) -> bool: return isinstance(geom, low_rank.LRCGeometry) or ( isinstance(geom, pointcloud.PointCloud) and geom.is_squared_euclidean ) if self.is_low_rank: return True geom_xx, geom_yy, geom_xy = self.geom_xx, self.geom_yy, self.geom_xy # either explicitly via cost factorization or implicitly (e.g., a PC) return self.ranks != -1 or ( convertible(geom_xx) and convertible(geom_yy) and (geom_xy is None or convertible(geom_xy)) )
[docs] def to_low_rank(self, seed: int = 0) -> "QuadraticProblem": """Convert geometries to low-rank. Args: seed: Random seed. Returns: Quadratic problem with low-rank geometries. """ def convert( vals: Union[int, float, Tuple[Union[int, float], ...]] ) -> Tuple[Union[int, float], ...]: size = 2 + self.is_fused if isinstance(vals, (int, float)): return (vals,) * 3 assert len(vals) == size, vals return vals + (None,) * (3 - size) if self.is_low_rank: return self (geom_xx, geom_yy, geom_xy, *children), aux_data = self.tree_flatten() (s1, s2, s3) = jax.random.split(jax.random.PRNGKey(seed), 3)[:, 0] (r1, r2, r3), (t1, t2, t3) = convert(self.ranks), convert(self.tolerances) geom_xx = geom_xx.to_LRCGeometry(rank=r1, tol=t1, seed=s1) geom_yy = geom_yy.to_LRCGeometry(rank=r2, tol=t2, seed=s2) if self.is_fused: if isinstance( geom_xy, pointcloud.PointCloud ) and geom_xy.is_squared_euclidean: geom_xy = geom_xy.to_LRCGeometry(scale=self.fused_penalty) else: geom_xy = geom_xy.to_LRCGeometry( rank=r3, tol=t3, seed=s3, scale=self.fused_penalty ) return type(self).tree_unflatten( aux_data, [geom_xx, geom_yy, geom_xy] + children )
@property def geom_xx(self) -> geometry.Geometry: """Geometry of the first space.""" return self._geom_xx @property def geom_yy(self) -> geometry.Geometry: """Geometry of the second space.""" return self._geom_yy @property def geom_xy(self) -> Optional[geometry.Geometry]: """Geometry of the joint space.""" return self._geom_xy @property def a(self) -> jnp.ndarray: """First marginal.""" num_a = self.geom_xx.shape[0] return jnp.ones((num_a,)) / num_a if self._a is None else self._a @property def b(self) -> jnp.ndarray: """Second marginal.""" num_b = self.geom_yy.shape[0] return jnp.ones((num_b,)) / num_b if self._b is None else self._b @property def is_fused(self) -> bool: """Whether the problem is fused.""" return self.geom_xy is not None @property def is_low_rank(self) -> bool: """Whether all geometries are low-rank.""" return ( isinstance(self.geom_xx, low_rank.LRCGeometry) and isinstance(self.geom_yy, low_rank.LRCGeometry) and (not self.is_fused or isinstance(self.geom_xy, low_rank.LRCGeometry)) ) @property def linear_loss(self) -> Tuple[quadratic_costs.Loss, quadratic_costs.Loss]: """Linear part of the Gromov-Wasserstein loss.""" return self.loss.f1, self.loss.f2 @property def quad_loss(self) -> Tuple[quadratic_costs.Loss, quadratic_costs.Loss]: """Quadratic part of the Gromov-Wasserstein loss.""" return self.loss.h1, self.loss.h2 @property def is_balanced(self) -> bool: """Whether the problem is balanced.""" return ((not self.gw_unbalanced_correction) or (self.tau_a == 1.0 and self.tau_b == 1.0)) def tree_flatten(self): return ([self.geom_xx, self.geom_yy, self.geom_xy, self._a, self._b], { 'tau_a': self.tau_a, 'tau_b': self.tau_b, 'loss': self._loss_name, 'fused_penalty': self.fused_penalty, 'scale_cost': self.scale_cost, 'gw_unbalanced_correction': self.gw_unbalanced_correction, 'ranks': self.ranks, 'tolerances': self.tolerances }) @classmethod def tree_unflatten(cls, aux_data, children): geoms, (a, b) = children[:3], children[3:] return cls(*geoms, a=a, b=b, **aux_data)
def update_epsilon_unbalanced( epsilon: Union[float, epsilon_scheduler.Epsilon], transport_mass: float ) -> epsilon_scheduler.Epsilon: updated_epsilon = epsilon_scheduler.Epsilon.make(epsilon) updated_epsilon._scale_epsilon = ( updated_epsilon._scale_epsilon * transport_mass ) return updated_epsilon def apply_cost( geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int, fn: quadratic_costs.Loss ) -> jnp.ndarray: return geom.apply_cost(arr, axis=axis, fn=fn.func, is_linear=fn.is_linear)