Softjax operators¤
Helper functions¤
softjax.sigmoidal(x: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['smooth', 'c0', 'c1', '_c1_pnorm', 'c2', '_c2_pnorm'] = 'smooth') -> Float[Array, '...']
¤
Sigmoidal functions defining a characteristic S-shaped curve.
Arguments:
x: Input Array.softness: Softness of the function, should be larger than zero.mode: Choice of smoothing family for the surrogate step.smooth: C∞ smooth (based on logistic/softmax/entropic regularizer). Smooth sigmoidal based on the logistic function.c0: C0 continuous (based on euclidean/L2 regularizer). Continuous sigmoidal based on a piecewise quadratic polynomial.c1: C1 differentiable (cubic Hermite). Differentiable sigmoidal based on a piecewise cubic polynomial.c2: C2 twice differentiable (quintic Hermite). Twice differentiable sigmoidal based on a piecewise quintic polynomial.
Returns:
SoftBool of same shape as x (Array with values in [0, 1]).
softjax.softrelu(x: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['smooth', 'c0', 'c1', '_c1_pnorm', 'c2', '_c2_pnorm'] = 'smooth', gated: bool = False) -> jax.Array
¤
Family of soft relaxations to ReLU.
Arguments:
x: Input Array.softness: Softness of the function, should be larger than zero.mode: Choice ofsoftjax.sigmoidalsmoothing mode, options are "smooth", "c0", "c1", "_c1_pnorm", "c2", or "_c2_pnorm".gated: If True, uses the 'gated' versionx * sigmoidal(x). If False, uses the integral of the sigmoidal.
Returns:
Result of applying soft elementwise ReLU to x.
Elementwise operators¤
softjax.abs(x: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth') -> jax.Array
¤
Performs a soft version of jax.numpy.abs.
Arguments:
x: Input Array of any shape.softness: Softness of the function, should be larger than zero.mode: Projection mode. "hard" returns the exact absolute value. Otherwise usessoftjax.sigmoidal-based "smooth", "c0", "c1", or "c2" relaxations.
Returns:
Result of applying soft elementwise absolute value to x.
softjax.clip(x: jax.Array, a: jax.Array, b: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', gated: bool = False) -> jax.Array
¤
Performs a soft version of jax.numpy.clip.
Implemented via two softjax.softrelu calls.
Arguments:
x: Input Array of any shape.a: Lower bound scalar.b: Upper bound scalar.softness: Softness of the function, should be larger than zero.mode: If "hard", appliesjnp.clip. Otherwise usessoftjax.softrelu-based "smooth", "c0", "c1", or "c2" relaxations.gated: Seesoftjax.softreludocumentation.
Returns:
Result of applying soft elementwise clipping to x.
softjax.heaviside(x: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth') -> Float[Array, '...']
¤
Performs a soft version of jax.numpy.heaviside(x,0.5).
Arguments:
x: Input Array of any shape.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact Heaviside step. Otherwise usessoftjax.sigmoidal-based "smooth", "c0", "c1", "c2" relaxations.
Returns:
SoftBool of same shape as x (Array with values in [0, 1]), relaxing the elementwise Heaviside step function.
softjax.relu(x: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', gated: bool = False) -> jax.Array
¤
Performs a soft version of jax.nn.relu.
Arguments:
x: Input Array of any shape.softness: Softness of the function, should be larger than zero.mode: If "hard", appliesjax.nn.relu. Otherwise usessoftjax.softreluwith "smooth", "c0", "c1", or "c2" relaxations.gated: Seesoftjax.softreludocumentation.
Returns:
Result of applying soft elementwise ReLU to x.
softjax.round(x: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', neighbor_radius: int = 5) -> jax.Array
¤
Performs a soft version of jax.numpy.round.
Arguments:
x: Input Array of any shape.softness: Softness of the function, should be larger than zero.mode: If "hard", appliesjnp.round. Otherwise uses a sigmoidal-based relaxation based on the algorithm described in Smooth Approximations of the Rounding Function. This function thereby inherits the differentsoftjax.sigmoidalmodes "smooth", "c0", "c1", or "c2".neighbor_radius: Number of neighbors on each side of the floor value to consider for the soft rounding.
Returns:
Result of applying soft elementwise rounding to x.
softjax.sign(x: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth') -> jax.Array
¤
Performs a soft version of jax.numpy.sign.
Arguments:
x: Input Array of any shape.softness: Softness of the function, should be larger than zero.mode: If "hard", returnsjnp.sign. Otherwise usessoftjax.sigmoidal-based "smooth", "c0", "c1", or "c2" relaxations.
Returns:
Result of applying soft elementwise sign to x.
Array-valued operators¤
softjax.argmax(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'sorting_network'] = 'softsort', standardize: bool = True, ot_kwargs: dict | None = None) -> Float[Array, '...']
¤
Performs a soft version of jax.numpy.argmax of x along the specified axis.
Arguments:
x: Input Array of shape (..., n, ...).axis: The axis along which to compute the argmax. If None, the input Array is flattened before computing the argmax.keepdims: If True, keeps the reduced dimension as a singleton {1}.softness: Softness of the function, should be larger than zero.mode: Type of regularizer in the projection operators.hard: Returns the result of jnp.argmax with a one-hot encoding of the indices.smooth: C∞ smooth (based on logistic/softmax/entropic regularizer), computed in closed-form via a softmax operation.c0: C0 continuous (based on euclidean/L2 regularizer), computed via the algorithm in Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application.c1: C1 differentiable (p=3/2 p-norm), computed in closed form via quadratic formula.c2: C2 twice differentiable (p=4/3 p-norm), computed in closed form via Cardano's method.
method: Method to compute the soft argmax. All approaches were originally proposed for the smooth mode, we extend them to the c0,c1,c2 modes as well.ot: Computes the max element via optimal transport projection onto a 2-point support.softsort: Computes the max element of the "SoftSort" operator from SoftSort: A Continuous Relaxation for the argsort Operator. Reduces to projectingxonto the unit simplex.neuralsort: Computes the max element of the "NeuralSort" operator from Stochastic Optimization of Sorting Networks via Continuous Relaxations.
standardize: If True, standardizes and squashes the inputxalong the specified axis before applying the softargmax operation. This can improve numerical stability and performance, especially when the values inxvary widely in scale.ot_kwargs: Additional optional keyword arguments to pass to the OT projection operator, e.g., to control the number of max iterations or tolerance.
Returns:
A SoftIndex of shape (..., {1}, ..., [n]) (positive Array which sums to 1 over the last dimension). Represents the probability of an index corresponding to the argmax along the specified axis.
Usage
This function can be used as a differentiable relaxation to jax.numpy.argmax, enabling backpropagation through index selection steps in neural networks or optimization routines. However, note that the output is not a discrete index but a SoftIndex, which is a distribution over indices. Therefore, functions which operate on indices have to be adjusted accordingly to accept a SoftIndex, see e.g. softjax.max for an example of using softjax.take_along_axis to retrieve the soft maximum value via the SoftIndex.
Difference to jax.nn.softmax
Note that softjax.argmax in smooth mode is not fully equivalent to jax.nn.softmax because it moves the probability dimension into the last axis (this is a convention in the SoftIndex data type).
softjax.max(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'softsort', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True) -> jax.Array
¤
Performs a soft version of jax.numpy.max of x along the specified axis.
For methods other than fast_soft_sort and sorting_network, implemented as softjax.argmax followed by softjax.take_along_axis, see respective documentations for details.
For fast_soft_sort and sorting_network, uses softjax.sort to compute soft sorted values and retrieves the maximum as the first element. See softjax.sort for method details.
Extra Arguments:
gated_grad: IfFalse, stops the gradient flow through the soft index. True gives gated 'SiLU-style' gradients, while False gives integrated 'Softplus-style' gradients.
Returns:
Array of shape (..., {1}, ...) representing the soft maximum of x along the specified axis.
softjax.argmin(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'sorting_network'] = 'softsort', standardize: bool = True, ot_kwargs: dict | None = None) -> Float[Array, '...']
¤
Performs a soft version of jax.numpy.argmin of x along the specified axis.
Implemented as softjax.argmax on -x, see respective documentation for details.
softjax.min(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'softsort', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True) -> jax.Array
¤
Performs a soft version of jax.numpy.min of x along the specified axis.
Implemented as -softjax.max on -x, see respective documentation for details.
Returns:
Array of shape (..., {1}, ...) representing the soft minimum of x along the specified axis.
softjax.argquantile(x: jax.Array, q: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'sorting_network'] = 'neuralsort', quantile_method: Literal['linear', 'lower', 'higher', 'nearest', 'midpoint'] = 'linear', standardize: bool = True, ot_kwargs: dict | None = None) -> Float[Array, '...']
¤
Performs a soft version of jax.numpy.quantile of x along the specified axis.
Arguments:
x: Input Array of shape (..., n, ...).q: Scalar quantile or 1-D Array of quantiles in [0, 1]. When a 1-D array of length k is passed, the q dimension is prepended to the output shape.axis: The axis along which to compute the argquantile. If None, the input Array is flattened before computing the argquantile.keepdims: If True, keeps the reduced dimension as a singleton {1}.softness: Softness of the function, should be larger than zero.mode: Type of regularizer in the projection operators.hard: Returns a one/two-hot encoding of the indices corresponding to the jax.numpy.quantile definitions.smooth: C∞ smooth (based on logistic/softmax/entropic regularizer).- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via a softmax operation. - For optimal transport (
otmethod), transport plan is computed via Sinkhorn iterations (see Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances).
- For unit simplex projection (
c0: C0 continuous (based on euclidean/L2 regularizer).- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via the algorithm in Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application. - For optimal transport (
otmethod), transport plan is computed via LBFGS (see Smooth and Sparse Optimal Transport).
- For unit simplex projection (
c1/c2: C1 differentiable / C2 twice differentiable. Similar toc0, but using p-norm regularizers with p=3/2 and p=4/3, respectively.
method: Method to compute the soft argquantile. All approaches were originally proposed for the smooth mode, we extend them to the c0,c1,c2 modes as well.ot: Uses a variation of the soft quantile approach in Differentiable Ranks and Sorting using Optimal Transport, which is adapted to converge to the jax quantile definitions for small softness. Depending on the quantile_method, either a lower and upper quantile are computed and combined, or just a single quantile is computed. Intuition: The sorted elements are selected by specifying 4 or 3 "anchors" and then transporting the upper/lower quantile values to the appropriate anchors. Note: Inaccurate for smallsinkhorn_max_iter(can be passed as keyword argument), but can be very slow for largesinkhorn_max_iter.softsort: Computes the upper and lower quantiles via the "SoftSort" operator from SoftSort: A Continuous Relaxation for the argsort Operator. Note: Can introduce gradient discontinuities when elements inxare not unique, but is much faster than OT-based method.neuralsort: Computes the upper and lower quantiles via the "NeuralSort" operator from Stochastic Optimization of Sorting Networks via Continuous Relaxations.
quantile_method: Method to compute the quantile, following the options in jax.numpy.quantile.standardize: If True, standardizes and squashes the inputxalong the specified axis before applying the softargquantile operation. This can improve numerical stability and performance, especially when the values inxvary widely in scale.ot_kwargs: Additional optional keyword arguments to pass to the OT projection operator, e.g., to control the number of max iterations or tolerance.
Returns:
A SoftIndex of shape (..., {1}, ..., [n]) for scalar q, or (k, ..., {1}, ..., [n]) for vector q of length k (q dimension prepended). Positive Array which sums to 1 over the last dimension. It represents a distribution over values in x being the q-quantile along the specified axis.
softjax.quantile(x: jax.Array, q: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'neuralsort', quantile_method: Literal['linear', 'lower', 'higher', 'nearest', 'midpoint'] = 'linear', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True) -> jax.Array
¤
Performs a soft version of jax.numpy.quantile of x along the specified axis.
For methods other than fast_soft_sort and sorting_network, implemented as softjax.argquantile followed by softjax.take_along_axis, see respective documentations for details.
For fast_soft_sort and sorting_network, uses softjax.sort to compute soft sorted values, then retrieves the quantile as a combination of the appropriate elements depending on the quantile method. See softjax.sort for method details.
Extra Arguments:
gated_grad: IfFalse, stops the gradient flow through the soft index. True gives gated 'SiLU-style' gradients, while False gives integrated 'Softplus-style' gradients.
Returns:
Array of shape (..., {1}, ...) for scalar q, or (k, ..., {1}, ...) for vector q of length k (q dimension prepended). Represents the soft q-quantile of x along the specified axis.
softjax.argpercentile(x: jax.Array, p: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'sorting_network'] = 'neuralsort', quantile_method: Literal['linear', 'lower', 'higher', 'nearest', 'midpoint'] = 'linear', standardize: bool = True, ot_kwargs: dict | None = None) -> Float[Array, '...']
¤
Computes the soft p-argpercentile of x along the specified axis.
Implemented as softjax.argquantile with q=p/100, see respective documentation for details.
softjax.percentile(x: jax.Array, p: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'neuralsort', quantile_method: Literal['linear', 'lower', 'higher', 'nearest', 'midpoint'] = 'linear', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True) -> jax.Array
¤
Performs a soft version of jax.numpy.percentile of x along the specified axis.
Implemented as softjax.quantile with q=p/100, see respective documentation for details.
softjax.argmedian(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'sorting_network'] = 'neuralsort', standardize: bool = True, ot_kwargs: dict | None = None) -> Float[Array, '...']
¤
Computes the soft argmedian of x along the specified axis.
Implemented as softjax.argquantile with q=0.5, see respective documentation for details.
softjax.median(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'neuralsort', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True) -> jax.Array
¤
Performs a soft version of jax.numpy.median of x along the specified axis.
Implemented as softjax.quantile with q=0.5, see respective documentation for details.
softjax.argsort(x: jax.Array, axis: int | None = None, descending: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'sorting_network'] = 'neuralsort', standardize: bool = True, ot_kwargs: dict | None = None) -> Float[Array, '...']
¤
Performs a soft version of jax.numpy.argsort of x along the specified axis.
Arguments:
x: Input Array of shape (..., n, ...).axis: The axis along which to compute the argsort operation. If None, the input Array is flattened before computing the argsort.descending: If True, sorts in descending order.softness: Softness of the function, should be larger than zero.mode: Type of regularizer in the projection operators.hard: Returns the result of jnp.argsort with a one-hot encoding of the indices.smooth: C∞ smooth (based on logistic/softmax/entropic regularizer).- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via a softmax operation. - For optimal transport (
otmethod), transport plan is computed via Sinkhorn iterations (see Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances).
- For unit simplex projection (
c0: C0 continuous (based on euclidean/L2 regularizer).- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via the algorithm in Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application. - For optimal transport (
otmethod), transport plan is computed via LBFGS (see Smooth and Sparse Optimal Transport).
- For unit simplex projection (
c1/c2: C1 differentiable / C2 twice differentiable. Similar toc0, but using p-norm regularizers with p=3/2 and p=4/3, respectively.
method: Method to compute the soft argsort. All approaches were originally proposed for the smooth mode, we extend them to the c0,c1,c2 modes as well.ot: Uses the approach in Differentiable Ranks and Sorting using Optimal Transport. Intuition: The sorted elements are selected by specifying n "anchors" and then transporting the ith-largest value to the ith-largest anchor. Note: Inaccurate for smallsinkhorn_max_iter(can be passed as keyword argument), but can be very slow for largesinkhorn_max_iter.softsort: Computes the "SoftSort" operator from SoftSort: A Continuous Relaxation for the argsort Operator. Note: Can introduce gradient discontinuities when elements inxare not unique, but is much faster than OT-based method.neuralsort: Computes the "NeuralSort" operator from Stochastic Optimization of Sorting Networks via Continuous Relaxations.
standardize: If True, standardizes and squashes the inputxalong the specified axis before applying the softargsort operation. This can improve numerical stability and performance, especially when the values inxvary widely in scale.ot_kwargs: Additional optional keyword arguments to pass to the OT projection operator, e.g., to control the number of max iterations or tolerance.
Returns:
A SoftIndex of shape (..., n, ..., [n]) (positive Array which sums to 1 over the last dimension). The elements in (..., i, ..., [n]) represent a distribution over values in x for the ith smallest element along the specified axis.
Computing the expectation
Computing the soft sorted values means taking the expectation of x under the SoftIndex distribution. Similar to how with normal indices you would do
sorted_x = jnp.take_along_axis(x, indices, axis=axis)
soft_sorted_x = sj.take_along_axis(x, soft_index, axis=axis)
softjax.sort.
softjax.sort(x: jax.Array, axis: int | None = None, descending: bool = False, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'neuralsort', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True) -> jax.Array
¤
Performs a soft version of jax.numpy.sort of x along the specified axis.
Most methods go through softjax.argsort + softjax.take_along_axis to produce soft sorted values.
The exceptions (fast_soft_sort, smooth_sort, sorting_network) bypass soft indices and compute values directly:
fast_soft_sort: permutahedron projection via PAV isotonic regression (Blondel et al., 2020). Insmoothmode, uses an entropic (log-KL) variant that is piecewise smooth but not C∞ (discontinuities at argsort chamber boundaries).smooth_sort(SoftJAX only,smoothmode only): permutahedron projection via ESP smooth majorization bounds + LBFGS dual, giving a truly C∞ relaxation.sorting_network: soft bitonic sorting network (Petersen et al., 2021).
Extra Arguments:
gated_grad: IfFalse, stops the gradient flow through the soft index. True gives gated 'SiLU-style' gradients, while False gives integrated 'Softplus-style' gradients.
Returns:
Array of shape (..., n, ...) representing the soft sorted values of x along the specified axis.
softjax.top_k(x: jax.Array, k: int, axis: int = -1, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'neuralsort', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True) -> tuple[jax.Array, Float[Array, '...'] | None]
¤
Performs a soft version of jax.lax.top_k of x along the specified axis.
Arguments:
x: Input Array of shape (..., n, ...).k: The number of top elements to select.axis: The axis along which to compute the top_k operation.softness: Softness of the function, should be larger than zero.mode: Type of regularizer in the projection operators.hard: Returns the result of jax.lax.top_k with a one-hot encoding of the indices.smooth: C∞ smooth (based on logistic/softmax/entropic regularizer).- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via a softmax operation. - For optimal transport (
otmethod), transport plan is computed via Sinkhorn iterations (see Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances).
- For unit simplex projection (
c0: C0 continuous (based on euclidean/L2 regularizer).- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via the algorithm in Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application. - For optimal transport (
otmethod), transport plan is computed via LBFGS (see Smooth and Sparse Optimal Transport).
- For unit simplex projection (
c1/c2: C1 differentiable / C2 twice differentiable. Similar toc0, but using p-norm regularizers with p=3/2 and p=4/3, respectively.
method: Method to compute the soft argsort. All approaches were originally proposed for the smooth mode, we extend them to the c0,c1,c2 modes as well.ot: Uses the approach in Differentiable Top-k with Optimal Transport. Intuition: The top-k elements are selected by specifying k+1 "anchors" and then transporting the top_k values to the top k anchors, and the remaining (n-k) values to the last anchor. Note: Inaccurate for smallsinkhorn_max_iter(can be passed as keyword argument), but can be very slow for largesinkhorn_max_iter.softsort: Computes the top-k elements of the "SoftSort" operator from SoftSort: A Continuous Relaxation for the argsort Operator. Note: Can introduce gradient discontinuities when elements inxare not unique, but is much faster than OT-based method.neuralsort: Computes the top-k elements of the "NeuralSort" operator from Stochastic Optimization of Sorting Networks via Continuous Relaxations.fast_soft_sort: Uses theFastSoftSortoperator from Fast Differentiable Sorting and Ranking to directly compute the soft sorted values, via projection onto the permutahedron. The projection is solved via a PAV algorithm as proposed in Fast Differentiable Sorting and Ranking. The top-k values are then retrieved by taking the first k values from the soft sorted output. Note: This method does not return the soft indices, only the soft values.sorting_network: Uses a soft bitonic sorting network as proposed in Differentiable Sorting Networks for Scalable Sorting and Ranking Supervision, replacing hard compare-and-swap with soft versions. The top-k values are retrieved by taking the first k values from the soft sorted output. Note: This method does not return the soft indices, only the soft values.
standardize: If True, standardizes and squashes the inputxalong the specified axis before applying the softtop_k operation. This can improve numerical stability and performance, especially when the values inxvary widely in scale.ot_kwargs: Additional optional keyword arguments to pass to the OT projection operator, e.g., to control the number of max iterations or tolerance.gated_grad: IfFalse, stops the gradient flow through the soft index. True gives gated 'SiLU-style' gradients, while False gives integrated 'Softplus-style' gradients.
Returns:
soft_values: Top-k values ofx, shape (..., k, ...).soft_index: SoftIndex of shape (..., k, ..., [n]) (positive Array which sums to 1 over the last dimension). Represents the soft indices of the top-k values.
softjax.rank(x: jax.Array, axis: int | None = None, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'neuralsort', descending: bool = True, standardize: bool = True, ot_kwargs: dict | None = None) -> jax.Array
¤
Computes the soft ranks of x along the specified axis.
Arguments:
x: Input Array of shape (..., n, ...).axis: The axis along which to compute the rank operation. If None, the input Array is flattened before computing the rank.descending: If True, larger inputs receive smaller ranks (best rank is 0). If False, ranks increase with the input values.softness: Softness of the function, should be larger than zero.mode: Type of regularizer in the projection operators.hard: Returns rank computed as two jnp.argsort calls.smooth: C∞ smooth (based on logistic/softmax/entropic regularizer).- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via a softmax operation. - For optimal transport (
otmethod), transport plan is computed via Sinkhorn iterations (see Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances).
- For unit simplex projection (
c0: C0 continuous (based on euclidean/L2 regularizer).- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via the algorithm in Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application. - For optimal transport (
otmethod), transport plan is computed via LBFGS (see Smooth and Sparse Optimal Transport).
- For unit simplex projection (
c1/c2: C1 differentiable / C2 twice differentiable. Similar toc0, but using p-norm regularizers with p=3/2 and p=4/3, respectively.
method: Method to compute the soft rank. All approaches were originally proposed for the smooth mode, we extend them to the c0,c1,c2 modes as well.ot: Uses the approach in Differentiable Ranks and Sorting using Optimal Transport. Intuition: Run an OT procedure as insoftjax.argsort, then transport the sorted ranks (0, 1, ..., n-1) back to the ranks of the original values by using the transpose of the transport plan. Note: Inaccurate for smallsinkhorn_max_iter(can be passed as keyword argument), but can be very slow for largesinkhorn_max_iter.softsort: Adapts the "SoftSort" operator from SoftSort: A Continuous Relaxation for the argsort Operator for rank by using a single column-wise projection instead of row-wise projection. Note: Can introduce gradient discontinuities when elements inxare not unique, but is much faster than OT-based method.neuralsort: Adapts the "NeuralSort" operator from Stochastic Optimization of Sorting Networks via Continuous Relaxations for rank by renormalizing over columns after the row-wise projection.fast_soft_sort: Uses theFastSoftSortoperator from Fast Differentiable Sorting and Ranking to directly compute the soft ranks, via projection onto the permutahedron. The projection is solved via a PAV algorithm as proposed in Fast Differentiable Sorting and Ranking.
standardize: If True, standardizes and squashes the inputxalong the specified axis before applying the softrank operation. This can improve numerical stability and performance, especially when the values inxvary widely in scale.ot_kwargs: Additional optional keyword arguments to pass to the OT projection operator, e.g., to control the number of max iterations or tolerance.
Returns:
A positive Array of shape (..., n, ...) with values in [1, n]. The elements in (..., i, ...) represent the soft rank of the ith element along the specified axis.
Comparison operators¤
softjax.greater(x: jax.Array, y: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x > y.
Uses a Heaviside relaxation so the output approaches 0 at equality.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, greater returns 0 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the elementwise x > y.
softjax.greater_equal(x: jax.Array, y: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x >= y.
Uses a Heaviside relaxation so the output approaches 1 at equality.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, greater_equal returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the elementwise x >= y.
softjax.less(x: jax.Array, y: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x < y.
Uses a Heaviside relaxation so the output approaches 0 at equality.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, less returns 0 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the elementwise x < y.
softjax.less_equal(x: jax.Array, y: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x <= y.
Uses a Heaviside relaxation so the output approaches 1 at equality.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, less_equal returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the elementwise x <= y.
softjax.equal(x: jax.Array, y: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x == y.
Implemented as a soft abs(x - y) <= 0 comparison.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, equal returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the elementwise x == y.
softjax.not_equal(x: jax.Array, y: jax.Array, softness: float | jax.Array = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x != y.
Implemented as a soft abs(x - y) > 0 comparison.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, not_equal returns 0 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the elementwise x != y.
softjax.isclose(x: jax.Array, y: jax.Array, softness: float | jax.Array = 0.1, rtol: float = 1e-05, atol: float = 1e-08, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to jnp.isclose for elementwise comparison.
Implemented as a soft abs(x - y) <= atol + rtol * abs(y) comparison.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero.rtol: Relative tolerance.atol: Absolute tolerance.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, isclose returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the elementwise isclose(x, y).
Logical operators¤
softjax.logical_and(x: Float[Array, '...'], y: Float[Array, '...'], use_geometric_mean: bool = False) -> Float[Array, '...']
¤
Computes soft elementwise logical AND between two SoftBool Arrays.
Fuzzy logic implemented as all(stack([x, y], axis=-1), axis=-1).
Arguments:
x: First SoftBool input Array.y: Second SoftBool input Array.use_geometric_mean: If True, uses the geometric mean to compute the soft AND. Otherwise, the product is used.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the elementwise logical AND.
softjax.logical_not(x: Float[Array, '...']) -> Float[Array, '...']
¤
Computes soft elementwise logical NOT of a SoftBool Array.
Fuzzy logic implemented as 1.0 - x.
Arguments:
- x: SoftBool input Array.
Returns:
SoftBool of same shape as x (Array with values in [0, 1]), relaxing the elementwise logical NOT.
softjax.logical_or(x: Float[Array, '...'], y: Float[Array, '...'], use_geometric_mean: bool = False) -> Float[Array, '...']
¤
Computes soft elementwise logical OR between two SoftBool Arrays.
Fuzzy logic implemented as any(stack([x, y], axis=-1), axis=-1).
Arguments:
- x: First SoftBool input Array.
- y: Second SoftBool input Array.
- use_geometric_mean: If True, uses the geometric mean to compute the soft ALL inside the logical NOT. Otherwise, the product is used.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the elementwise logical OR.
softjax.logical_xor(x: Float[Array, '...'], y: Float[Array, '...'], use_geometric_mean: bool = False) -> Float[Array, '...']
¤
Computes soft elementwise logical XOR between two SoftBool Arrays.
Arguments:
- x: First SoftBool input Array.
- y: Second SoftBool input Array.
- use_geometric_mean: If True, uses the geometric mean to compute the soft AND and OR operations. Otherwise, the product is used.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the elementwise logical XOR.
softjax.all(x: Float[Array, '...'], axis: int = -1, epsilon: float = 1e-10, use_geometric_mean: bool = False) -> Float[Array, '...']
¤
Computes soft elementwise logical AND across a specified axis. Fuzzy logic implemented as the geometric mean along the axis.
Arguments:
- x: SoftBool input Array.
- axis: Axis along which to compute the logical AND. Default is -1 (last axis).
- epsilon: Minimum value for numerical stability inside the log.
- use_geometric_mean: If True, uses the geometric mean to compute the soft AND. Otherwise, the product is used.
Returns:
SoftBool (Array with values in [0, 1]) with the specified axis reduced, relaxing the logical ALL along that axis.
softjax.any(x: Float[Array, '...'], axis: int = -1, use_geometric_mean: bool = False) -> Float[Array, '...']
¤
Computes soft elementwise logical OR across a specified axis.
Fuzzy logic implemented as 1.0 - all(logical_not(x), axis=axis).
Arguments:
- x: SoftBool input Array.
- axis: Axis along which to compute the logical OR. Default is -1 (last axis).
- use_geometric_mean: If True, uses the geometric mean to compute the soft ALL inside the logical NOT. Otherwise, the product is used.
Returns:
SoftBool (Array with values in [0, 1]) with the specified axis reduced, relaxing the logical ANY along that axis.
Autograd-safe operators¤
softjax.arcsin(x: jax.Array) -> jax.Array
¤
Autograd-safe version of jax.numpy.arcsin via the double-where trick.
Returns jnp.arcsin(x) for |x| < 1 and ±π/2 at x = ±1,
without producing NaN gradients at the boundary (unlike jnp.arcsin).
Arguments:
x: Input Array.
Returns:
Elementwise arcsine of x, safe for autodiff.
softjax.arccos(x: jax.Array) -> jax.Array
¤
Autograd-safe version of jax.numpy.arccos via the double-where trick.
Returns jnp.arccos(x) for |x| < 1, 0 at x >= 1, and
π at x <= -1, without producing NaN gradients at the boundary
(unlike jnp.arccos).
Arguments:
x: Input Array.
Returns:
Elementwise arccosine of x, safe for autodiff.
softjax.div(x: jax.Array, y: jax.Array) -> jax.Array
¤
Autograd-safe division via the double-where trick.
Returns x / y when y != 0 and 0 otherwise, without
producing NaN gradients at y = 0 (unlike plain x / y).
Arguments:
x: Numerator Array.y: Denominator Array.
Returns:
Elementwise x / y, safe for autodiff.
softjax.log(x: jax.Array) -> jax.Array
¤
Autograd-safe version of jax.numpy.log via the double-where trick.
Returns jnp.log(x) for x > 0 and 0 otherwise, without
producing NaN gradients at x = 0 (unlike jnp.log).
Arguments:
x: Input Array.
Returns:
Elementwise natural logarithm of x, safe for autodiff.
softjax.norm(x: jax.Array, axis=None, keepdims=False) -> jax.Array
¤
Autograd-safe L2 norm via :func:sqrt.
Computes sqrt(sum(x**2, ...)) using the autograd-safe :func:sqrt,
avoiding NaN gradients when the norm is zero (unlike jnp.linalg.norm).
Arguments:
x: Input Array.axis: Axis or axes along which to compute the norm.keepdims: IfTrue, retains reduced axes with size 1.
Returns:
L2 norm of x along the given axis, safe for autodiff.
softjax.sqrt(x: jax.Array) -> jax.Array
¤
Autograd-safe version of jax.numpy.sqrt via the double-where trick.
Returns jnp.sqrt(x) for x > 0 and 0 otherwise, without
producing NaN gradients at x = 0 (unlike jnp.sqrt).
Arguments:
x: Input Array.
Returns:
Elementwise square root of x, safe for autodiff.
Selection operators¤
softjax.where(condition: Float[Array, '...'], x: jax.Array, y: jax.Array) -> jax.Array
¤
Performs a soft version of jax.numpy.where as x * condition + y * (1.0 - condition).
Arguments:
- condition: SoftBool condition Array, same shape as x and y.
- x: First input Array, same shape as condition.
- y: Second input Array, same shape as condition.
Returns:
Array of the same shape as x and y, interpolating between x and y according to condition in [0, 1].
softjax.take_along_axis(x: jax.Array, soft_index: Float[Array, '...'], axis: int | None = -1) -> jax.Array
¤
Performs a soft version of jax.numpy.take_along_axis via a weighted dot product.
Relation to jnp.take_along_axis
x = jnp.array([[1, 2, 3], [4, 5, 6]])
indices = jnp.array([[0, 2], [1, 0]])
print(jnp.take_along_axis(x, indices, axis=1))
indices_onehot = jax.nn.one_hot(indices, x.shape[1])
print(sj.take_along_axis(x, indices_onehot, axis=1))
[[1. 3.]
[5. 4.]]
[[1. 3.]
[5. 4.]]
Interaction with softjax.argmax
x = jnp.array([[5, 3, 4], [2, 7, 6]])
indices = jnp.argmin(x, axis=1, keepdims=True)
print("argmin_jnp:", jnp.take_along_axis(x, indices, axis=1))
indices_onehot = sj.argmin(x, axis=1, mode="hard", keepdims=True)
print("argmin_val_onehot:", sj.take_along_axis(x, indices_onehot, axis=1))
indices_soft = sj.argmin(x, axis=1, mode="smooth", softness=1.0,
keepdims=True)
print("argmin_val_soft:", sj.take_along_axis(x, indices_soft, axis=1))
argmin_jnp: [[3]
[2]]
argmin_val_onehot: [[3.]
[2.]]
argmin_val_soft: [[3.42478962]
[2.10433824]]
Interaction with softjax.argsort
x = jnp.array([[5, 3, 4], [2, 7, 6]])
indices = jnp.argsort(x, axis=1)
print("sorted_jnp:", jnp.take_along_axis(x, indices, axis=1))
indices_onehot = sj.argsort(x, axis=1, mode="hard")
print("sorted_sj_hard:", sj.take_along_axis(x, indices_onehot, axis=1))
indices_soft = sj.argsort(x, axis=1, mode="smooth", softness=1.0)
print("sorted_sj_soft:", sj.take_along_axis(x, indices_soft, axis=1))
sorted_jnp: [[3 4 5]
[2 6 7]]
sorted_sj_hard: [[3. 4. 5.]
[2. 6. 7.]]
sorted_sj_soft: [[3.2918137 4. 4.7081863 ]
[2.00000045 6.26894107 6.73105858]]
Arguments:
x: Input Array of shape (..., n, ...).soft_index: A SoftIndex of shape (..., k, ..., [n]) (positive Array which sums to 1 over the last dimension). If axis is None, must be two-dimensional. If axis is not None, must have x.ndim + 1 == soft_index.ndim, and x must be broadcast-compatible with soft_index along dimensions other than axis.axis: Axis along which to apply the soft index. If None, the array will be flattened before indexing is applied.
Returns:
Array of shape (..., k, ...), representing the result after soft selection along the specified axis.
softjax.take(x: jax.Array, soft_index: Float[Array, '...'], axis: int | None = None) -> jax.Array
¤
Performs a soft version of jax.numpy.take via a weighted dot product.
Arguments:
x: Input Array of shape (..., n, ...).soft_index: A SoftIndex of shape (k, [n]) (positive Array which sums to 1 over the last dimension).axis: Axis along which to apply the soft index. If None, the input is flattened.
Returns:
Array of shape (..., k, ...) after soft selection.
softjax.choose(soft_index: Float[Array, '...'], choices: jax.Array) -> jax.Array
¤
Performs a soft version of jax.numpy.choose via a weighted dot product.
Arguments:
soft_index: A SoftIndex of shape (..., [n]) (positive Array which sums to 1 over the last dimension). Represents the weights for each choice.choices: Array of shape (n, ...) supplying the values to mix.
Returns:
Array of shape (..., ...) after softly selecting among choices.
softjax.dynamic_index_in_dim(x: jax.Array, soft_index: Float[Array, '...'], axis: int = 0, keepdims: bool = True) -> jax.Array
¤
Performs a soft version of jax.lax.dynamic_index_in_dim via a weighted dot product.
Arguments:
x: Input Array of shape (..., n, ...).soft_index: A SoftIndex of shape ([n],) (positive Array which sums to 1 over the last dimension).axis: Axis along which to apply the soft index.keepdims: If True, keeps the reduced dimension as a singleton {1}.
Returns:
Array after soft indexing, shape (..., {1}, ...).
softjax.dynamic_slice_in_dim(x: jax.Array, soft_start_index: Float[Array, '...'], slice_size: int, axis: int = 0) -> jax.Array
¤
Performs a soft version of jax.lax.dynamic_slice_in_dim via a weighted dot product.
Arguments:
x: Input Array of shape (..., n, ...).soft_start_index: A SoftIndex of shape ([n],) (positive Array which sums to 1 over the last dimension).slice_size: Length of the slice to extract.axis: Axis along which to apply the soft slice.
Returns:
Array of shape (..., slice_size, ...) after soft slicing.
softjax.dynamic_slice(x: jax.Array, soft_start_indices: Sequence[Float[Array, '...']], slice_sizes: Sequence[int]) -> jax.Array
¤
Performs a soft version of jax.lax.dynamic_slice via a weighted dot product.
Arguments:
x: Input Array of shape (n_1, n_2, ..., n_k).soft_start_indices: A list of SoftIndices of shape ([n_i],) (positive Array which sums to 1). Sequence of SoftIndex distributions of shapes ([n_1],), ([n_2],), ..., ([n_k]) each summing to 1.slice_sizes: Sequence of slice lengths for each dimension.
Returns:
Array of shape (l_1, l_2, ..., l_k) after soft slicing.