Source code for ott.solvers.quadratic.gw_barycenter

from functools import partial
from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp

from ott.geometry import pointcloud
from ott.math import fixed_point_loop
from ott.problems.linear import linear_problem
from ott.problems.quadratic import gw_barycenter
from ott.solvers import was_solver
from ott.solvers.quadratic import gromov_wasserstein

__all__ = ["GWBarycenterState", "GromovWassersteinBarycenter"]


[docs]class GWBarycenterState(NamedTuple): """Holds the state of the \ :class:`~ott.problems.quadratic.gw_barycenter.GWBarycenterProblem`. Args: c: Barycenter cost matrix of shape ``[bar_size, bar_size]``. x: Barycenter features of shape ``[bar_size, ndim_fused]``. Only used in the fused case. a: Weights of the barycenter of shape ``[bar_size,]``. errors: Array of shape ``[max_iter, num_measures, quad_max_iter, lin_outer_iter]`` containing the GW errors at each iteration. costs: Array of shape ``[max_iter,]`` containing the cost at each iteration. gw_convergence: Array of shape ``[max_iter,]`` containing the convergence of all GW problems at each iteration. """ cost: Optional[jnp.ndarray] = None x: Optional[jnp.ndarray] = None a: Optional[jnp.ndarray] = None errors: Optional[jnp.ndarray] = None costs: Optional[jnp.ndarray] = None gw_convergence: Optional[jnp.ndarray] = None
[docs] def set(self, **kwargs: Any) -> 'GWBarycenterState': """Return a copy of self, possibly with overwrites.""" return self._replace(**kwargs)
[docs]@jax.tree_util.register_pytree_node_class class GromovWassersteinBarycenter(was_solver.WassersteinSolver): """Gromov-Wasserstein barycenter solver of the \ :class:`~ott.problems.quadratic.gw_barycenter.GWBarycenterProblem`. Args: epsilon: Entropy regulariser. min_iterations: Minimum number of iterations. max_iterations: Maximum number of outermost iterations. threshold: Convergence threshold. store_inner_errors: Whether to store the errors of the GW solver, as well as its linear solver, at each iteration for each measure. quad_solver: The GW solver. kwargs: Keyword argument for :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein`. Only used when ``quad_solver = None``. """ def __init__( self, epsilon: Optional[float] = None, min_iterations: int = 5, max_iterations: int = 50, threshold: float = 1e-3, store_inner_errors: bool = False, quad_solver: Optional[gromov_wasserstein.GromovWasserstein] = None, # TODO(michalk8): maintain the API compatibility with `was_solver` # but makes passing kwargs with the same name to `quad_solver` impossible # will be fixed when refactoring the solvers # note that `was_solver` also suffers from this **kwargs: Any, ): super().__init__( epsilon=epsilon, min_iterations=min_iterations, max_iterations=max_iterations, threshold=threshold, store_inner_errors=store_inner_errors ) self._quad_solver = quad_solver if quad_solver is None: kwargs["epsilon"] = epsilon # TODO(michalk8): store only GW errors? kwargs["store_inner_errors"] = store_inner_errors self._quad_solver = gromov_wasserstein.GromovWasserstein(**kwargs) def __call__( self, problem: gw_barycenter.GWBarycenterProblem, bar_size: int, **kwargs: Any ) -> GWBarycenterState: """Solver the (fused) GW barycenter problem. Args: problem: The GW barycenter problem. bar_size: Size of the barycenter. kwargs: Keyword arguments for :meth:`init_state`. Returns: The solution. """ state = self.init_state(problem, bar_size, **kwargs) state = iterations(solver=self, problem=problem, init_state=state) return self.output_from_state(state)
[docs] def init_state( self, problem: gw_barycenter.GWBarycenterProblem, bar_size: int, bar_init: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, a: Optional[jnp.ndarray] = None, seed: int = 0, ) -> GWBarycenterState: """Initialize the (fused) Gromov-Wasserstein barycenter state. Args: problem: The barycenter problem. bar_size: Size of the barycenter. bar_init: Initial barycenter value. Can be one of the following: - ``None`` - randomly initialize the barycenter. - :class:`jax.numpy.ndarray` - barycenter cost matrix of shape ``[bar_size, bar_size]``. Only used in the non-fused case. - :class:`tuple` of :class:`jax.numpy.ndarray` - the 1st array corresponds to a cost matrix of shape ``[bar_size, bar_size]``, the 2nd array is a ``[bar_size, ndim_fused]`` feature matrix used in the fused case. a: An array of shape ``[bar_size,]`` containing the barycenter weights. seed: Random seed used when ``bar_init = None``. Returns: The initial barycenter state. """ if a is None: a = jnp.ones((bar_size,)) / bar_size else: assert a.shape == (bar_size,) if bar_init is None: _, b = problem.segmented_y_b rng = jax.random.PRNGKey(seed) keys = jax.random.split(rng, problem.num_measures) linear_solver = self._quad_solver.linear_ot_solver transports = init_transports(linear_solver, keys, a, b, problem.epsilon) x = problem.update_features(transports, a) if problem.is_fused else None cost = problem.update_barycenter(transports, a) else: cost, x = bar_init if isinstance(bar_init, tuple) else (bar_init, None) assert cost.shape == (bar_size, bar_size) if problem.is_fused: assert x is not None, "Barycenter features are not initialized." assert x.shape == (bar_size, problem.ndim_fused) num_iter = self.max_iterations if self.store_inner_errors: # TODO(michalk8): in the future, think about how to do this in general errors = -jnp.ones(( num_iter, problem.num_measures, self._quad_solver.max_iterations, self._quad_solver.linear_ot_solver.outer_iterations )) else: errors = None costs = -jnp.ones((num_iter,)) gw_convergence = -jnp.ones((num_iter,)) return GWBarycenterState( cost=cost, x=x, a=a, errors=errors, costs=costs, gw_convergence=gw_convergence )
[docs] def update_state( self, state: GWBarycenterState, iteration: int, problem: gw_barycenter.GWBarycenterProblem, store_errors: bool = True, ) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]: def solve_gw( state: GWBarycenterState, b: jnp.ndarray, y: jnp.ndarray, f: Optional[jnp.ndarray] ) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]: quad_problem = problem._create_problem(state, y=y, b=b, f=f) out = self._quad_solver(quad_problem) return ( out.reg_gw_cost, out.converged, out.matrix, out.errors if store_errors else None ) in_axes = [None, 0, 0] in_axes += [0] if problem.is_fused else [None] solve_fn = jax.vmap(solve_gw, in_axes=in_axes) y, b = problem.segmented_y_b y_f = problem.segmented_y_fused costs, convergeds, transports, errors = solve_fn(state, b, y, y_f) cost = jnp.sum(costs * problem.weights) costs = state.costs.at[iteration].set(cost) converged = jnp.all(convergeds) gw_convergence = state.gw_convergence.at[iteration].set(converged) if self.store_inner_errors: errors = state.errors.at[iteration, ...].set(errors) else: errors = None x = problem.update_features( transports, state.a ) if problem.is_fused else state.x cost = problem.update_barycenter(transports, state.a) return state.set( cost=cost, x=x, costs=costs, errors=errors, gw_convergence=gw_convergence )
[docs] def output_from_state(self, state: GWBarycenterState) -> GWBarycenterState: """No-op.""" # TODO(michalk8): just for consistency with continuous barycenter # will be refactored in the future to create an output return state
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: children, aux = super().tree_flatten() return children + [self._quad_solver], aux @classmethod def tree_unflatten( cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "GromovWassersteinBarycenter": epsilon, _, threshold, quad_solver = children return cls( epsilon=epsilon, threshold=threshold, quad_solver=quad_solver, **aux_data, )
@partial(jax.vmap, in_axes=[None, 0, None, 0, None]) def init_transports( solver, key: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, epsilon: Optional[float] ) -> jnp.ndarray: """Initialize random 2D point cloud and solve the linear OT problem. Args: solver: Linear OT solver. key: Random key. a: Source marginals (e.g., for barycenter) of shape ``[bar_size,]``. b: Target marginals of shape ``[max_measure_size,]``. epsilon: Entropy regularization. Returns: Transport map of shape ``[bar_size, max_measure_size]``. """ key1, key2 = jax.random.split(key, 2) x = jax.random.normal(key1, shape=(len(a), 2)) y = jax.random.normal(key2, shape=(len(b), 2)) geom = pointcloud.PointCloud( x, y, epsilon=epsilon, src_mask=a > 0, tgt_mask=b > 0 ) problem = linear_problem.LinearProblem(geom, a=a, b=b) return solver(problem).matrix def iterations( solver: GromovWassersteinBarycenter, problem: gw_barycenter.GWBarycenterProblem, init_state: GWBarycenterState ) -> GWBarycenterState: def cond_fn( iteration: int, constants: GromovWassersteinBarycenter, state: GWBarycenterState ) -> bool: solver, _ = constants return solver._continue(state, iteration) def body_fn( iteration, constants: Tuple[GromovWassersteinBarycenter, gw_barycenter.GWBarycenterProblem], state: GWBarycenterState, compute_error: bool ) -> GWBarycenterState: del compute_error # always assumed true solver, problem = constants return solver.update_state(state, iteration, problem) state = fixed_point_loop.fixpoint_iter( cond_fn=cond_fn, body_fn=body_fn, min_iterations=solver.min_iterations, max_iterations=solver.max_iterations, inner_iterations=1, constants=(solver, problem), state=init_state, ) return state