ott.solvers.quadratic.gromov_wasserstein.solve
ott.solvers.quadratic.gromov_wasserstein.solve#
- ott.solvers.quadratic.gromov_wasserstein.solve(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss='sqeucl', tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, ranks=- 1, tolerances=0.01, **kwargs)[source]#
Solve quadratic regularized OT problem.
The quadratic loss of a single OT matrix is assumed to have the form given in [Peyré et al., 2016], eq. 4.
The two geometries below parameterize matrices \(C\) and \(\bar{C}\) in that equation. The function \(L\) (of two real values) in that equation is assumed to match the form given in eq. 5., with our notations:
\[L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)\]- Parameters
geom_xx (
Geometry) – Ground geometry of the first space.geom_yy (
Geometry) – Ground geometry of the second space.geom_xy (
Optional[Geometry]) – Geometry defining the linear penalty term for Fused Gromov-Wasserstein. If None, the problem reduces to a plain Gromov-Wasserstein problem.fused_penalty (
float) – multiplier of the linear term in Fused Gromov-Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem. Ignored ifgeom_xyis not specified.scale_cost (
Union[bool,float,str,None]) –option to rescale the cost matrices:
a (
Optional[Array]) – array representing the probability weights of the samples fromgeom_xx. If None, it will be uniform.b (
Optional[Array]) – array representing the probability weights of the samples fromgeom_yy. If None, it will be uniform.loss (
Union[Literal[‘sqeucl’, ‘kl’],GWLoss]) – a 2-tuple of 2-tuples of Callable. The first tuple is the linear part of the loss. The second one is the quadratic part (quad1, quad2). By default, the loss is set as the 4 functions representing the squared Euclidean loss, and this property is taken advantage of in subsequent computations. Alternatively, KL loss can be specified in no less optimized way.tau_a (
Optional[float]) – if < 1.0, defines how much unbalanced the problem is on the first marginal.tau_b (
Optional[float]) – if < 1.0, defines how much unbalanced the problem is on the second marginal.gw_unbalanced_correction (
bool) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise,tau_aandtau_bonly affect the inner Sinkhorn loop.ranks (
Union[int,Tuple[int,...]]) – Ranks of the cost matrices, seeto_LRCGeometry(). Used when geometries are notPointCloudwith ‘sqeucl’ cost function. If -1, the geometries will not be converted to low-rank. Iftuple, it specifies the ranks ofgeom_xx,geom_yyandgeom_xy, respectively. Ifint, rank is shared across all geometries.tolerances (
Union[float,Tuple[float,...]]) – Tolerances used when converting geometries to low-rank. Used when geometries are notPointCloudwith ‘sqeucl’ cost. Iffloat, it is shared across all geometries.kwargs (
Any) – Keyword arguments forGromovWasserstein.
- Return type
- Returns
Gromov-Wasserstein output.