Source code for ott.problems.quadratic.quadratic_costs

from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

__all__ = ["make_square_loss", "make_kl_loss"]


class Loss(NamedTuple):
  func: Callable[[jnp.ndarray], jnp.ndarray]
  is_linear: bool


class GWLoss(NamedTuple):
  f1: Loss
  f2: Loss
  h1: Loss
  h2: Loss


[docs]def make_square_loss() -> GWLoss: f1 = Loss(lambda x: x ** 2, is_linear=False) f2 = Loss(lambda y: y ** 2, is_linear=False) h1 = Loss(lambda x: x, is_linear=True) h2 = Loss(lambda y: 2.0 * y, is_linear=True) return GWLoss(f1, f2, h1, h2)
[docs]def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss: f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False) f2 = Loss(lambda y: y, is_linear=True) h1 = Loss(lambda x: x, is_linear=True) h2 = Loss(lambda y: jnp.log(jnp.clip(y, clipping_value)), is_linear=False) return GWLoss(f1, f2, h1, h2)