ott.tools.gaussian_mixture.gaussian.Gaussian
ott.tools.gaussian_mixture.gaussian.Gaussian#
- class ott.tools.gaussian_mixture.gaussian.Gaussian(loc, scale)[source]#
PyTree for a normal distribution.
Methods
- rtype
Array
f_potential
(dest, points)Optimal potential for W2 distance between Gaussians.
from_mean_and_cov
(mean, cov)Construct a Gaussian from a mean and covariance.
from_random
(key, n_dimensions[, stdev_mean, ...])Construct a random Gaussian.
from_samples
(points[, weights])Construct a Gaussian from weighted samples.
from_z
(z)- rtype
Array
log_prob
(x)Log probability for a gaussian with a diagonal covariance.
sample
(key, size)Generate samples from the distribution.
to_z
(x)- rtype
Array
transport
(dest, points)Transport points according to map between two Gaussian measures.
w2_dist
(other)Wasserstein distance W_2^2 to another Gaussian.
Attributes
- rtype
Array
- rtype
- rtype
ScaleTriL
- Parameters
loc (jax.Array) –
scale (ott.tools.gaussian_mixture.scale_tril.ScaleTriL) –