# Copyright 2022 The OTT Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sinkhorn initializers."""
import abc
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple
import jax
import jax.numpy as jnp
from ott.geometry import pointcloud
if TYPE_CHECKING:
from ott.problems.linear import linear_problem
__all__ = ["DefaultInitializer", "GaussianInitializer", "SortingInitializer"]
[docs]@jax.tree_util.register_pytree_node_class
class SinkhornInitializer(abc.ABC):
"""Base class for Sinkhorn initializers."""
[docs] @abc.abstractmethod
def init_dual_a(
self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool
) -> jnp.ndarray:
"""Initialization for Sinkhorn potential/scaling f_u."""
[docs] @abc.abstractmethod
def init_dual_b(
self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool
) -> jnp.ndarray:
"""Initialization for Sinkhorn potential/scaling g_v."""
def __call__(
self,
ot_prob: 'linear_problem.LinearProblem',
a: Optional[jnp.ndarray],
b: Optional[jnp.ndarray],
lse_mode: bool,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Initialize Sinkhorn potentials/scalings f_u and g_v.
Args:
ot_prob: Linear OT problem.
a: Initial potential/scaling f_u. If ``None``, it will be initialized using
:meth:`init_dual_a`.
b: Initial potential/scaling g_v. If ``None``, it will be initialized using
:meth:`init_dual_b`.
lse_mode: Return potentials if true, scalings otherwise.
Returns:
The initial potentials/scalings.
"""
n, m = ot_prob.geom.shape
if a is None:
a = self.init_dual_a(ot_prob, lse_mode=lse_mode)
if b is None:
b = self.init_dual_b(ot_prob, lse_mode=lse_mode)
assert a.shape == (
n,
), f"Expected `f_u` to have shape `{n,}`, found `{a.shape}`."
assert b.shape == (
m,
), f"Expected `g_v` to have shape `{m,}`, found `{b.shape}`."
# cancel dual variables for zero weights
a = jnp.where(ot_prob.a > 0., a, -jnp.inf if lse_mode else 0.)
b = jnp.where(ot_prob.b > 0., b, -jnp.inf if lse_mode else 0.)
return a, b
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [], {}
@classmethod
def tree_unflatten(
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "SinkhornInitializer":
return cls(*children, **aux_data)
[docs]@jax.tree_util.register_pytree_node_class
class DefaultInitializer(SinkhornInitializer):
"""Default initialization of Sinkhorn dual potentials/primal scalings."""
[docs] def init_dual_a(
self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool
) -> jnp.ndarray:
"""Initialize Sinkhorn potential/scaling f_u.
Args:
ot_prob: OT problem between discrete distributions of size n and m.
lse_mode: Return potential if true, scaling if false.
Returns:
potential/scaling, array of size n.
"""
a = ot_prob.a
init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a)
return init_dual_a
[docs] def init_dual_b(
self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool
) -> jnp.ndarray:
"""Initialize Sinkhorn potential/scaling g_v.
Args:
ot_prob: OT problem between discrete distributions of size n and m.
lse_mode: Return potential if true, scaling if false.
Returns:
potential/scaling, array of size m.
"""
b = ot_prob.b
init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b)
return init_dual_b
[docs]@jax.tree_util.register_pytree_node_class
class GaussianInitializer(DefaultInitializer):
"""Gaussian initializer :cite:`thornton2022rethinking:22`.
Compute Gaussian approximations of each point cloud, then compute closed from
Kantorovich potential between Gaussian approximations using Brenier's theorem
(adapt convex/Brenier potential to Kantorovich). Use this Gaussian potential
to initialize Sinkhorn potentials/scalings.
"""
[docs] def init_dual_a(
self,
ot_prob: 'linear_problem.LinearProblem',
lse_mode: bool,
) -> jnp.ndarray:
"""Gaussian initialization function.
Args:
ot_prob: OT problem between discrete distributions of size n and m.
lse_mode: Return potential if true, scaling if false.
Returns:
potential/scaling, array of size n.
"""
# import Gaussian here due to circular imports
from ott.tools.gaussian_mixture import gaussian
assert isinstance(
ot_prob.geom, pointcloud.PointCloud
), "Gaussian initializer valid only for point clouds."
x, y = ot_prob.geom.x, ot_prob.geom.y
a, b = ot_prob.a, ot_prob.b
gaussian_a = gaussian.Gaussian.from_samples(x, weights=a)
gaussian_b = gaussian.Gaussian.from_samples(y, weights=b)
# Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2
f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x)
f_potential = f_potential - jnp.mean(f_potential)
f_u = f_potential if lse_mode else ot_prob.geom.scaling_from_potential(
f_potential
)
return f_u
@jax.tree_util.register_pytree_node_class
class SortingInitializer(DefaultInitializer):
"""Sorting initializer :cite:`thornton2022rethinking:22`.
Solves non-regularized OT problem via sorting, then compute potential through
iterated minimum on C-transform and use this potential to initialize
regularized potential.
Args:
vectorized_update: Use vectorized inner loop if true.
tolerance: DualSort convergence threshold.
max_iter: Max DualSort steps.
"""
def __init__(
self,
vectorized_update: bool = True,
tolerance: float = 1e-2,
max_iter: int = 100
):
super().__init__()
self.tolerance = tolerance
self.max_iter = max_iter
self.vectorized_update = vectorized_update
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return ([], {
'tolerance': self.tolerance,
'max_iter': self.max_iter,
'vectorized_update': self.vectorized_update
})
def _init_sorting_dual(
self, modified_cost: jnp.ndarray, init_f: jnp.ndarray
) -> jnp.ndarray:
"""Run DualSort algorithm.
Args:
modified_cost: cost matrix minus diagonal column-wise.
init_f: potential f, array of size n. This is the starting potential,
which is then updated to make the init potential, so an init of an init.
Returns:
potential f, array of size n.
"""
def body_fn(
state: Tuple[jnp.ndarray, float, int]
) -> Tuple[jnp.ndarray, float, int]:
prev_f, _, it = state
new_f = fn(prev_f, modified_cost)
diff = jnp.sum((new_f - prev_f) ** 2)
it += 1
return new_f, diff, it
def cond_fn(state: Tuple[jnp.ndarray, float, int]) -> bool:
_, diff, it = state
return jnp.logical_and(diff > self.tolerance, it < self.max_iter)
fn = _vectorized_update if self.vectorized_update else _coordinate_update
state = (init_f, jnp.inf, 0) # init, error, iter
f_potential, _, _ = jax.lax.while_loop(
cond_fun=cond_fn, body_fun=body_fn, init_val=state
)
return f_potential
def init_dual_a(
self,
ot_prob: 'linear_problem.LinearProblem',
lse_mode: bool,
init_f: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""Apply DualSort algorithm.
Args:
ot_prob: OT problem.
lse_mode: Return potential if true, scaling if false.
init_f: potential f, array of size n. This is the starting potential,
which is then updated to make the init potential, so an init of an init.
Returns:
potential/scaling f_u, array of size n.
"""
assert not ot_prob.geom.is_online, \
"Sorting initializer does not work for online geometry."
# check for sorted x, y requires point cloud and could slow initializer
cost_matrix = ot_prob.geom.cost_matrix
assert cost_matrix.shape[0] == cost_matrix.shape[
1], "Requires square cost matrix."
modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :]
n = cost_matrix.shape[0]
init_f = jnp.zeros(n) if init_f is None else init_f
f_potential = self._init_sorting_dual(modified_cost, init_f)
f_potential = f_potential - jnp.mean(f_potential)
f_u = f_potential if lse_mode else ot_prob.geom.scaling_from_potential(
f_potential
)
return f_u
def _vectorized_update(
f: jnp.ndarray, modified_cost: jnp.ndarray
) -> jnp.ndarray:
"""Inner loop DualSort Update.
Args:
f: potential f, array of size n.
modified_cost: cost matrix minus diagonal column-wise.
Returns:
updated potential vector, f.
"""
return jnp.min(modified_cost + f[None, :], axis=1)
def _coordinate_update(
f: jnp.ndarray, modified_cost: jnp.ndarray
) -> jnp.ndarray:
"""Coordinate-wise updates within inner loop.
Args:
f: potential f, array of size n.
modified_cost: cost matrix minus diagonal column-wise.
Returns:
updated potential vector, f.
"""
def body_fn(i: int, f: jnp.ndarray) -> jnp.ndarray:
new_f = jnp.min(modified_cost[i, :] + f)
return f.at[i].set(new_f)
return jax.lax.fori_loop(0, len(f), body_fn, f)