|Downloads| |Tests| |Docs| |Coverage|
Optimal Transport Tools (OTT)
=============================
Introduction
------------
``OTT`` is a `JAX `_ package that bundles a few utilities to compute,
and differentiate as needed, the solution to optimal transport (OT) problems, taken in a fairly wide sense.
For instance, ``OTT`` can of course compute Wasserstein (or Gromov-Wasserstein) distances between
weighted clouds of points (or histograms) in a wide variety of scenarios,
but also estimate Monge maps, Wasserstein barycenters, and help with simpler tasks
such as differentiable approximations to ranking or even clustering.
To achieve this, ``OTT`` rests on two families of tools:
The first family consists in *discrete* solvers computing transport between point clouds,
using the Sinkhorn :cite:`cuturi:13` and low-rank Sinkhorn :cite:`scetbon:21` algorithms,
and moving up towards Gromov-Wasserstein :cite:`memoli:11,peyre:16`;
the second family consists in *continuous* solvers, using suitable neural architectures :cite:`amos:17` coupled
with SGD type estimators :cite:`makkuva:20,korotin:21`.
Installation
------------
Install ``OTT`` from `PyPI `_ as:
.. code-block:: bash
pip install ott-jax
or with ``conda`` via `conda-forge `_ as:
.. code-block:: bash
conda install -c conda-forge ott-jax
Design Choices
--------------
``OTT`` is designed with the following choices:
- Take advantage whenever possible of JAX features, such as `Just-in-time (JIT) compilation`_,
`auto-vectorization (VMAP)`_ and both `automatic`_ but most importantly `implicit`_ differentiation.
- Split geometry from OT solvers in the discrete case: We argue that there
should be one, and one implementation only, of every major OT algorithm
(Sinkhorn, Gromov-Wasserstein, barycenters, etc...), regardless of the
geometric setup that is considered. To give a concrete example, any
speedups one may benefit from by using a specific cost
(e.g. Sinkhorn being faster when run on a separable cost on histograms supported
on a separable grid :cite:`solomon:15`) should not require a separate
reimplementation of a Sinkhorn routine.
- As a consequence, and to minimize code copy/pasting, use as often as possible
object hierarchies, and interleave outer solvers (such as quadratic,
aka Gromov-Wasserstein solvers) with inner solvers (e.g. Low-Rank Sinkhorn).
This choice ensures that speedups achieved at lower computation levels
(e.g. low-rank factorization of squared Euclidean distances) propagate seamlessly and
automatically in higher level calls (e.g. updates in Gromov-Wasserstein),
without requiring any attention from the user.
.. TODO(marcocuturi): add missing package descriptions below
Packages
--------
- :ref:`geometry` contains classes to instantiate objects that describe
*two point clouds* paired with a *cost* function. Geometry objects are used to
describe OT problems, handled by solvers in the :ref:`solvers`.
- :ref:`problems`
- :ref:`solvers`
- :ref:`initializers`
- :ref:`tools` provides an interface to exploit OT solutions, as produced by
solvers in the :ref:`solvers`. Such tasks include computing approximations
to Wasserstein distances :cite:`genevay:18,sejourne:19`, approximating OT
between GMMs, or computing differentiable sort and quantile operations
:cite:`cuturi:19`.
- :ref:`math`
.. toctree::
:maxdepth: 1
:caption: Examples
Getting Started
tutorials/index
.. toctree::
:maxdepth: 1
:caption: API
geometry
problems/index
solvers/index
initializers/index
tools
math
.. toctree::
:maxdepth: 1
:caption: References
references
.. |Downloads| image:: https://static.pepy.tech/badge/ott-jax
:target: https://pypi.org/project/ott-jax/
:alt: Documentation
.. |Tests| image:: https://img.shields.io/github/actions/workflow/status/ott-jax/ott/tests.yml?branch=main
:target: https://github.com/ott-jax/ott/actions/workflows/tests.yml
:alt: Documentation
.. |Docs| image:: https://img.shields.io/readthedocs/ott-jax/latest
:target: https://ott-jax.readthedocs.io/en/latest/
:alt: Documentation
.. |Coverage| image:: https://img.shields.io/codecov/c/github/ott-jax/ott/main
:target: https://app.codecov.io/gh/ott-jax/ott
:alt: Coverage
.. _Just-in-time (JIT) compilation: https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit
.. _auto-vectorization (VMAP): https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap
.. _automatic: https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation
.. _implicit: https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html#jax.custom_jvp