ott.tools.gaussian_mixture.gaussian_mixture_pair.GaussianMixturePair#

class ott.tools.gaussian_mixture.gaussian_mixture_pair.GaussianMixturePair(gmm0, gmm1, epsilon=0.01, tau=1.0, lock_gmm1=False)[source]#

Pytree for a coupled pair of Gaussian mixture models.

Includes methods used in estimating an optimal pairing between GMM components using the Wasserstein-like method described in [Delon and Desolneux, 2020], as well as generalization that allows for the re-weighting of components.

The Delon & Desolneux paper above proposes fitting a pair of GMMs to a pair of point clouds in such a way that the sum of the log likelihood of the points minus a weighted penalty involving a Wasserstein-like distance between the GMMs. Their proposed algorithm involves using EM in which a balanced Sinkhorn algorithm is used to estimate a coupling between the GMMs at each step of EM.

Our generalization of this algorithm allows for a mismatch between the marginals of the coupling and the GMM component weights. This mismatch can be interpreted as components being re-weighted rather than being transported. We penalize re-weighting with a generalized KL-divergence penalty, and we give the option to use the unbalanced Sinkhorn algorithm rather than the balanced to compute the divergence between GMMs.

Methods

get_bures_geometry()

Get a Bures Geometry for the two GMMs.

get_cost_matrix()

Get matrix of W2^2 costs between all pairs of (gmm0, gmm1) components.

get_normalized_sinkhorn_coupling(sinkhorn_output)

Get the normalized coupling matrix for the specified Sinkhorn output.

get_sinkhorn(cost_matrix, **kwargs)

Get the output of Sinkhorn's method for a given cost matrix.

Attributes

dtype

epsilon

gmm0

gmm1

lock_gmm1

rho

tau

Parameters