ott.tools.soft_sort.sort
ott.tools.soft_sort.sort#
- ott.tools.soft_sort.sort(inputs, axis=- 1, topk=- 1, num_targets=None, **kwargs)[source]#
Apply the soft sort operator on a given axis of the input.
- Parameters
inputs (
Array) – jnp.ndarray<float> of any shape.axis (
int) – the axis on which to apply the operator.topk (
int) – if set to a positive value, the returned vector will only contain the topk values. This also reduces the complexity of soft sorting.num_targets (
Optional[int]) – if topk is not specified, num_targets defines the number of (composite) sorted values computed from the inputs (each value is a convex combination of values recorded in the inputs, provided in increasing order). If not specified,num_targetsis set by default to be the size of the slices of the input that are sorted, i.e. the number of composite sorted values is equal to that of the inputs that are sorted.kwargs (
Any) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun, which will redistribute the values ininputsto lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn, used inPointCloud, that defines the ground cost function to transport frominputsto thenum_targetstarget values (squared Euclidean distance by default, seepointcloud.pyfor more details);epsilonvalues as well as other parameters to shape thesinkhornalgorithm.
- Return type
Array- Returns
A jnp.ndarray of the same shape as the input with soft sorted values on the given axis.