ott.tools.soft_sort.ranks
ott.tools.soft_sort.ranks#
- ott.tools.soft_sort.ranks(inputs, axis=- 1, num_targets=None, **kwargs)[source]#
Apply the soft trank operator on input tensor.
- Parameters
inputs (
Array) – a jnp.ndarray<float> of any shape.axis (
int) – the axis on which to apply the soft ranks operator.num_targets (
Optional[int]) – num_targets defines the number of targets used to compute a composite ranks for each value ininputs: that soft rank will be a convex combination of values in [0,…,``(num_targets-2)/num_targets``,1] specified by the optimal transport between values ininputstowards those values. If not specified,num_targetsis set by default to be the size of the slices of the input that are sorted.kwargs (
Any) – keyword arguments passed on to lower level functions. Of interest to the user aresquashing_fun, which will redistribute the values ininputsto lie in [0,1] (sigmoid of whitened values by default) to solve the optimal transport problem;cost_fn, used inPointCloud, that defines the ground cost function to transport frominputsto thenum_targetstarget values (squared Euclidean distance by default, seepointcloud.pyfor more details);epsilonvalues as well as other parameters to shape thesinkhornalgorithm.
- Return type
Array- Returns
A jnp.ndarray<float> of the same shape as inputs, with the ranks.