Softjax operators¤
Elementwise operators¤
softjax.abs(x: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'pseudohuber', 'cubic', 'quintic'] = 'entropic') -> jax.Array
¤
Performs a soft version of jax.numpy.abs.
Arguments:
x: Input Array of any shape.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: Projection mode. "hard" returns the exact absolute value, otherwise uses "entropic", "pseudohuber", "euclidean", "cubic", or "quintic" relaxations. Defaults to "entropic".
Returns:
Result of applying soft elementwise absolute value to x.
softjax.clip(x: jax.Array, a: jax.Array, b: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'quartic', 'gated_entropic', 'gated_euclidean', 'gated_cubic', 'gated_quintic', 'gated_pseudohuber'] = 'entropic') -> jax.Array
¤
Performs a soft version of jax.numpy.clip.
Arguments:
x: Input Array of any shape.a: Lower bound scalar.b: Upper bound scalar.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: If "hard", appliesjnp.clip. Otherwise uses "entropic", "euclidean", "quartic", "gated_entropic", "gated_euclidean", "gated_cubic", "gated_quintic", or "gated_pseudohuber" relaxations. Defaults to "entropic".
Returns:
Result of applying soft elementwise clipping to x.
softjax.heaviside(x: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'pseudohuber', 'cubic', 'quintic'] = 'entropic') -> Float[Array, '...']
¤
Performs a soft version of jax.numpy.heaviside(x,0.5).
Arguments:
x: Input Array of any shape.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: If "hard", returns the exact Heaviside step. Otherwise uses "entropic", "euclidean", "cubic", or "quintic" relaxations. Defaults to "entropic".
Returns:
SoftBool of same shape as x (Array with values in [0, 1]), relaxing the
elementwise Heaviside step function.
softjax.relu(x: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'quartic', 'gated_entropic', 'gated_euclidean', 'gated_cubic', 'gated_quintic', 'gated_pseudohuber'] = 'entropic') -> jax.Array
¤
Performs a soft version of jax.nn.relu.
Arguments:
x: Input Array of any shape.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: If "hard", appliesjax.nn.relu. Otherwise uses "entropic", "euclidean", "quartic", "gated_entropic", "gated_euclidean", "gated_cubic", "gated_quintic", or "gated_pseudohuber" relaxations. Defaults to "entropic".
Returns:
Result of applying soft elementwise ReLU to x.
softjax.round(x: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'pseudohuber', 'cubic', 'quintic'] = 'entropic', neighbor_radius: int = 5) -> jax.Array
¤
Performs a soft version of jax.numpy.round.
Arguments:
x: Input Array of any shape.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: If "hard", appliesjnp.round. Otherwise uses a sigmoid-based relaxation based on the algorithm described in https://arxiv.org/pdf/2504.19026v1. This function thereby inherits the different sigmoid modes "entropic", "euclidean", "pseudohuber", "cubic", or "quintic". Defaults to "entropic".neighbor_radius: Number of neighbors on each side of the floor value to consider for the soft rounding. Defaults to 5.
Returns:
Result of applying soft elementwise rounding to x.
softjax.sign(x: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'pseudohuber', 'cubic', 'quintic'] = 'entropic') -> jax.Array
¤
Performs a soft version of jax.numpy.sign.
Arguments:
x: Input Array of any shape.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: If "hard", returnsjnp.sign. Otherwise smooths via "entropic", "euclidean", "cubic", or "quintic" relaxations. Defaults to "entropic".
Returns:
Result of applying soft elementwise sign to x.
Array-valued operators¤
softjax.argmax(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean'] = 'entropic') -> Float[Array, '...']
¤
Performs a soft version of jax.numpy.argmax
of x along the specified axis.
Arguments:
x: Input Array of shape (..., n, ...).axis: The axis along which to compute the argmax. If None, the input Array is flattened before computing the argmax. Defaults to None.keepdims: If True, keeps the reduced dimension as a singleton {1}.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: Controls the type of softening:hard: Returns the result of jnp.argmax with a one-hot encoding of the indices.entropic: Returns a softmax-based relaxation of the argmax.euclidean: Returns an L2-projection-based relaxation of the argmax.
Returns:
A SoftIndex of shape (..., {1}, ..., [n]) (positive Array which sums to 1 over the last dimension). Represents the probability of an index corresponding to the argmax along the specified axis.
Usage
This function can be used as a differentiable relaxation to
jax.numpy.argmax,
enabling backpropagation through index selection steps in neural networks or
optimization routines. However, note that the output is not a discrete index
but a SoftIndex, which is a distribution over indices.
Therefore, functions which operate on indices have to be adjusted accordingly
to accept a SoftIndex, see e.g. softjax.max for an example of using
softjax.take_along_axis to retrieve the soft maximum value via the
SoftIndex.
Difference to jax.nn.softmax
Note that softjax.argmax in entropic mode is not fully equivalent to
jax.nn.softmax
because it moves the probability dimension into the last axis
(this is a convention in the SoftIndex data type).
softjax.max(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean'] = 'entropic') -> jax.Array
¤
Performs a soft version of jax.numpy.max
of x along the specified axis.
Implemented as softjax.argmax followed by softjax.take_along_axis, see
respective documentations for details.
Returns:
Array of shape (..., {1}, ...) representing the soft maximum of x along the
specified axis.
softjax.argmin(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean'] = 'entropic') -> Float[Array, '...']
¤
Performs a soft version of jax.numpy.argmin
of x along the specified axis.
Implemented as softjax.argmax on -x, see respective documentation for
details.
softjax.min(x: jax.Array, axis: int | None = None, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean'] = 'entropic', keepdims: bool = False) -> jax.Array
¤
Performs a soft version of jax.numpy.min
of x along the specified axis.
Implemented as -softjax.max on -x, see respective documentation for details.
Returns:
Array of shape (..., {1}, ...) representing the soft minimum of x along the
specified axis.
softjax.argmedian(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean'] = 'entropic', fast: bool = True, max_iter: int = 1000) -> Float[Array, '...']
¤
Computes the soft argmedian of x along the specified axis.
Arguments:
x: Input Array of shape (..., n, ...).axis: The axis along which to compute the median. If None, the input Array is flattened before computing the median. Defaults to None.keepdims: If True, keeps the reduced dimension as a singleton {1}.softness: Softness of the function, should be larger than zero. Defaults to 1.modeandfast: These two arguments control the behavior of the median:mode="hard": Returns the result of jnp.median with a one-hot encoding of the indices. On ties, it returns a uniform distribution over all median indices.fast=Falseandmode="entropic": Uses entropy-regularized optimal transport (implemented via Sinkhorn iterations). We adapt the approach in Differentiable Ranks and Sorting using Optimal Transport and Differentiable Top-k with Optimal Transport to the median operation by carefully adjusting the cost matrix and marginals. Intuition: There are three "anchors", the median is transported onto one anchor, and all the larger and smaller elements are transported to the other two anchors, respectively. Can be slow for largemax_iter.fast=Falseandmode="euclidean": Similar to entropic case, but using an L2-regularizer (implemented via projection onto Birkhoff polytope).fast=Trueandmode="entropic": This formulation a well-known soft median operation based on the interpretation of the median as the minimizer of absolute deviations. The softening is then achieved by replacing the argmax operator with a softmax. Note, that this also has close ties to the "SoftSort" operator from SoftSort: A Continuous Relaxation for the argsort Operator. Note: Fast mode introduces gradient discontinuities when elements inxare not unique, but is much faster.fast=Trueandmode="euclidean": Similar to entropic fast case, but using a euclidean unit-simplex projection instead of softmax.
max_iter: Maximum number of iterations for the Sinkhorn algorithm ifmodeis "entropic", or for the projection onto the Birkhoff polytope ifmodeis "euclidean". Unused iffast=True.
Returns:
A SoftIndex of shape (..., {1}, ..., [n]) (positive Array which sums to 1 over the last dimension). The elements in (..., 0, ...) represent a distribution over values in x being the median along the specified axis.
softjax.median(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean'] = 'entropic', fast: bool = True, max_iter: int = 1000) -> jax.Array
¤
Performs a soft version of jnp.median
of x along the specified axis.
Implemented as softjax.argmedian followed by softjax.take_along_axis,
see respective documentations for details.
Returns:
An Array of shape (..., {1}, ...), representing the soft median values along the specified axis.
softjax.median_newton(x: jax.Array, axis: int | None = None, keepdims: bool = False, softness: float = 1.0, mode: Literal['hard', 'entropic', 'pseudohuber', 'euclidean', 'cubic', 'quintic'] = 'entropic', max_iter: int = 8, eps: float = 1e-12) -> jax.Array
¤
Performs a soft version of jax.numpy.median
of x along the specified axis.
Arguments:
x: Input Array of shape (..., n, ...).axis: Axis along which to compute the median. If None, the input is flattened. Defaults to None.keepdims: If True, keeps the reduced dimension as a singleton {1}. Defaults to False.softness: Softness of the score function, should be larger than zero. Defaults to 1.0.mode: Smooth score choice:hard: Returnsjnp.median.sigmoid,pseudohuber,linear,cubic,quintic: Smooth relaxations for the M-estimator using Newton steps. Defaults tosigmoid.
max_iter: Maximum number of Newton iterations in the M-estimator.eps: Small constant added to the derivative to avoid division by zero.
Returns:
Array of shape (..., {1}, ...) representing the soft median of x along the
specified axis.
softjax.argsort(x: jax.Array, axis: int | None = None, descending: bool = False, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean'] = 'entropic', fast: bool = True, max_iter: int = 1000) -> Float[Array, '...']
¤
Performs a soft version of jax.numpy.argsort
of x along the specified axis.
Arguments:
x: Input Array of shape (..., n, ...).axis: The axis along which to compute the argsort operation. If None, the input Array is flattened before computing the argsort. Defaults to None.descending: If True, sorts in descending order. Defaults to False (ascending).softness: Softness of the function, should be larger than zero. Defaults to 1.modeandfast: These two arguments control the type of softening:mode="hard": Returns the result of jnp.argsort with a one-hot encoding of the indices.fast=Falseandmode="entropic": Uses entropy-regularized optimal transport (implemented via Sinkhorn iterations) as in Differentiable Ranks and Sorting using Optimal Transport. Intuition: The sorted elements are selected by specifying n "anchors" and then transporting the ith-largest value to the ith-largest anchor. Can be slow for largemax_iter.fast=Falseandmode="euclidean": Similar to entropic case, but using an L2-regularizer (implemented via LBFGS projection onto Birkhoff polytope, see Smooth and Sparse Optimal Transport).fast=Trueandmode="entropic": Uses the "SoftSort" operator proposed in SoftSort: A Continuous Relaxation for the argsort Operator. This initializes the cost matrix based on the absolute difference ofxto the sorted values and then applies a single row normalization (instead of full Sinkhorn in OT). Note: Fast mode introduces gradient discontinuities when elements inxare not unique, but is much faster.fast=Trueandmode="euclidean": Similar to entropic fast case, but using a euclidean unit-simplex projection instead of softmax. To the best of our knowledge this variant is novel.
max_iter: Maximum number of iterations for the Sinkhorn algorithm ifmodeis "entropic", or for the projection onto the Birkhoff polytope ifmodeis "euclidean". Unused iffast=True.
Returns:
A SoftIndex of shape (..., n, ..., [n]) (positive Array which sums to 1 over the last dimension). The elements in (..., i, ..., [n]) represent a distribution over values in x for the ith smallest element along the specified axis.
Computing the expectation
Computing the soft sorted values means taking the expectation of x under the
SoftIndex distribution. Similar to how with normal indices you would do
sorted_x = jnp.take_along_axis(x, indices, axis=axis)
soft_sorted_x = sj.take_along_axis(x, soft_index, axis=axis)
softjax.sort.
softjax.sort(x: jax.Array, axis: int | None = None, descending: bool = False, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean'] = 'entropic', fast: bool = True, max_iter: int = 1000) -> jax.Array
¤
Performs a soft version of jax.numpy.sort
of x along the specified axis.
Implemented as softjax.argsort followed by softjax.take_along_axis, see
respective documentations for details.
Returns:
Array of shape (..., n, ...) representing the soft sorted values of x along the
specified axis.
softjax.top_k(x: jax.Array, k: int, axis: int = -1, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean'] = 'entropic', fast: bool = True, max_iter: int = 1000) -> tuple[jax.Array, Float[Array, '...']]
¤
Performs a soft version of jax.lax.top_k
of x along the specified axis.
Arguments:
x: Input Array of shape (..., n, ...).k: The number of top elements to select.axis: The axis along which to compute the top_k operation. Defaults to -1.softness: Softness of the function, should be larger than zero. Defaults to 1.modeandfast: These two arguments control the type of softening:mode="hard": Returns the result of jax.lax.top_k with a one-hot encoding of the indices.fast=Falseandmode="entropic": Uses entropy-regularized optimal transport (implemented via Sinkhorn iterations) as in Differentiable Top-k with Optimal Transport. Intuition: The top-k elements are selected by specifying k+1 "anchors" and then transporting the top_k values to the top k anchors, and the remaining (n-k) values to the last anchor. Can be slow for largemax_iter.fast=Falseandmode="euclidean": Similar to entropic case, but using an L2-regularizer (implemented via projection onto Birkhoff polytope). This version combines the approaches in Smooth and Sparse Optimal Transport) (L2 regularizer for sorting) and Differentiable Top-k with Optimal Transport (entropic regularizer for top-k).fast=Trueandmode="entropic": Uses the "SoftSort" operator proposed in SoftSort: A Continuous Relaxation for the argsort Operator. This initializes the cost matrix based on the absolute difference ofxto the sorted values and then applies a single row normalization (instead of full Sinkhorn in OT). Because this is very fast we do a full soft argsort and then take the top-k elements. Note: Fast mode introduces gradient discontinuities when elements inxare not unique, but is much faster.fast=Trueandmode="euclidean": Similar to entropic fast case, but using a euclidean unit-simplex projection instead of softmax. To the best of our knowledge this variant is novel.
max_iter: Maximum number of iterations for the Sinkhorn algorithm ifmodeis "entropic", or for the projection onto the Birkhoff polytope ifmodeis "euclidean". Unused iffast=True.
Returns:
soft_values: Top-k values ofx, shape (..., k, ...).soft_index: SoftIndex of shape (..., k, ..., [n]) (positive Array which sums to 1 over the last dimension). Represents the soft indices of the top-k values.
softjax.ranking(x: jax.Array, axis: int | None = None, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean'] = 'entropic', fast: bool = True, max_iter: int = 1000, descending: bool = True) -> jax.Array
¤
Computes the soft rankings of x along the specified axis.
Arguments:
x: Input Array of shape (..., n, ...).axis: The axis along which to compute the ranking operation. If None, the input Array is flattened before computing the ranking. Defaults to None.descending: If True, larger inputs receive smaller ranks (best rank is 0). If False, ranks increase with the input values.softness: Softness of the function, should be larger than zero. Defaults to 1.modeandfast: These two arguments control the behavior of the ranking operation:mode="hard": Returns ranking computed as two jnp.argsort calls.fast=Falseandmode="entropic": Uses entropy-regularized optimal transport (implemented via Sinkhorn iterations) as in Differentiable Ranks and Sorting using Optimal Transport. Intuition: We can use the transportation plan obtained in soft sorting for ranking by transporting the sorted ranks (0, 1, ..., n-1) back to the ranks of the original values. Can be slow for largemax_iter.fast=Falseandmode="euclidean": Similar to entropic case, but using an L2-regularizer (implemented via projection onto Birkhoff polytope, see Smooth and Sparse Optimal Transport).fast=Trueandmode="entropic": Uses an adaptation of the "SoftSort" operator proposed in SoftSort: A Continuous Relaxation for the argsort Operator. This initializes the cost matrix based on the absolute difference ofxto the sorted values and then we crucially apply a single column normalization (instead of of row normalization in the original paper). This makes the resulting matrix a unimodal column stochastic matrix which is better suited for soft ranking. Note: Fast mode introduces gradient discontinuities when elements inxare not unique, but is much faster.fast=Trueandmode="euclidean": Similar to entropic fast case, but using a euclidean unit-simplex projection instead of softmax. To the best of our knowledge this variant is novel.
max_iter: Maximum number of iterations for the Sinkhorn algorithm ifmodeis "entropic", or for the projection onto the Birkhoff polytope ifmodeis "euclidean". Unused iffast=True.
Returns:
A positive Array of shape (..., n, ...) with values in [0, n-1]. The elements in (..., i, ...) represent the soft rank of the ith element along the specified axis.
Comparison operators¤
softjax.greater(x: jax.Array, y: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'pseudohuber', 'cubic', 'quintic'] = 'entropic', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x > y.
Uses a Heaviside relaxation so the output approaches 0 at equality.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside: "entropic", "euclidean", "cubic" spline, or "quintic" spline. Defaults to "entropic".epsilon: Small offset so that as softness->0, greater returns 0 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the
elementwise x > y.
softjax.greater_equal(x: jax.Array, y: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'pseudohuber', 'cubic', 'quintic'] = 'entropic', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x >= y.
Uses a Heaviside relaxation so the output approaches 1 at equality.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside: "entropic", "euclidean", "cubic" spline, or "quintic" spline. Defaults to "entropic".epsilon: Small offset so that as softness->0, greater_equal returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the
elementwise x >= y.
softjax.less(x: jax.Array, y: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'pseudohuber', 'cubic', 'quintic'] = 'entropic', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x < y.
Uses a Heaviside relaxation so the output approaches 0 at equality.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside: "entropic", "euclidean", "cubic" spline, or "quintic" spline. Defaults to "entropic".epsilon: Small offset so that as softness->0, less returns 0 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the
elementwise x < y.
softjax.less_equal(x: jax.Array, y: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'pseudohuber', 'cubic', 'quintic'] = 'entropic', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x <= y.
Uses a Heaviside relaxation so the output approaches 1 at equality.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside: "entropic", "euclidean", "cubic" spline, or "quintic" spline. Defaults to "entropic".epsilon: Small offset so that as softness->0, less_equal returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the
elementwise x <= y.
softjax.equal(x: jax.Array, y: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'pseudohuber', 'cubic', 'quintic'] = 'entropic', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x == y.
Implemented as a soft abs(x - y) <= 0 comparison.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside: "entropic", "euclidean", "cubic" spline, or "quintic" spline. Defaults to "entropic".epsilon: Small offset so that as softness->0, equal returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the
elementwise x == y.
softjax.not_equal(x: jax.Array, y: jax.Array, softness: float = 1.0, mode: Literal['hard', 'entropic', 'euclidean', 'pseudohuber', 'cubic', 'quintic'] = 'entropic', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to elementwise x != y.
Implemented as a soft abs(x - y) > 0 comparison.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero. Defaults to 1.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside: "entropic", "euclidean", "cubic" spline, or "quintic" spline. Defaults to "entropic".epsilon: Small offset so that as softness->0, not_equal returns 0 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the
elementwise x != y.
softjax.isclose(x: jax.Array, y: jax.Array, softness: float = 1.0, rtol: float = 1e-05, atol: float = 1e-08, mode: Literal['hard', 'entropic', 'euclidean', 'pseudohuber', 'cubic', 'quintic'] = 'entropic', epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes a soft approximation to jnp.isclose for elementwise comparison.
Implemented as a soft abs(x - y) <= atol + rtol * abs(y) comparison.
Arguments:
x: First input Array.y: Second input Array, same shape asx.softness: Softness of the function, should be larger than zero. Defaults to 1.rtol: Relative tolerance. Defaults to 1e-5.atol: Absolute tolerance. Defaults to 1e-8.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside: "entropic", "euclidean", "cubic" spline, or "quintic" spline. Defaults to "entropic".epsilon: Small offset so that as softness->0, isclose returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the
elementwise isclose(x, y).
Logical operators¤
softjax.logical_and(x: Float[Array, '...'], y: Float[Array, '...']) -> Float[Array, '...']
¤
Computes soft elementwise logical AND between two SoftBool Arrays.
Fuzzy logic implemented as all(stack([x, y], axis=-1), axis=-1).
Arguments:
x: First SoftBool input Array.y: Second SoftBool input Array.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the
elementwise logical AND.
softjax.logical_not(x: Float[Array, '...']) -> Float[Array, '...']
¤
Computes soft elementwise logical NOT of a SoftBool Array.
Fuzzy logic implemented as 1.0 - x.
Arguments:
- x: SoftBool input Array.
Returns:
SoftBool of same shape as x (Array with values in [0, 1]), relaxing the
elementwise logical NOT.
softjax.logical_or(x: Float[Array, '...'], y: Float[Array, '...']) -> Float[Array, '...']
¤
Computes soft elementwise logical OR between two SoftBool Arrays.
Fuzzy logic implemented as any(stack([x, y], axis=-1), axis=-1).
Arguments:
- x: First SoftBool input Array.
- y: Second SoftBool input Array.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the
elementwise logical OR.
softjax.logical_xor(x: Float[Array, '...'], y: Float[Array, '...']) -> Float[Array, '...']
¤
Computes soft elementwise logical XOR between two SoftBool Arrays.
Arguments:
- x: First SoftBool input Array.
- y: Second SoftBool input Array.
Returns:
SoftBool of same shape as x and y (Array with values in [0, 1]), relaxing the
elementwise logical XOR.
softjax.all(x: Float[Array, '...'], axis: int = -1, epsilon: float = 1e-10) -> Float[Array, '...']
¤
Computes soft elementwise logical AND across a specified axis. Fuzzy logic implemented as the geometric mean along the axis.
Arguments:
- x: SoftBool input Array.
- axis: Axis along which to compute the logical AND. Default is -1 (last axis).
- epsilon: Minimum value for numerical stability inside the log.
Returns:
SoftBool (Array with values in [0, 1]) with the specified axis reduced, relaxing the logical ALL along that axis.
softjax.any(x: Float[Array, '...'], axis: int = -1) -> Float[Array, '...']
¤
Computes soft elementwise logical OR across a specified axis.
Fuzzy logic implemented as 1.0 - all(logical_not(x), axis=axis).
Arguments:
- x: SoftBool input Array.
- axis: Axis along which to compute the logical OR. Default is -1 (last axis).
Returns:
SoftBool (Array with values in [0, 1]) with the specified axis reduced, relaxing t he logical ANY along that axis.
Selection operators¤
softjax.where(condition: Float[Array, '...'], x: jax.Array, y: jax.Array) -> jax.Array
¤
Computes a soft elementwise selection between two Arrays based on a SoftBool
condition. Fuzzy logic implemented as x * condition + y * (1.0 - condition).
Arguments:
- condition: SoftBool condition Array, same shape as x and y.
- x: First input Array, same shape as condition.
- y: Second input Array, same shape as condition.
Returns:
Array of the same shape as x and y, interpolating between x and y according
to condition in [0, 1].
softjax.take_along_axis(x: jax.Array, soft_index: Float[Array, '...'], axis: int = -1) -> jax.Array
¤
Performs a soft version of jax.numpy.take_along_axis via a weighted dot product.
Relation to jnp.take_along_axis
x = jnp.array([[1, 2, 3], [4, 5, 6]])
indices = jnp.array([[0, 2], [1, 0]])
print(jnp.take_along_axis(x, indices, axis=1))
indices_onehot = jax.nn.one_hot(indices, x.shape[1])
print(sj.take_along_axis(x, indices_onehot, axis=1))
[[1. 3.]
[5. 4.]]
[[1. 3.]
[5. 4.]]
Interaction with softjax.argmax
x = jnp.array([[5, 3, 4], [2, 7, 6]])
indices = jnp.argmin(x, axis=1, keepdims=True)
print("argmin_jnp:", jnp.take_along_axis(x, indices, axis=1))
indices_onehot = sj.argmin(x, axis=1, mode="hard", keepdims=True)
print("argmin_val_onehot:", sj.take_along_axis(x, indices_onehot, axis=1))
indices_soft = sj.argmin(x, axis=1, mode="entropic", softness=1.0,
keepdims=True)
print("argmin_val_soft:", sj.take_along_axis(x, indices_soft, axis=1))
argmin_jnp: [[3]
[2]]
argmin_val_onehot: [[3.]
[2.]]
argmin_val_soft: [[3.42478962]
[2.10433824]]
Interaction with softjax.argsort
x = jnp.array([[5, 3, 4], [2, 7, 6]])
indices = jnp.argsort(x, axis=1)
print("sorted_jnp:", jnp.take_along_axis(x, indices, axis=1))
indices_onehot = sj.argsort(x, axis=1, mode="hard")
print("sorted_sj_hard:", sj.take_along_axis(x, indices_onehot, axis=1))
indices_soft = sj.argsort(x, axis=1, mode="entropic", softness=1.0)
print("sorted_sj_soft:", sj.take_along_axis(x, indices_soft, axis=1))
sorted_jnp: [[3 4 5]
[2 6 7]]
sorted_sj_hard: [[3. 4. 5.]
[2. 6. 7.]]
sorted_sj_soft: [[3.2918137 4. 4.7081863 ]
[2.00000045 6.26894107 6.73105858]]
Arguments:
x: Input Array of shape (..., n, ...).soft_index: A SoftIndex of shape (..., k, ..., [n]) (positive Array which sums to 1 over the last dimension).axis: Axis along which to apply the soft index. Defaults to -1.
Returns:
Array of shape (..., k, ...), representing the result after soft selection along the specified axis.
softjax.take(x: jax.Array, soft_index: Float[Array, '...'], axis: int | None = None) -> jax.Array
¤
Performs a soft version of jax.numpy.take via a weighted dot product.
Arguments:
x: Input Array of shape (..., n, ...).soft_index: A SoftIndex of shape (k, [n]) (positive Array which sums to 1 over the last dimension).axis: Axis along which to apply the soft index. If None, the input is flattened. Defaults to None.
Returns:
Array of shape (..., k, ...) after soft selection.
softjax.choose(soft_index: Float[Array, '...'], choices: jax.Array) -> jax.Array
¤
Performs a soft version of jax.numpy.choose via a weighted dot product.
Arguments:
soft_index: A SoftIndex of shape (..., [n]) (positive Array which sums to 1 over the last dimension). Represents the weights for each choice.choices: Array of shape (n, ...) supplying the values to mix.
Returns:
Array of shape (..., ...) after softly selecting among choices.
softjax.dynamic_index_in_dim(x: jax.Array, soft_index: Float[Array, '...'], axis: int = 0, keepdims: bool = True) -> jax.Array
¤
Performs a soft version of jax.lax.dynamic_index_in_dim via a weighted dot product.
Arguments:
x: Input Array of shape (..., n, ...).soft_index: A SoftIndex of shape ([n],) (positive Array which sums to 1 over the last dimension).axis: Axis along which to apply the soft index. Defaults to 0.keepdims: If True, keeps the reduced dimension as a singleton {1}.
Returns:
Array after soft indexing, shape (..., {1}, ...).
softjax.dynamic_slice_in_dim(x: jax.Array, soft_start_index: Float[Array, '...'], slice_size: int, axis: int = 0) -> jax.Array
¤
Performs a soft version of jax.lax.dynamic_slice_in_dim via a weighted dot product.
Arguments:
x: Input Array of shape (..., n, ...).soft_index: A SoftIndex of shape ([n],) (positive Array which sums to 1 over the last dimension).slice_size: Length of the slice to extract.axis: Axis along which to apply the soft slice. Defaults to 0.
Returns:
Array of shape (..., slice_size, ...) after soft slicing.
softjax.dynamic_slice(x: jax.Array, soft_start_indices: Sequence[Float[Array, '...']], slice_sizes: Sequence[int]) -> jax.Array
¤
Performs a soft version of jax.lax.dynamic_slice via a weighted dot product.
Arguments:
x: Input Array of shape (n_1, n_2, ..., n_k).soft_start_indices: A list of SoftIndices of shape ([n_i],) (positive Arrays which sums to 1). Sequence of SoftIndex distributions of shapes ([n_1],), ([n_2],), ..., ([n_k]) each summing to 1.slice_sizes: Sequence of slice lengths for each dimension.
Returns:
Array of shape (l_1, l_2, ..., l_k) after soft slicing.