ott.tools.gaussian_mixture.gaussian.Gaussian#

class ott.tools.gaussian_mixture.gaussian.Gaussian(loc, scale)[source]#

PyTree for a normal distribution.

Methods

covariance()

rtype

Array

f_potential(dest, points)

Optimal potential for W2 distance between Gaussians.

from_mean_and_cov(mean, cov)

Construct a Gaussian from a mean and covariance.

from_random(key, n_dimensions[, stdev_mean, ...])

Construct a random Gaussian.

from_samples(points[, weights])

Construct a Gaussian from weighted samples.

from_z(z)

rtype

Array

log_prob(x)

Log probability for a gaussian with a diagonal covariance.

sample(key, size)

Generate samples from the distribution.

to_z(x)

rtype

Array

transport(dest, points)

Transport points according to map between two Gaussian measures.

w2_dist(other)

Wasserstein distance W_2^2 to another Gaussian.

Attributes

loc

rtype

Array

n_dimensions

rtype

int

scale

rtype

ScaleTriL

Parameters
  • loc (jax.Array) –

  • scale (ott.tools.gaussian_mixture.scale_tril.ScaleTriL) –