{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "-_L4OiekjBRp" }, "source": [ "# Sinkhorn Divergence Hessians" ] }, { "cell_type": "markdown", "metadata": { "id": "jzzs0FmbPpvY" }, "source": [ "## Samples two point clouds, computes their Sinkhorn divergence\n", "\n", "We show in colab how OTT and JAX can be used to compute automatically the Hessian of the {func}`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` w.r.t. the input variables, such as weights ``a`` or locations ``x``." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " !pip install -q git+https://github.com/ott-jax/ott@main" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "from ott.geometry import pointcloud\n", "from ott.solvers.linear import implicit_differentiation as implicit_lib\n", "from ott.tools import sinkhorn_divergence" ] }, { "cell_type": "markdown", "metadata": { "id": "1mTq29HkSFnU" }, "source": [ "Sample two random point clouds of dimension `dim`" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "0jfa6mSiWAw6" }, "outputs": [], "source": [ "def sample(n, m, dim):\n", " rngs = jax.random.split(jax.random.PRNGKey(0), 6)\n", " x = jax.random.uniform(rngs[0], (n, dim))\n", " y = jax.random.uniform(rngs[1], (m, dim))\n", " a = jax.random.uniform(rngs[2], (n,)) + 0.1\n", " b = jax.random.uniform(rngs[3], (m,)) + 0.1\n", " a = a / jnp.sum(a)\n", " b = b / jnp.sum(b)\n", " return a, x, b, y" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "79peUzQOVqcJ" }, "outputs": [], "source": [ "a, x, b, y = sample(15, 17, 3)" ] }, { "cell_type": "markdown", "metadata": { "id": "kOtW4xTTSJhg" }, "source": [ "As usual in JAX, we define a custom loss that outputs the quantity of interest, and is defined using relevant inputs as arguments, i.e. parameters against which we may want to differentiate. We add to `a` and `x` the ``implicit`` auxiliary flag which will be used to switch between unrolling and implicit differentiation of the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm (see this excellent [tutorial](http://implicit-layers-tutorial.org/implicit_functions/) for a deep dive on their differences).\n", "\n", "The loss outputs the Sinkhorn divergence between two point clouds." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "38KEITACSR1D" }, "outputs": [], "source": [ "def loss(a, x, implicit):\n", " return sinkhorn_divergence.sinkhorn_divergence(\n", " pointcloud.PointCloud,\n", " x,\n", " y, # this part defines geometry\n", " a=a,\n", " b=b, # this sets weights\n", " sinkhorn_kwargs={\n", " \"implicit_diff\": implicit_lib.ImplicitDiff(\n", " precondition_fun=lambda x: x\n", " ),\n", " \"use_danskin\": False,\n", " }, # to be used by the Sinkhorn algorithm\n", " ).divergence" ] }, { "cell_type": "markdown", "metadata": { "id": "tnrx9dMnVDxD" }, "source": [ "Let's parse the three lines in the call to {func}`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` above:\n", "\n", "- The first one defines the point cloud geometry between `x` and `y` that will define the cost matrix. Here we could have added details on `epsilon` regularization (or scheduler), as well as alternative definitions of the cost function (here assumed by default to be squared Euclidean distance). We stick to the default setting.\n", "\n", "- The second one sets the respective weight vectors `a` and `b`. Those are simply two histograms of size ``n`` and `m`, both sum to 1, in the so-called balanced setting.\n", "\n", "- The third one passes on arguments to the three {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` solvers that will be called, to compare ``x`` with `y`, `x` with `x` and `y` with `y` with their respective weights `a` and `b`. Rather than focusing on the several numerical options available to parameterize {class}`~ott.solvers.linear.sinkhorn.Sinkhorn`'s behavior, we instruct JAX on how it should differentiate the outputs of the sinkhorn algorithm. The `use_danskin` flag specifies whether the outputted potentials should be freezed when differentiating. Since we aim for 2nd order differentiation here, we must set this to ``False`` (if we wanted to compute gradients, ``True`` would have resulted in faster yet almost equivalent computations)." ] }, { "cell_type": "markdown", "metadata": { "id": "hjIAe7ducbaH" }, "source": [ "## Computing Hessians" ] }, { "cell_type": "markdown", "metadata": { "id": "StMRwYUJVuOY" }, "source": [ "Let's now plot Hessians of this output w.r.t. either `a` or `x`. \n", "\n", "- The Hessian w.r.t. `a` will be a $n \\times n$ matrix, with the convention that `a` has size $n$. \n", "- Because `x` is itself a matrix of 3D coordinates, the Hessian w.r.t. `x` will be a 4D tensor of size $n \\times 3 \\times n \\times 3$.\n", "\n", "To plot both Hessians, we loop on arg 0 or 1 of `loss`, and plot all (or part for `x`) of those Hessians, to check they match:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 939 }, "id": "JRBYQmh9WzLY", "outputId": "a8c38778-3887-4931-dd04-1e7915403c94" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- Time: Implicit Hessian w.r.t. a\n", "11.2 ms ± 24.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", "--- Time: Unrolled Hessian w.r.t. a\n", "11.2 ms ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", "--- Time: Implicit Hessian w.r.t. x\n", "35.2 ms ± 132 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", "--- Time: Unrolled Hessian w.r.t. x\n", "35.2 ms ± 289 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for arg in [0, 1]:\n", " # Compute Hessians using either unrolling or implicit differentiation.\n", " hess_loss_imp = jax.jit(\n", " jax.hessian(lambda a, x: loss(a, x, True), argnums=arg)\n", " )\n", " print(\"--- Time: Implicit Hessian w.r.t. \" + (\"a\" if arg == 0 else \"x\"))\n", " %timeit _ = hess_loss_imp(a, x).block_until_ready()\n", " hess_imp = hess_loss_imp(a, x)\n", "\n", " hess_loss_back = jax.jit(\n", " jax.hessian(lambda a, x: loss(a, x, False), argnums=arg)\n", " )\n", " print(\"--- Time: Unrolled Hessian w.r.t. \" + (\"a\" if arg == 0 else \"x\"))\n", " %timeit _ = hess_loss_back(a, x).block_until_ready()\n", " hess_back = hess_loss_back(a, x)\n", "\n", " # Since we are solving balanced OT problems, Hessians w.r.t. weights are\n", " # only defined up to the orthogonal space of 1s.\n", " # For that reason we remove that contribution and check the\n", " # resulting matrices are equal.\n", " if arg == 0:\n", " hess_imp -= jnp.mean(hess_imp, axis=1)[:, None]\n", " hess_back -= jnp.mean(hess_back, axis=1)[:, None]\n", " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))\n", " im = ax1.imshow(hess_imp if arg == 0 else hess_imp[0, 0, :, :])\n", " ax1.set_title(\n", " \"Implicit Hessian w.r.t. \" + (\"a\" if arg == 0 else \"x (1st slice)\")\n", " )\n", " fig.colorbar(im, ax=ax1)\n", " im = ax2.imshow(hess_back if arg == 0 else hess_back[0, 0, :, :])\n", " ax2.set_title(\n", " \"Unrolled Hessian w.r.t. \" + (\"a\" if arg == 0 else \"x (1st slice)\")\n", " )\n", " fig.colorbar(im, ax=ax2)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "Hessians.ipynb", "provenance": [] }, "kernelspec": { "display_name": "ott", "language": "python", "name": "ott" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 1 }