ott.tools.gaussian_mixture.gaussian_mixture.GaussianMixture#

class ott.tools.gaussian_mixture.gaussian_mixture.GaussianMixture(loc, scale_params, component_weight_ob)[source]#

Pytree for a Gaussian Mixture model.

Methods

components()

Get a list of all GMM components.

conditional_log_prob(x)

Compute the component-conditional log probability of x.

from_mean_cov_component_weights(mean, cov, ...)

Construct a GMM from means, covariances, and component weights.

from_points_and_assignment_probs(points, ...)

Estimate a GMM from points and a set of component probabilities.

from_random(key, n_components, n_dimensions)

Construct a random GMM.

get_component(index)

Get the specified GMM component.

get_log_component_posterior(x)

Compute the posterior probability that x came from each component.

has_nans()

rtype

bool

log_component_weights()

rtype

Array

log_prob(x)

Compute the log probability of the observations x.

sample(key, size)

Generate samples from the distribution.

Attributes

cholesky

rtype

Array

component_weight_ob

rtype

Probabilities

component_weights

rtype

Array

covariance

rtype

Array

dtype

loc

rtype

Array

n_components

n_dimensions

scale_params

rtype

Array

Parameters
  • loc (jax.Array) –

  • scale_params (jax.Array) –

  • component_weight_ob (ott.tools.gaussian_mixture.probabilities.Probabilities) –