ott.tools.soft_sort.sort_with
ott.tools.soft_sort.sort_with#
- ott.tools.soft_sort.sort_with(inputs, criterion, topk=- 1, **kwargs)[source]#
Sort a multidimensional array according to a real valued criterion.
Given
batchvectors of dimension dim, to which, for each, a real valuecriterionis associated, this function producestopk(orbatchiftopkis set to -1, as by default) composite vectors of sizedimthat will be convex combinations of all vectors, ranked from smallest to largest criterion. Composite vectors with the largest indices will contain convex combinations of those vectors with highest criterion, vectors with smaller indices will contain combinations of vectors with smaller criterion.- Parameters
inputs (
Array) – the inputs as a jnp.ndarray[batch, dim].criterion (
Array) – the values according to which to sort the inputs. It has shape [batch, 1].topk (
int) – The number of outputs to keep.kwargs (
Any) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun, which will redistribute the values ininputsto lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn, used inPointCloud, that defines the ground cost function to transport frominputsto thenum_targetstarget values (squared Euclidean distance by default, seepointcloud.pyfor more details);epsilonvalues as well as other parameters to shape thesinkhornalgorithm.
- Return type
Array- Returns
A jnp.ndarray[batch | topk, dim].