{ "cells": [ { "cell_type": "markdown", "id": "bc14beac", "metadata": {}, "source": [ "# Wasserstein Barycenters of Gaussian Mixture Models" ] }, { "cell_type": "markdown", "id": "e7152202", "metadata": {}, "source": [ "In this notebook we demonstrate how we can compute {class}`~ott.solvers.linear.continuous_barycenter.WassersteinBarycenter` of mixtures of {class}`~ott.tools.gaussian_mixture.gaussian.Gaussian`s." ] }, { "cell_type": "code", "execution_count": null, "id": "88cd6ac2", "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " !pip install -q git+https://github.com/ott-jax/ott@main" ] }, { "cell_type": "code", "execution_count": 1, "id": "d2799c24", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "\n", "import matplotlib.pyplot as plt\n", "from matplotlib.colors import LogNorm\n", "\n", "from ott.geometry import costs\n", "from ott.problems.linear import barycenter_problem\n", "from ott.solvers.linear import continuous_barycenter\n", "from ott.tools.gaussian_mixture import gaussian_mixture" ] }, { "cell_type": "markdown", "id": "603a2cd7", "metadata": {}, "source": [ "## Generate the Gaussian Mixture Models" ] }, { "cell_type": "markdown", "id": "f867aca0", "metadata": {}, "source": [ "First, we randomly generate the $d$ dimensional mixtures of Gaussians. \n", "\n", "For each GMM we chose a ridge $c \\in \\mathbb{R}^d$ for the means of the components and a parameter $s>0$ that affects the covariance matrices of the components.\n", "\n", "The means of the components of a GMM are generated as:\n", "\n", "$$\\mu = mu + c,$$\n", "\n", "where $u$ is obtained as $d$ samples from the normal distribution, $c \\in \\mathbb{R}^d$ and $m=0.1\\frac{1}{d} \\sum_{i=1}^d c_i$. \n", "\n", "The covariance matrices are generated using the eigendecomposition:\n", "\n", "$$C=U \\Lambda U^{\\top},$$\n", "\n", "where $U \\in \\mathbb{R}^{d \\times d}$ is a randomly generated orthogonal matrix and $\\Lambda$ is the diagonal matrix with elements of the diagonal the randomly generated eigenvalues:\n", "\n", "$$\\lambda = s e^h,$$\n", "\n", "where $h$ is obtained as $d$ samples from the normal distribution and $s>0$. " ] }, { "cell_type": "markdown", "id": "c9ce4856", "metadata": {}, "source": [ "The weights of the $K$ components of the GMM are generated as:\n", "\n", "$$w_k = \\frac{e^{a_k}}{\\sum_{k=1}^K e^{a_k}} \\forall k,$$\n", "\n", "where $a_k=0.1*v_K$ and $v_k$ follows the normal distribution." ] }, { "cell_type": "markdown", "id": "0e25d363", "metadata": {}, "source": [ "We generate three GMMs, each composed of a different number of components. " ] }, { "cell_type": "code", "execution_count": 4, "id": "b9ccbb51", "metadata": {}, "outputs": [], "source": [ "dim = 2 # the dimension of the Gaussians\n", "n_components = (2, 3, 5) # the number of components of the GMMs\n", "# the number of GMMs whose barycenter will be computed\n", "n_gmms = len(n_components)\n", "epsilon = 0.1 # the entropy regularization parameter" ] }, { "cell_type": "markdown", "id": "89a59ccb", "metadata": {}, "source": [ "The barycentric weights determine how much each {class}`~ott.tools.gaussian_mixture.gaussian_mixture.GaussianMixture` will contribute to the {class}`~ott.solvers.linear.continuous_barycenter.WassersteinBarycenter` computation. Since these weights must sum to one, we generate them by sampling Dirichlet random values. Larger values of the concentration parameter $\\alpha$ lead to barycentric weights that are more uniform. Smaller values of $\\alpha$ will result to certain GMMs contributing significantly more than others to the barycenter computation. " ] }, { "cell_type": "code", "execution_count": 5, "id": "53bf20b1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] } ], "source": [ "# generate the pseudo-random keys that will be needed\n", "key = jax.random.PRNGKey(seed=0)\n", "keys = jax.random.split(key, num=3)" ] }, { "cell_type": "code", "execution_count": 6, "id": "d48c7e8f", "metadata": {}, "outputs": [], "source": [ "alpha = 50.0 # the concentration parameter of Dirichlet" ] }, { "cell_type": "code", "execution_count": 7, "id": "bc7f8af4", "metadata": {}, "outputs": [], "source": [ "barycentric_weights = jax.random.dirichlet(\n", " keys[0], alpha=jnp.ones(n_gmms) * alpha\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "id": "3e2b2cd6", "metadata": {}, "outputs": [], "source": [ "# Create the seeds for the random generation of each measure.\n", "seeds = jax.random.randint(keys[1], shape=(n_gmms,), minval=0, maxval=100)" ] }, { "cell_type": "markdown", "id": "12a43944", "metadata": {}, "source": [ "We set the offsets $c$ for each GMM to be different so that they can be easily visualized." ] }, { "cell_type": "code", "execution_count": 9, "id": "2cbb4efe", "metadata": {}, "outputs": [], "source": [ "# Offsets for the means of each GMM\n", "cs = jnp.array([[-20, -15], [60, -15], [50, 65]])" ] }, { "cell_type": "code", "execution_count": 10, "id": "3c2d0cec", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(3, 2)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cs.shape" ] }, { "cell_type": "code", "execution_count": 11, "id": "0c7829ff", "metadata": {}, "outputs": [], "source": [ "ms = 0.1 * jnp.mean(cs, axis=1)" ] }, { "cell_type": "code", "execution_count": 12, "id": "86ab5b36", "metadata": {}, "outputs": [], "source": [ "# parameter that controls the covariance matrices\n", "ss = jnp.array([4, 3, 5])" ] }, { "cell_type": "code", "execution_count": 13, "id": "9bd26384", "metadata": {}, "outputs": [], "source": [ "assert cs.shape[0] == n_gmms\n", "assert ss.size == n_gmms\n", "assert seeds.size == n_gmms\n", "assert len(n_components) == n_gmms\n", "assert jnp.mean(cs, axis=1).all() > 0\n", "assert ss.all() > 0" ] }, { "cell_type": "code", "execution_count": 14, "id": "82a461e5", "metadata": {}, "outputs": [], "source": [ "gmm_generators = [\n", " gaussian_mixture.GaussianMixture.from_random(\n", " jax.random.PRNGKey(seeds[i]),\n", " n_components=n_components[i],\n", " n_dimensions=dim,\n", " stdev_cov=ss[i],\n", " stdev_mean=ms[i],\n", " ridge=cs[i],\n", " )\n", " for i in range(n_gmms)\n", "]" ] }, { "cell_type": "code", "execution_count": 15, "id": "1c7566e7", "metadata": {}, "outputs": [], "source": [ "# get the means and covariances of the GMMs\n", "means_covs = [\n", " (gmm_generators[i].loc, gmm_generators[i].covariance) for i in range(n_gmms)\n", "]" ] }, { "cell_type": "code", "execution_count": 16, "id": "6cafd87d", "metadata": {}, "outputs": [], "source": [ "means_and_covs_to_x = jax.vmap(costs.mean_and_cov_to_x, in_axes=[0, 0, None])" ] }, { "cell_type": "code", "execution_count": 17, "id": "5d746c32", "metadata": {}, "outputs": [], "source": [ "# stack the concatenated means and (raveled) covariances of the pointclouds\n", "ys = jnp.vstack(\n", " means_and_covs_to_x(means_covs[i][0], means_covs[i][1], dim)\n", " for i in range(n_gmms)\n", ")" ] }, { "cell_type": "code", "execution_count": 18, "id": "eb718f6e", "metadata": {}, "outputs": [], "source": [ "# get the weights of the randomly generated GMMs\n", "weights = [gmm_generators[i].component_weight_ob.probs() for i in range(n_gmms)]" ] }, { "cell_type": "code", "execution_count": 19, "id": "91700d09", "metadata": {}, "outputs": [], "source": [ "# stack the weights of the GMMs\n", "bs = jnp.hstack(jnp.asarray(weights[i]) for i in range(n_gmms))" ] }, { "cell_type": "markdown", "id": "7a87beb4", "metadata": {}, "source": [ "## Compute the Wasserstein barycenter of the GMMs\n", "\n", "We can now compute the barycenter of the input GMMs. We determine the size of the barycenter and solve the barycenter problem. We must ensure that the initialization of the {class}`~ott.problems.linear.barycenter_problem.BarycenterProblem` is such that its covariance matrices are positive semidefinite. We therefore initialize the barycenter as a random {class}`~ott.tools.gaussian_mixture.gaussian_mixture.GaussianMixture`." ] }, { "cell_type": "code", "execution_count": 20, "id": "1c178f92", "metadata": {}, "outputs": [], "source": [ "# determine the size of the barycenter.\n", "bar_size = 6" ] }, { "cell_type": "code", "execution_count": 21, "id": "d4bfa77c", "metadata": {}, "outputs": [], "source": [ "gmm_generator = gaussian_mixture.GaussianMixture.from_random(\n", " keys[2], n_components=bar_size, n_dimensions=dim\n", ")\n", "\n", "x_init_means = gmm_generator.loc\n", "x_init_covs = gmm_generator.covariance\n", "\n", "x_init = means_and_covs_to_x(x_init_means, x_init_covs, dim)" ] }, { "cell_type": "code", "execution_count": 22, "id": "85428ed6", "metadata": {}, "outputs": [], "source": [ "# create an instance of the Bures cost class.\n", "b_cost = costs.Bures(dimension=dim)" ] }, { "cell_type": "code", "execution_count": 23, "id": "bd016819", "metadata": {}, "outputs": [], "source": [ "# create a barycenter problem.\n", "bar_p = barycenter_problem.BarycenterProblem(\n", " y=ys,\n", " b=bs,\n", " weights=barycentric_weights,\n", " num_per_segment=n_components,\n", " cost_fn=b_cost,\n", " epsilon=epsilon,\n", ")" ] }, { "cell_type": "code", "execution_count": 24, "id": "5dba3154", "metadata": {}, "outputs": [], "source": [ "# create a Wasserstein barycenter solver.\n", "solver = continuous_barycenter.WassersteinBarycenter(lse_mode=True)" ] }, { "cell_type": "code", "execution_count": 25, "id": "5702e303", "metadata": {}, "outputs": [], "source": [ "# compute the barycenter.\n", "out = solver(bar_p, bar_size=bar_size, x_init=x_init)\n", "barycenter = out.x" ] }, { "cell_type": "markdown", "id": "587f463a", "metadata": {}, "source": [ "Now that we have computed the barycenter, we can extract the means and the covariances of its components." ] }, { "cell_type": "code", "execution_count": 26, "id": "1162408b", "metadata": {}, "outputs": [], "source": [ "# extract the means and covariances of the barycenter.\n", "means_bary, covs_bary = costs.x_to_means_and_covs(barycenter, dim)" ] }, { "cell_type": "markdown", "id": "80741af9", "metadata": {}, "source": [ "## Visualize the results" ] }, { "cell_type": "markdown", "id": "742a6785", "metadata": {}, "source": [ "We consider a discretization grid in 2D in order to plot the negative probabilities of points under the considered GMMs. " ] }, { "cell_type": "code", "execution_count": 27, "id": "e2d51265", "metadata": {}, "outputs": [], "source": [ "# create the grid\n", "x1 = np.linspace(-30.0, 90.0)\n", "x2 = np.linspace(-30.0, 90.0)\n", "x, y = np.meshgrid(x1, x2)\n", "grid = np.array([x.ravel(), y.ravel()]).T" ] }, { "cell_type": "code", "execution_count": 28, "id": "101063de", "metadata": {}, "outputs": [], "source": [ "# compute the negative log probabilities of the GMMs at the grid.\n", "n_log_probs = jnp.asarray(\n", " [-gmm_generators[i].log_prob(grid) for i in range(n_gmms)]\n", ")" ] }, { "cell_type": "code", "execution_count": 29, "id": "1f779128", "metadata": {}, "outputs": [], "source": [ "weights_bary = jnp.ones(bar_size) / bar_size" ] }, { "cell_type": "markdown", "id": "49af2e5f", "metadata": {}, "source": [ "We now create a {class}`~ott.tools.gaussian_mixture.gaussian_mixture.GaussianMixture` object using the means and covarinces of the computed barycenter." ] }, { "cell_type": "code", "execution_count": 30, "id": "67b141a8", "metadata": {}, "outputs": [], "source": [ "# compute the negative log probabilities at the grid under the barycenter GMM.\n", "gmm_generator_bary = (\n", " gaussian_mixture.GaussianMixture.from_mean_cov_component_weights(\n", " mean=means_bary, cov=covs_bary, component_weights=weights_bary\n", " )\n", ")\n", "\n", "# compute the negative log probabilities of the barycenter at the grid.\n", "n_log_prob_bary = -gmm_generator_bary.log_prob(grid)" ] }, { "cell_type": "markdown", "id": "b249681b", "metadata": {}, "source": [ "We visualize the three GMMs and their barycenter by plotting the negative log probabilities at the considered grid." ] }, { "cell_type": "code", "execution_count": 31, "id": "c7403b2f", "metadata": {}, "outputs": [], "source": [ "# reshape the log probabilities in order to plot.\n", "n_log_probs = n_log_probs.reshape((n_gmms, x.shape[0], x.shape[1]))\n", "n_log_prob_bary = n_log_prob_bary.reshape((x.shape[0], x.shape[1]))" ] }, { "cell_type": "code", "execution_count": 32, "id": "1258a9ca", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 10))\n", "for i in range(n_gmms):\n", " _ = plt.contour(\n", " x,\n", " y,\n", " n_log_probs[i, :, :],\n", " norm=LogNorm(vmin=1.0, vmax=10.0),\n", " levels=np.logspace(0, 1, 10),\n", " )\n", "_ = plt.contour(\n", " x,\n", " y,\n", " n_log_prob_bary,\n", " norm=LogNorm(vmin=1.0, vmax=10.0),\n", " levels=np.logspace(0, 1, 10),\n", ")\n", "plt.annotate(\"First GMM\", (-25, 0), fontsize=16)\n", "plt.annotate(\"Second GMM\", (50, -8), fontsize=16)\n", "plt.annotate(\"Third GMM\", (40, 78), fontsize=16)\n", "plt.annotate(\"Barycenter\", (20, 30), fontsize=16)\n", "\n", "plt.title(\"Three GMMs and their barycenter\", fontsize=20)\n", "plt.axis(\"tight\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }