{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "C4nZjbMHWcm_" }, "source": [ "# OTT vs. POT\n", "\n", "The [Python Optimal Transport (POT)](https://pythonot.github.io/) toolbox paved the way for much progress in OT. `POT` implements several OT solvers (LP and regularized), and is complemented with various tools (barycenters, domain adaptation, Gromov-Wasserstein distances, sliced W, etc.).\n", "\n", "The goal of this notebook is to compare the performance `OTT`'s {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` and `POT`'s {func}`~ot.sinkhorn` solvers. `OTT` benefits from just-in-time compilation, which should give it an edge.\n", "\n", "The comparisons carried out below have limitations: minor modifications in the setup (e.g. data distributions, tolerance thresholds, type of accelerator...) could have an impact on these results. Feel free to change these settings and experiment by yourself!" ] }, { "cell_type": "markdown", "metadata": { "id": "dpTlNSRqXevL" }, "source": [ "## Installs toolboxes" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " !pip install -q git+https://github.com/ott-jax/ott@main\n", " !pip install -q POT" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import timeit\n", "\n", "import ot\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "\n", "import matplotlib.pyplot as plt\n", "import mpl_toolkits.axes_grid1\n", "\n", "import ott\n", "from ott.geometry import pointcloud\n", "from ott.problems.linear import linear_problem\n", "from ott.solvers.linear import sinkhorn" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "ysURew0UKhHE" }, "outputs": [], "source": [ "plt.rc(\"font\", size=20)" ] }, { "cell_type": "markdown", "metadata": { "id": "rzeoMqPlcGMT" }, "source": [ "## Regularized OT in a nutshell\n", "\n", "We consider two probability measures $\\mu,\\nu$ compared with the squared-Euclidean distance, $c(x,y)=\\|x-y\\|^2$. These measures are discrete and of the same size in this notebook:\n", "\n", "$$\\mu=\\sum_{i=1}^n a_i\\delta_{x_i}, \\nu =\\sum_{j=1}^n b_j\\delta_{y_j},$$\n", "\n", "to define the OT problem in its primal form,\n", "$$\\min_{P \\in U(a,b)} \\langle C, P \\rangle - \\varepsilon H(P).$$\n", "\n", "where $U(a,b):=\\{P \\in \\mathbf{R}_+^{n\\times n}, P\\mathbf{1}_{n}=b, P^T\\mathbf{1}_n=b\\}$, and $C = [ \\|x_i - y_j \\|^2 ]_{i,j}\\in \\mathbf{R}_+^{n\\times n}$.\n", "\n", "That problem is equivalent to the following dual form,\n", "$$\\max_{f, g} \\langle a, f \\rangle + \\langle b, g \\rangle - \\varepsilon \\langle e^{f/\\varepsilon},Ke^{g/\\varepsilon} \\rangle.$$\n", "\n", "These two problems are solved by `OTT` and `POT` using the *Sinkhorn iterations* using a simple initialization for $u$, and subsequent updates $v \\leftarrow a / K^Tu, u \\leftarrow b / Kv$, where $K:=e^{-C/\\varepsilon}$.\n", "\n", "Upon convergence to fixed points $u^*, v^*$, one has $$P^*=D(u^*)KD(v^*)$$ or, alternatively, \n", "$$f^*, g^* = \\varepsilon \\log(u^*), \\varepsilon\\log(v^*)$$" ] }, { "cell_type": "markdown", "metadata": { "id": "OhSVMQUWYZiY" }, "source": [ "## OTT and POT implementation\n", "\n", "Both toolboxes carry out Sinkhorn updates using either the formulas above directly (this corresponds to `lse_mode=False` in `OTT` and `method='sinkhorn'` in `POT`) or using slightly slower but more robust approaches:\n", "\n", "`OTT` relies on log-space iterations (`lse_mode=True`), whereas `POT`, uses a stabilization trick , using the `method='sinkhorn_stabilized'` flag, designed to avoid numerical overflows, while still benefitting from the speed given by matrix vector products. \n", "\n", "The default behaviour of `OTT` and [POT](https://github.com/PythonOT/POT/blob/f6139428e70ce964de3bef703ef13aa701a83620/ot/bregman.py#L413) is to carry out these updates until $\\|u\\circ Kv - a\\|_2 + \\|v\\circ K^Tu - b\\|_2$ is smaller than the user-defined `threshold`." ] }, { "cell_type": "markdown", "metadata": { "id": "yjjlf297b-WF" }, "source": [ "## Common API for `OTT` and `POT`\n", "\n", "We will compare in our experiments `OTT` vs. `POT` in their more stable setups (`lse_mode` and `stabilized`). We define a common API for both, making sure their results are comparable. That API takes as inputs the measures' info, the targeted $\\varepsilon$ value and the `threshold` used to terminate the algorithm. We set a maximum of 1000 iterations for both." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "cM2cM87nZ6XU" }, "outputs": [], "source": [ "def solve_ot(a, b, x, y, 𝜀, threshold):\n", " _, log = ot.sinkhorn(\n", " a,\n", " b,\n", " ot.dist(x, y),\n", " 𝜀,\n", " stopThr=threshold,\n", " method=\"sinkhorn_stabilized\",\n", " log=True,\n", " numItermax=1000,\n", " )\n", " f, g = 𝜀 * log[\"logu\"], 𝜀 * log[\"logv\"]\n", " f, g = f - np.mean(f), g + np.mean(\n", " f\n", " ) # center variables, useful if one wants to compare them\n", " reg_ot = (\n", " np.sum(f * a) + np.sum(g * b) if log[\"err\"][-1] < threshold else np.nan\n", " )\n", " return f, g, reg_ot\n", "\n", "\n", "@jax.jit\n", "def solve_ott(a, b, x, y, 𝜀, threshold):\n", " geom = pointcloud.PointCloud(x, y, epsilon=𝜀)\n", " prob = linear_problem.LinearProblem(geom, a=a, b=b)\n", "\n", " solver = sinkhorn.Sinkhorn(\n", " threshold=threshold, lse_mode=True, max_iterations=1000\n", " )\n", " out = solver(prob)\n", "\n", " f, g = out.f, out.g\n", " f, g = f - np.mean(f), g + np.mean(\n", " f\n", " ) # center variables, useful if one wants to compare them\n", " reg_ot = jnp.where(out.converged, jnp.sum(f * a) + jnp.sum(g * b), jnp.nan)\n", " return f, g, reg_ot" ] }, { "cell_type": "markdown", "metadata": { "id": "bPWuBwYvC-y-" }, "source": [ "To test both solvers, we run simulations using a random seed to generate random point clouds of size $n$. Random generation is carried out using {func}`jax.random.PRNGKey`, to ensure reproducibility. A solver provides three pieces of info: the function (using our simple common API), its numerical environment and its name." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "kchT2nnMKl2q" }, "outputs": [], "source": [ "dim = 3\n", "\n", "\n", "def run_simulation(rng, n, 𝜀, threshold, solver_spec):\n", " # setting global variables helps avoir a timeit bug.\n", " global solver_\n", " global a, b, x, y\n", "\n", " # extract specificities of solver.\n", " solver_, env, name = solver_spec\n", "\n", " # draw data at random using JAX\n", " rng, *rngs = jax.random.split(rng, 5)\n", " x = jax.random.uniform(rngs[0], (n, dim))\n", " y = jax.random.uniform(rngs[1], (n, dim)) + 0.1\n", " a = jax.random.uniform(rngs[2], (n,))\n", " b = jax.random.uniform(rngs[3], (n,))\n", " a = a / jnp.sum(a)\n", " b = b / jnp.sum(b)\n", "\n", " # map to numpy if needed\n", " if env == \"np\":\n", " a, b, x, y = map(np.array, (a, b, x, y))\n", "\n", " timeit_res = %timeit -o solver_(a, b, x, y, 𝜀, threshold)\n", " out = solver_(a, b, x, y, 𝜀, threshold)\n", " exec_time = np.nan if np.isnan(out[-1]) else timeit_res.best\n", " return exec_time, out" ] }, { "cell_type": "markdown", "metadata": { "id": "DxySanaEOwYX" }, "source": [ "Defines the two solvers used in this experiment:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "EWmVxyysvEKT" }, "outputs": [], "source": [ "POT = (solve_ot, \"np\", \"POT\")\n", "OTT = (solve_ott, \"jax\", \"OTT\")" ] }, { "cell_type": "markdown", "metadata": { "id": "szWoHukXOz08" }, "source": [ "## Runs simulations with varying $n$ and $\\varepsilon$\n", "We run simulations by setting the regularization strength 𝜀 to either $10^{-2}$ or $10^{-1}$.\n", "\n", "We consider $n$ between sizes $2^{8}= 256$ and $2^{12}= 4096$. We do not go higher, because `POT` runs into out-of-memory errors for $2^{13}=8192$ in this RAM restricted colab environment. `OTT` can avoid these by setting the flag `batch_size` to, e.g., `1024`, as done in the tutorial for grids, and also handled by the [GeomLoss](https://www.kernel-operations.io/geomloss/) toolbox. We leave the comparison with `geomloss` to a different NB. \n", "\n", "When `%timeit` outputs execution time, **notice the warning message** highlighting the fact that, for `OTT`, at least one run took significantly longer. That run is that doing the **JIT pre-compilation** of the procedure, suitable for that particular problem size $n$. Once pre-compiled, subsequent runs are order of magnitudes faster, thanks to the {func}`jax.jit` decorator added to `solve_ott`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9VWhb6B3VFJN", "outputId": "7b85bed2-903e-465f-b7eb-cf8479554c4e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----- POT\n", "10 loops, best of 5: 43.7 ms per loop\n", "100 loops, best of 5: 11.9 ms per loop\n", "1 loop, best of 5: 230 ms per loop\n", "10 loops, best of 5: 41.4 ms per loop\n", "1 loop, best of 5: 33.4 s per loop\n", "10 loops, best of 5: 155 ms per loop\n", "1 loop, best of 5: 2min 13s per loop\n", "1 loop, best of 5: 367 ms per loop\n", "1 loop, best of 5: 6min 21s per loop\n", "1 loop, best of 5: 1.22 s per loop\n", "----- OTT\n", "The slowest run took 66.78 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "1 loop, best of 5: 11.2 ms per loop\n", "1000 loops, best of 5: 1.04 ms per loop\n", "The slowest run took 128.37 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "1 loop, best of 5: 6.12 ms per loop\n", "1000 loops, best of 5: 1.08 ms per loop\n", "The slowest run took 94.84 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "1 loop, best of 5: 8.95 ms per loop\n", "1000 loops, best of 5: 1.42 ms per loop\n", "The slowest run took 33.90 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "1 loop, best of 5: 24 ms per loop\n", "100 loops, best of 5: 3.47 ms per loop\n", "The slowest run took 8.19 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "1 loop, best of 5: 112 ms per loop\n", "100 loops, best of 5: 14.3 ms per loop\n" ] } ], "source": [ "rng = jax.random.PRNGKey(0)\n", "solvers = (POT, OTT)\n", "n_range = 2 ** np.arange(8, 13)\n", "𝜀_range = 10 ** np.arange(-2.0, 0.0)\n", "\n", "threshold = 1e-2\n", "\n", "exec_time = {}\n", "reg_ot = {}\n", "for solver_spec in solvers:\n", " solver, env, name = solver_spec\n", " print(\"----- \", name)\n", " exec_time[name] = np.ones((len(n_range), len(𝜀_range))) * np.nan\n", " reg_ot[name] = np.ones((len(n_range), len(𝜀_range))) * np.nan\n", " for i, n in enumerate(n_range):\n", " for j, 𝜀 in enumerate(𝜀_range):\n", " t, out = run_simulation(rng, n, 𝜀, threshold, solver_spec)\n", " exec_time[name][i, j] = t\n", " reg_ot[name][i, j] = out[-1]" ] }, { "cell_type": "markdown", "metadata": { "id": "zruxPCbHN1HY" }, "source": [ "## Plots results in terms of time and difference in objective\n", "\n", "When the algorithm does not converge within the maximal number of 1000 iterations, or runs into numerical issues, the solver returns a NaN and that point does not appear in the plot." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 540 }, "id": "xj_9C6-3uHMH", "outputId": "736a26fd-3a0f-4b0f-b107-843a84503bc7" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "list_legend = []\n", "fig = plt.figure(figsize=(14, 8))\n", "\n", "for solver_spec, marker, col in zip(solvers, (\"p\", \"o\"), (\"blue\", \"red\")):\n", " solver, env, name = solver_spec\n", " p = plt.plot(\n", " exec_time[name],\n", " marker=marker,\n", " color=col,\n", " markersize=16,\n", " markeredgecolor=\"k\",\n", " lw=3,\n", " )\n", " p[0].set_linestyle(\"dotted\")\n", " p[1].set_linestyle(\"solid\")\n", " list_legend += [name + r\" $\\varepsilon $=\" + f\"{𝜀:.2g}\" for 𝜀 in 𝜀_range]\n", "\n", "plt.xticks(ticks=np.arange(len(n_range)), labels=n_range)\n", "plt.legend(list_legend)\n", "plt.yscale(\"log\")\n", "plt.xlabel(\"dimension $n$\")\n", "plt.ylabel(\"time (s)\")\n", "plt.title(\n", " r\"Execution Time vs Dimension for OTT and POT for two $\\varepsilon$ values\"\n", ")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "ma5sqdMtzggT" }, "source": [ "For good measure, we also show the differences in *objectives* between the two solvers. We substract the objective returned by `POT` to that returned by `OTT`.\n", "\n", "Since the problem is evaluated in its dual form, a *higher* objective is *better*, and therefore a positive difference denotes better performance for `OTT`. White areas stand for values for which `POT` did not converge (either because it has exhausted the maximal number of iterations or experienced numerical issues)." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 357 }, "id": "RCoef_sbzyFn", "outputId": "42851943-d6e7-4765-9432-5f95237a4b4f" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "fig = plt.figure(figsize=(12, 8))\n", "ax = plt.gca()\n", "im = ax.imshow(reg_ot[\"OTT\"].T - reg_ot[\"POT\"].T)\n", "plt.xticks(ticks=np.arange(len(n_range)), labels=n_range)\n", "plt.yticks(ticks=np.arange(len(𝜀_range)), labels=𝜀_range)\n", "plt.xlabel(\"dimension $n$\")\n", "plt.ylabel(r\"regularization $\\varepsilon$\")\n", "plt.title(\"Gap in objective, >0 when OTT is better\")\n", "divider = mpl_toolkits.axes_grid1.make_axes_locatable(ax)\n", "cax = divider.append_axes(\"right\", size=\"5%\", pad=0.1)\n", "plt.colorbar(im, cax=cax)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZxwbScNpz3LR", "outputId": "36eed4e2-4fca-4cb1-fe00-cf10c4291c09" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "---- POT\n", "Objective\n", "[[-0.00862313 -0.79116929]\n", " [-0.02666368 -0.93283839]\n", " [ nan -1.07958862]\n", " [ nan -1.22432204]\n", " [ nan -1.36762311]]\n", "Time\n", "[[0.04367424 0.01185102]\n", " [0.22960342 0.04137421]\n", " [ nan 0.15465033]\n", " [ nan 0.3669143 ]\n", " [ nan 1.21968372]]\n", "---- OTT\n", "Objective\n", "[[-0.00783848 -0.79117149]\n", " [-0.02610656 -0.93283963]\n", " [-0.05083928 -1.07959068]\n", " [-0.06328616 -1.21402502]\n", " [-0.07956241 -1.35710597]]\n", "Time\n", "[[0.01124264 0.00103751]\n", " [0.00612156 0.00107929]\n", " [0.00895449 0.00142238]\n", " [0.02404206 0.00346715]\n", " [0.11208566 0.01432985]]\n" ] } ], "source": [ "for name in (\"POT\", \"OTT\"):\n", " print(\"----\", name)\n", " print(\"Objective\")\n", " print(reg_ot[name])\n", " print(\"Execution Time\")\n", " print(exec_time[name])" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "OTT & POT", "provenance": [] }, "kernelspec": { "display_name": "ott", "language": "python", "name": "ott" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 1 }