ott.solvers.quadratic.gw_barycenter.GromovWassersteinBarycenter.init_state
ott.solvers.quadratic.gw_barycenter.GromovWassersteinBarycenter.init_state#
- GromovWassersteinBarycenter.init_state(problem, bar_size, bar_init=None, a=None, seed=0)[source]#
Initialize the (fused) Gromov-Wasserstein barycenter state.
- Parameters
problem (
GWBarycenterProblem) – The barycenter problem.bar_size (
int) – Size of the barycenter.bar_init (
Union[Array,Tuple[Array,Array],None]) –Initial barycenter value. Can be one of the following:
None- randomly initialize the barycenter.jax.numpy.ndarray- barycenter cost matrix of shape[bar_size, bar_size]. Only used in the non-fused case.tupleofjax.numpy.ndarray- the 1st array corresponds to a cost matrix of shape[bar_size, bar_size], the 2nd array is a[bar_size, ndim_fused]feature matrix used in the fused case.
a (
Optional[Array]) – An array of shape[bar_size,]containing the barycenter weights.seed (
int) – Random seed used whenbar_init = None.
- Return type
- Returns
The initial barycenter state.