{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "qL_G7B4fovBH" }, "source": [ "# Grid geometry" ] }, { "cell_type": "markdown", "metadata": { "id": "y18hMXUWJxv6" }, "source": [ "In this tutorial, we cover how to instantiate and use {class}`~ott.geometry.grid.Grid` geometry. \n", "\n", "{class}`~ott.geometry.grid.Grid` is a geometry that is useful when the probability measures are supported on a $d$-dimensional cartesian grid, i.e. a cartesian product of $d$ lists of values, each list $i$ being of size $n_i$. The transportation cost between points in the grid is assumed to be separable, namely a sum of coordinate-wise cost functions, as in $\\text{cost}(x,y) = \\sum_{i=1}^d \\text{cost}_i(x_i, y_i)$ where $\\text{cost}_i: \\mathbb{R} \\times \\mathbb{R} \\rightarrow \\mathbb{R}$.\n", " \n", "The advantage of using {class}`~ott.geometry.grid.Grid` over {class}`~ott.geometry.pointcloud.PointCloud` for such cases is that the computational cost is $O(N^{(1+1/d)})$ instead of $O(N^2)$ where $N$ is the total number of points in the grid." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " !pip install -q git+https://github.com/ott-jax/ott@main" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "\n", "from ott.geometry import costs, grid, pointcloud\n", "from ott.problems.linear import linear_problem\n", "from ott.solvers.linear import sinkhorn" ] }, { "cell_type": "markdown", "metadata": { "id": "ChPtJxt0qInW" }, "source": [ "## Use `Grid` with the argument ```x```" ] }, { "cell_type": "markdown", "metadata": { "id": "ZW3w2Ys_ulct" }, "source": [ "In this example, the argument `x` is a list of $3$ vectors, of varying sizes $\\{n_1, n_2, n_3\\}$, that describe the locations of the grid. The resulting grid is the Cartesian product of these vectors. `a` and `b` are two histograms in a grid of size 5 x 6 x 7 that lies in the 3-dimensional hypercube. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "executionInfo": { "elapsed": 5339, "status": "ok", "timestamp": 1613817615565, "user": { "displayName": "Laetitia Papaxanthos", "photoUrl": "", "userId": "13824884068334195048" }, "user_tz": -60 }, "id": "Af0dnf1qqHv7" }, "outputs": [], "source": [ "rng = jax.random.PRNGKey(0)\n", "keys = jax.random.split(rng, 5)\n", "\n", "grid_size = (5, 6, 7)\n", "x = [\n", " jax.random.uniform(keys[0], (grid_size[0],)),\n", " jax.random.uniform(keys[1], (grid_size[1],)),\n", " jax.random.uniform(keys[2], (grid_size[2],)),\n", "]\n", "a = jax.random.uniform(keys[3], grid_size)\n", "b = jax.random.uniform(keys[4], grid_size)\n", "a = a.ravel() / jnp.sum(a)\n", "b = b.ravel() / jnp.sum(b)" ] }, { "cell_type": "markdown", "metadata": { "id": "GHrL_JjevsYN" }, "source": [ "Instantiate {class}`~ott.geometry.grid.Grid` and calculate the regularized optimal transport cost." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "executionInfo": { "elapsed": 2768, "status": "ok", "timestamp": 1613817618340, "user": { "displayName": "Laetitia Papaxanthos", "photoUrl": "", "userId": "13824884068334195048" }, "user_tz": -60 }, "id": "fAoTEGNFv5y7", "outputId": "3b639a18-92a8-4f42-aa1f-a6e26cce7b83" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Regularised optimal transport cost = 0.20520979166030884\n" ] } ], "source": [ "geom = grid.Grid(x=x, epsilon=0.1)\n", "prob = linear_problem.LinearProblem(geom, a=a, b=b)\n", "\n", "solver = sinkhorn.Sinkhorn()\n", "out = solver(prob)\n", "\n", "print(f\"Regularised optimal transport cost = {out.reg_ot_cost}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "96rC8WsdHQEf" }, "source": [ "## Use `Grid` with the argument ```grid_size```\n", "\n", "In this example, the grid is described as points regularly sampled in $[0, 1]$. `a` and `b` are two histograms in a grid of size 5 x 6 x 7 that lies in the 3-dimensional hypercube $[0, 1]^3$. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "executionInfo": { "elapsed": 394, "status": "ok", "timestamp": 1613817618751, "user": { "displayName": "Laetitia Papaxanthos", "photoUrl": "", "userId": "13824884068334195048" }, "user_tz": -60 }, "id": "XQNQFhe2pKR7" }, "outputs": [], "source": [ "rng = jax.random.PRNGKey(1)\n", "keys = jax.random.split(rng, 2)\n", "\n", "grid_size = (5, 6, 7)\n", "a = jax.random.uniform(keys[0], grid_size)\n", "b = jax.random.uniform(keys[1], grid_size)\n", "a = a.ravel() / jnp.sum(a)\n", "b = b.ravel() / jnp.sum(b)" ] }, { "cell_type": "markdown", "metadata": { "id": "OSM5LhbVz99j" }, "source": [ "Instantiate {class}`~ott.geometry.grid.Grid` and calculate the regularized optimal transport cost." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "executionInfo": { "elapsed": 2014, "status": "ok", "timestamp": 1613817620772, "user": { "displayName": "Laetitia Papaxanthos", "photoUrl": "", "userId": "13824884068334195048" }, "user_tz": -60 }, "id": "arLJWk-kklnX", "outputId": "77961d66-f220-4cc0-bb89-1240dcf6bb17" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Regularised optimal transport cost = 0.2816336154937744\n" ] } ], "source": [ "geom = grid.Grid(grid_size=grid_size, epsilon=0.1)\n", "prob = linear_problem.LinearProblem(geom, a=a, b=b)\n", "\n", "out = solver(prob)\n", "\n", "print(f\"Regularised optimal transport cost = {out.reg_ot_cost}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "73KzjNwpAb60" }, "source": [ "## Varies the cost function in each dimension\n", "Instead of the squared Euclidean distance, we will use a squared\n", "Mahalanobis distance, where the covariance matrix is diagonal. This example illustrates the possibility of choosing a cost function for each dimension.\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "executionInfo": { "elapsed": 646, "status": "ok", "timestamp": 1613817621425, "user": { "displayName": "Laetitia Papaxanthos", "photoUrl": "", "userId": "13824884068334195048" }, "user_tz": -60 }, "id": "ttOyhIWIAccB" }, "outputs": [], "source": [ "rng = jax.random.PRNGKey(1)\n", "keys = jax.random.split(rng, 2)\n", "\n", "grid_size = (5, 6)\n", "a = jax.random.uniform(keys[0], grid_size)\n", "b = jax.random.uniform(keys[1], grid_size)\n", "a = a.ravel() / jnp.sum(a)\n", "b = b.ravel() / jnp.sum(b)" ] }, { "cell_type": "markdown", "metadata": { "id": "8a6iU3t5Alr8" }, "source": [ "We want to use as covariance matrix for the Mahalanobis distance the diagonal 2x2 matrix, with $[1/2, 1]$ as diagonal. We create an additional {class}`~ott.geometry.costs.CostFn`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "executionInfo": { "elapsed": 279, "status": "ok", "timestamp": 1613817621708, "user": { "displayName": "Laetitia Papaxanthos", "photoUrl": "", "userId": "13824884068334195048" }, "user_tz": -60 }, "id": "2ilbOUD1Acwj" }, "outputs": [], "source": [ "@jax.tree_util.register_pytree_node_class\n", "class SqEuclideanTimes2(costs.CostFn):\n", " \"\"\"The cost function corresponding to the squared SqEuclidean distance times 2.\"\"\"\n", "\n", " def norm(self, x):\n", " return jnp.sum(x**2, axis=-1) * 2\n", "\n", " def pairwise(self, x, y):\n", " return -2 * jnp.sum(x * y) * 2\n", "\n", "\n", "cost_fns = [SqEuclideanTimes2(), costs.SqEuclidean()]" ] }, { "cell_type": "markdown", "metadata": { "id": "qPDNxISlAtUa" }, "source": [ "Instantiate {class}`~ott.geometry.grid.Grid` and calculate the regularized optimal transport cost." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "executionInfo": { "elapsed": 1810, "status": "ok", "timestamp": 1613817623524, "user": { "displayName": "Laetitia Papaxanthos", "photoUrl": "", "userId": "13824884068334195048" }, "user_tz": -60 }, "id": "0q_phX06Ao9d", "outputId": "767e0d21-16cc-402e-bec9-5f521f9723d0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Regularised optimal transport cost = 0.22414202988147736\n" ] } ], "source": [ "geom = grid.Grid(grid_size=grid_size, cost_fns=cost_fns, epsilon=0.1)\n", "prob = linear_problem.LinearProblem(geom, a=a, b=b)\n", "out = solver(prob)\n", "\n", "print(f\"Regularised optimal transport cost = {out.reg_ot_cost}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "mr9ogxkxHYsL" }, "source": [ "## Compare runtime between using `Grid` and `PointCloud`\n", "\n", "The squared euclidean distance is an example of separable distance for which it is possible to use {class}`~ott.geometry.grid.Grid` instead of {class}`~ott.geometry.pointcloud.PointCloud`. In this case, using {class}`~ott.geometry.grid.Grid` over {class}`~ott.geometry.pointcloud.PointCloud` as geometry in the context of regularised optimal transport presents a computational advantage, as the computational cost of applying a kernel in {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` steps is of the order of $O(N^{(1+1/d)})$ instead of the naive $O(N^2)$ complexity, where $N$ is the total number of points in the grid and $d$ its dimension.\n", "\n", "In this example, we can see that for the same grid size and points, the computational runtime of sinkhorn with {class}`~ott.geometry.grid.Grid` is smaller than with {class}`~ott.geometry.pointcloud.PointCloud`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "executionInfo": { "elapsed": 6373, "status": "ok", "timestamp": 1613817647644, "user": { "displayName": "Laetitia Papaxanthos", "photoUrl": "", "userId": "13824884068334195048" }, "user_tz": -60 }, "id": "IfARvic82UnF", "outputId": "89985d99-857f-4a4c-fa45-407d2245ef49" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.53 s ± 6.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", "Regularised optimal transport cost using Grid = 0.24500826001167297\n", "\n" ] } ], "source": [ "epsilon = 0.1\n", "grid_size = (50, 50, 50)\n", "\n", "rng = jax.random.PRNGKey(2)\n", "keys = jax.random.split(rng, 2)\n", "a = jax.random.uniform(keys[0], grid_size)\n", "b = jax.random.uniform(keys[1], grid_size)\n", "a = a.ravel() / jnp.sum(a)\n", "b = b.ravel() / jnp.sum(b)\n", "\n", "# Instantiates Grid\n", "geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon)\n", "prob_grid = linear_problem.LinearProblem(geometry_grid, a=a, b=b)\n", "\n", "x, y, z = np.mgrid[0 : grid_size[0], 0 : grid_size[1], 0 : grid_size[2]]\n", "xyz = jnp.stack(\n", " [\n", " jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),\n", " jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),\n", " jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),\n", " ]\n", ").transpose()\n", "# Instantiates PointCloud with `batch_size` argument\n", "geometry_pointcloud = pointcloud.PointCloud(\n", " xyz, xyz, epsilon=epsilon, batch_size=1024\n", ")\n", "prob_pointcloud = linear_problem.LinearProblem(geometry_pointcloud, a=a, b=b)\n", "\n", "%timeit solver(prob_grid).reg_ot_cost.block_until_ready()\n", "out_grid = solver(prob_grid)\n", "print(\n", " f\"Regularised optimal transport cost using Grid = {out_grid.reg_ot_cost}\\n\"\n", ")\n", "\n", "%timeit solver(prob_pointcloud).reg_ot_cost.block_until_ready()\n", "out_pointcloud = solver(prob_pointcloud)\n", "print(\n", " f\"Regularised optimal transport cost using Pointcloud = {out_pointcloud.reg_ot_cost}\"\n", ")" ] } ], "metadata": { "colab": { "collapsed_sections": [], "last_runtime": { "build_target": "", "kind": "private" }, "name": "Grid Geometry for OTT.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 1 }