from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Literal,
Optional,
Sequence,
Tuple,
)
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
from ott.problems.linear import linear_problem
if TYPE_CHECKING:
from ott.geometry import costs
__all__ = ["DualPotentials", "EntropicPotentials"]
Potential_t = Callable[[jnp.ndarray], float]
[docs]@jtu.register_pytree_node_class
class DualPotentials:
r"""The Kantorovich dual potential functions :math:`f` and :math:`g`.
:math:`f` and :math:`g` are a pair of functions, candidates for the dual
OT Kantorovich problem, supposedly optimal for a given pair of measures.
Args:
f: The first dual potential function.
g: The second dual potential function.
cost_fn: The cost function used to solve the OT problem.
corr: Whether the duals solve the problem in distance form, or correlation
form (as used for instance for ICNNs, see, e.g., top right of p.3 in
:cite:`makkuva:20`)
"""
def __init__(
self,
f: Potential_t,
g: Potential_t,
*,
cost_fn: 'costs.CostFn',
corr: bool = False
):
self._f = f
self._g = g
self.cost_fn = cost_fn
self._corr = corr
[docs] def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
r"""Transport ``vec`` according to Brenier formula :cite:`brenier:91`.
Uses Theorem 1.17 from :cite:`santambrogio:15` to compute an OT map when
given the Legendre transform of the dual potentials.
That OT map can be recovered as :math:`x- (\nabla h)^{-1}\circ \nabla f(x)`
For the case :math:`h(\cdot) = \|\cdot\|^2, \nabla h(\cdot) = 2 \cdot\,`,
and as a consequence :math:`h^*(\cdot) = \|.\|^2 / 4`, while one has that
:math:`\nabla h^*(\cdot) = (\nabla h)^{-1}(\cdot) = 0.5 \cdot\,`.
When the dual potentials are solved in correlation form (only in the Sq.
Euclidean distance case), the maps are :math:`\nabla g` for forward,
:math:`\nabla f` for backward.
Args:
vec: Points to transport, array of shape ``[n, d]``.
forward: Whether to transport the points from source to the target
distribution or vice-versa.
Returns:
The transported points.
"""
from ott.geometry import costs
vec = jnp.atleast_2d(vec)
if self._corr and isinstance(self.cost_fn, costs.SqEuclidean):
return self._grad_g(vec) if forward else self._grad_f(vec)
if forward:
return vec - self._grad_h_inv(self._grad_f(vec))
else:
return vec - self._grad_h_inv(self._grad_g(vec))
[docs] def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
"""Evaluate 2-Wasserstein distance between samples using dual potentials.
Uses Eq. 5 from :cite:`makkuva:20` when given in `corr` form, direct
estimation by integrating dual function against points when using dual form.
Args:
src: Samples from the source distribution, array of shape ``[n, d]``.
tgt: Samples from the target distribution, array of shape ``[m, d]``.
Returns:
Wasserstein distance.
"""
src, tgt = jnp.atleast_2d(src), jnp.atleast_2d(tgt)
f = jax.vmap(self.f)
if self._corr:
grad_g_y = self._grad_g(tgt)
term1 = -jnp.mean(f(src))
term2 = -jnp.mean(jnp.sum(tgt * grad_g_y, axis=-1) - f(grad_g_y))
C = jnp.mean(jnp.sum(src ** 2, axis=-1))
C += jnp.mean(jnp.sum(tgt ** 2, axis=-1))
return 2. * (term1 + term2) + C
g = jax.vmap(self.g)
return jnp.mean(f(src)) + jnp.mean(g(tgt))
@property
def f(self) -> Potential_t:
"""The first dual potential function."""
return self._f
@property
def g(self) -> Potential_t:
"""The second dual potential function."""
return self._g
@property
def _grad_f(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
"""Vectorized gradient of the potential function :attr:`f`."""
return jax.vmap(jax.grad(self.f, argnums=0))
@property
def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
"""Vectorized gradient of the potential function :attr:`g`."""
return jax.vmap(jax.grad(self.g, argnums=0))
@property
def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
from ott.geometry import costs
assert isinstance(self.cost_fn, costs.TICost), (
"Cost must be a `TICost` and "
"provide access to Legendre transform of `h`."
)
return jax.vmap(jax.grad(self.cost_fn.h_legendre))
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [], {
"f": self._f,
"g": self._g,
"cost_fn": self.cost_fn,
"corr": self._corr
}
@classmethod
def tree_unflatten(
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "DualPotentials":
return cls(*children, **aux_data)
[docs]@jtu.register_pytree_node_class
class EntropicPotentials(DualPotentials):
"""Dual potential functions from finite samples :cite:`pooladian:21`.
Args:
f_xy: The first dual potential vector of shape ``[n,]``.
g_xy: The second dual potential vector of shape ``[m,]``.
prob: Linear problem with :class:`~ott.geometry.pointcloud.PointCloud`
geometry that was used to compute the dual potentials using, e.g.,
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.
f_xx: The first dual potential vector of shape ``[n,]`` used for debiasing
:cite:`pooladian:22`.
g_yy: The second dual potential vector of shape ``[m,]`` used for debiasing.
"""
def __init__(
self,
f_xy: jnp.ndarray,
g_xy: jnp.ndarray,
prob: linear_problem.LinearProblem,
f_xx: Optional[jnp.ndarray] = None,
g_yy: Optional[jnp.ndarray] = None,
):
# we pass directly the arrays and override the properties
# since only the properties need to be callable
super().__init__(f_xy, g_xy, cost_fn=prob.geom.cost_fn, corr=False)
self._prob = prob
self._f_xx = f_xx
self._g_yy = g_yy
@property
def f(self) -> Potential_t:
return self._potential_fn(kind="f")
@property
def g(self) -> Potential_t:
return self._potential_fn(kind="g")
def _potential_fn(self, *, kind: Literal["f", "g"]) -> Potential_t:
from ott.geometry import pointcloud
def callback(
x: jnp.ndarray,
*,
potential: jnp.ndarray,
y: jnp.ndarray,
weights: jnp.ndarray,
epsilon: float,
) -> float:
x = jnp.atleast_2d(x)
assert x.shape[-1] == y.shape[-1], (x.shape, y.shape)
geom = pointcloud.PointCloud(x, y, cost_fn=self.cost_fn)
cost = geom.cost_matrix
z = (potential - cost) / epsilon
lse = -epsilon * jsp.special.logsumexp(z, b=weights, axis=-1)
return jnp.squeeze(lse)
assert isinstance(
self._prob.geom, pointcloud.PointCloud
), f"Expected point cloud geometry, found `{type(self._prob.geom)}`."
x, y = self._prob.geom.x, self._prob.geom.y
a, b = self._prob.a, self._prob.b
if kind == "f":
# When seeking to evaluate 1st potential function,
# the 2nd set of potential values and support should be used,
# see proof of Prop. 2 in https://arxiv.org/pdf/2109.12004.pdf
potential, arr, weights = self._g, y, b
else:
potential, arr, weights = self._f, x, a
potential_xy = jax.tree_util.Partial(
callback,
potential=potential,
y=arr,
weights=weights,
epsilon=self.epsilon,
)
if not self.is_debiased:
return potential_xy
ep = EntropicPotentials(self._f_xx, self._g_yy, prob=self._prob)
# switch the order because for `kind='f'` we require `f/x/a` in `other`
# which is accessed when `kind='g'`
potential_other = ep._potential_fn(kind="g" if kind == "f" else "f")
return lambda x: (potential_xy(x) - potential_other(x))
@property
def is_debiased(self) -> bool:
"""Whether the entropic map is debiased."""
return self._f_xx is not None and self._g_yy is not None
@property
def epsilon(self) -> float:
"""Entropy regularizer."""
return self._prob.geom.epsilon
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._f, self._g, self._prob, self._f_xx, self._g_yy], {}