All of Softjax¶
Softjax provides easy-to-use differentiable function surrogates of non-differentiable functions and discrete logic operations in JAX. Softjax offers soft function surrogates operating on real values, Booleans, and indices as well as wrappers to use these functions for gradient computation via straight-through estimation.
This pages guides you through all of Softjax key functionalities.
1. Softening¶
Many functions are not particulary well suited for gradient-based optimization as their gradients are either zero or undefined at discontinuities. Therefore, a typical approach in machine learning is to use soft surrogates for these functions to enable gradient-based optimization via automatic differentiation. Softjax provides such soft surrogates for many operations in JAX.
As shown below, Softjax provides two hyper-parameters to tune soft function surrogates:
mode being the type of function used for obtaining a soft approximation.
softness defining how close the surrogate function approximates the original function.
import softjax as sj
plot(sj.relu, modes=["entropic", "gated_quintic"])
SoftJAX provides function surrogates for many other function such as the absolute function.
plot(sj.abs, modes=["entropic", "quintic"])
plot(sj.round, modes=["entropic", "quintic"])
2. Straight-through estimation¶
Soft function surrogates can be used to compute soft gradients (or more accurately: soft vector-jacobian-producs) without modifying the forward pass via straight-through estimation. Straight-through estimation uses JAX's automatic differentiation system to replace only the function's gradient with the gradient of the surrogate function.
Note: Historically, straight-through estimators refer to the special case of treating a function as the identity operation on the backward pass. We use the term more generally to describe the case of replacing a function with smooth a surrogate on the backward pass.
Example - ReLu activation: The rectified linear unit (aka
relu) is commonly used as activation function in neural networks. For $x<0$ the gradient of thereluis zero. In turn, neural networks containingreluactivations may suffer from the "dying ReLu problem" where the gradients computed via automatic differentiation become zero when a gradient-based optimizer adjusts the inputs of ReLU functions to $x<0$. As pointed out in "The Resurrection of the ReLU", you can mitigate these problems by replacing its backward pass with a soft surrogate function. To do this with Softjax, simply replace thereluactivation in your code withsj.st(sj.relu), or equivalently directly use our wrapped primitives viasj.relu_stfor convenience.
soft_relu = sj.st(sj.relu)
x = jnp.arange(-1, 1, 0.01)
values, grads = jax.vmap(jax.value_and_grad(soft_relu))(x)
plot_value_and_grad(x, values, grads, label_func="ReLU", label_grad="Soft gradient")
The STE Trick¶
Under the hood sj.st() uses the stop-gradient oepration to replace the gradient of a function.
def st(fn: Callable) -> Callable:
sig = inspect.signature(fn)
mode_default = sig.parameters.get("mode").default
def wrapped(*args, **kwargs):
mode = kwargs.pop("mode", mode_default)
fw_y = fn(*args, **kwargs, mode="hard")
bw_y = fn(*args, **kwargs, mode=mode)
return jtu.tree_map(lambda fw, bw: jax.lax.stop_gradient(fw - bw) + bw, fw_y, bw_y)
return wrapped
By adding and subtracting the backward function bw_y to the function call, it does not alter the function's forward pass fw_y. Due to jax.lax.stop_gradient, only the soft backward function bw_y is used in the gradient computation.
Custom differentiation rules¶
The @sj.st decorator can also be used to define custom straight-through operations.
This can be useful when combining multiple functions provided by the softjax library. For this, it is important to understand, that simply applying the straight-through trick to every non-smooth function does not always result in the intended behavior.
Consider for example the case of multiplying the output of two relu functions together. This function only provides meaningful gradients in the first quadrant, we would like to change it such that we get a meaningful signal in the whole domain, as visualized below.
Note: We normalize the plotted gradient vectors for reduced cluttering.
def relu_prod(x, y):
# Standard ReLU product, no softening
return jax.nn.relu(x) * jax.nn.relu(y)
plot_value_grad_2D(relu_prod)
A naive approach would be to replace each relu with sj.relu_st independently. However, the resulting function will not provide informative gradients for every input.
This is due to the chain rule, in which the gradient flowing through one relu is multiplid by the (forward) output of the other relu. As the forward pass is not smoothed, the gradient will sometimes be multiplied by zero, resulting in no informative gradient.
def soft_relu_prod_naive(x, y, mode="entropic", softness=1.0):
# Naive straight-through implementation
return sj.relu_st(x, mode=mode, softness=softness) * sj.relu_st(
y, mode=mode, softness=softness
)
plot_value_grad_2D(soft_relu_prod_naive)
An alternative approach for softening this function is to apply the straight through trick on the outer level as illustrated below. When applied on the outer level, the forward pass computes the hard product of ReLUs as before, whereas the backward pass differentiates through the product of smooth relus.
Notice that we use sj.relu instead of sj.relu_st here.
@sj.st
def soft_relu_prod_custom_st(x, y, mode="entropic", softness=1.0):
# Custom straight-through implementation
return sj.relu(x, mode=mode, softness=softness) * sj.relu(
y, mode=mode, softness=softness
)
plot_value_grad_2D(soft_relu_prod_custom_st)
We observe that as expected, this version of the function now also produces informative gradients in the third quadrant.
In the simple above example, the only both sj.relu functions take the same parameters, therefore it was easy to just apply the sj.st decorator again.
In general, we might want to define custom behavior. This can be implemented by using the @sj.grad_replace decorator, which allows custom control flow conditioned on a forward boolean variable. The function will then execute with forward=True on the forward pass and forward=False on the backward pass.
def grad_replace(fn: Callable) -> Callable:
def wrapped(*args, **kwargs):
fw_y = fn(*args, **kwargs, forward=True)
bw_y = fn(*args, **kwargs, forward=False)
return jtu.tree_map(lambda fw, bw: jax.lax.stop_gradient(fw - bw) + bw, fw_y, bw_y)
return wrapped
3. Soft bools¶
Softjax provides differentiable surrogates of JAX's Boolean operators. A Boolean (aka Bool) is a data type that takes one of two possible values either being false or true (aka 0 or 1). Many operations in JAX such as greater or isclose generate arrays containing Booleans, while other operations such as logical_and or any operate on such arrays.
Example -
jax.numpy.greateryields zero gradients: As shown below,jax.graddoes not raise an error when called on its boolean operations. However, the returned gradients are zero for all array entries.
x = jax.random.uniform(jax.random.key(0), shape=(2, 10))
bool_array = jax.numpy.greater(x, 0.5)
def boolean_loss(x):
return jax.numpy.greater(x, 0.5).sum().astype("float32")
boolean_grads = jax.grad(boolean_loss)(x)
plot_array(x, title="x")
plot_array(bool_array, title="jax.numpy.greater(x, 0.5)")
plot_array(boolean_grads, title="jax.grad(jax.numpy.greater(x, 0.5).sum())(x)")
Example -
softjax.greater_styields useful gradients: Instead ofjax.numpy.greater, let's usesoftjax.greater_st(straight_through variant ofsoftjax.greater). As shown below, thanks to straight-through estimation,softjax.greater_styields exact Booleans while the gradient of the Boolean loss points in informative directions.
def soft_boolean_loss(x):
return sj.greater_st(x, 0.5).sum()
x = jax.random.uniform(jax.random.key(0), shape=(2, 10))
bool_array = sj.greater_st(x, 0.5)
boolean_grads = jax.grad(soft_boolean_loss)(x)
plot_array(x, title="x")
plot_array(bool_array, title="soft_greater_st(x, 0.5)")
plot_array(boolean_grads, title="jax.grad(soft_greater_st(x, 0.5).sum())(x)")
Generating soft bools¶
How does Softjax make Boolean logic operations differentiable? A real number $x\in \mathbb{R}$ could be mapped to a Bool using the Heaviside function
$$
H(x) = \begin{cases}
1, & x > 0 \\
0.5, & x=0\\
0, & x < 0
\end{cases}.
$$
The gradient of the Heaviside function (as implemented in JAX) is zero everywhere and hence unsuited for differentiable optimization. Instead of operating directly on Booleans, Softjax's differentiable logic operators resort to soft Booleans. A soft Boolean aka SoftBool can be interpreted as the probability of a Boolean being True.
We replace the heaviside function with differentiable surrogate such as the sigmoid function $\sigma(x) = \frac{1}{1+e^{-x}}$. While the sigmoid is the canonical example for mapping a real number to a SoftBool, Softjax provides additional surrogates.
plot(sj.heaviside, modes=["entropic", "euclidean", "quintic"])
In the above case, the linear and quintic relaxations have the advantage of altering the original function only in a bounded region, a property that can be desireble in some cases.
Given the concept of a SoftBool, a probabilistic surrogate for binary logical operations such as jax.numpy.equal and jax.numpy.greater is obtained by simply shifting the sigmoid.
Example - Greater operator:
sj.greater(x,y)corresponds to shiftingsj.heavisidebyyto the right. The output can be interpreted as the probability $P(x \geq y)\in[0,1]$ with $x\in\mathbb{R}$ and $y\in\mathbb{R}$.
def greater_than_1(x, mode="entropic", softness=1.0):
return sj.greater(x, y=jnp.array(1.0), mode=mode, softness=softness)
plot(greater_than_1, modes=["entropic", "quintic"])
Manipulating soft bools¶
Softjax replaces a Boolean with a SoftBool, in turn Boolean logic operators are replaced in Softjax with fuzzy logic operators that effectively compute the probabilities of Boolean events.
Example - Logical NOT: Given a
SoftBool$P(B)$ (being the probability that a Boolean event $B$ occurs), the probability of the event not occuring is $P(\bar B) = 1 - P(B)$ as implemented insj.logical_not.
def logical_not(x: SoftBool) -> SoftBool:
return 1 - x
Given sj.logical_not, the probability that x is not greater equal 0.5 is given by sj.logical_not(sj.greater_st(x, 0.5)). Due to the straight-through trick, the function sj.logical_not(sj.greater_st(x, 0.5)) uses exact Boolean logic in the forward pass and the SoftBool probability computation in the backward pass.
def not_greater_st(x):
return sj.logical_not(sj.greater_st(x, y=0.5, mode="entropic", softness=1.0))
x = jnp.arange(-1, 1, 0.01)
values, grads = jax.vmap(jax.value_and_grad(not_greater_st))(x)
plot_value_and_grad(x, values, grads, label_func="not_greater_st")
Example - Logical AND: Given two
SoftBools$P(A)$ and $P(B)$, the probability that both independent events occur is $P(A \wedge B) = P(A) \cdot P(B)$.
def logical_and(x: SoftBool, y: SoftBool) -> SoftBool:
return x * y
plot_softbool_operation(sj.logical_and)
Example - Logical XOR: Softjax computes other soft logic operators such as sj.logical_xor by combining sj.logical_not and sj.logical_and.
def sj.logical_xor(x: SoftBool, y: SoftBool) -> SoftBool:
return logical_or(logical_and(x, logical_not(y)), logical_and(logical_not(x), y))
plot_softbool_operation(sj.logical_xor)
Selection with soft bools¶
Through the use of Fuzzy logic operators, Softjax provides a toolbox to make many non-differentiable functions of JAX differentiable.
Example - sj.where(): The function
jax.numpy.where(condition, x, y)selects elements of arrayxifcondition == Trueand otherwise selectsy. Softjax provides a differentiable surrogate for this function viasj.where(P, x, y)which effectively computes the expected value $\mathbb{E}[X] = P \cdot x + (1-P) \cdot y$.
greater = lambda x, y: sj.greater(x, y, mode="entropic", softness=1.0)
soft_where = lambda x, y: sj.where(greater(x, y), x, y)
x = jax.random.uniform(jax.random.key(0), shape=(2, 10))
y = jax.random.uniform(jax.random.key(1), shape=(2, 10))
plot_array(x, title="x")
plot_array(y, title="y")
plot_array(soft_where(x, y), title="soft_where(x>y, x, y)")
4. Soft indices¶
Softjax offers soft surrogates for functions that generate indices as outputs, such as argmax, argmin, top_k, argmedian, and argsort.
The main mechanism here is to replace hard indices with distributions over indices (SoftIndex), allowing for informative gradients.
Similar to how SoftBool required going from boolean logic to fuzzy logic, this now requires adjusting functions that do selection via indices. As such, we provide new versions of e.g. take_along_axis, dynamic_index_in_dim, and choose.
Combining the soft index generation with the selection then allows to define surrogates for the corresponding max, min, top_k, median and sort functions.
Generating soft indices¶
In JAX, functions like jax.argmax return integer indices as outputs, which can take values within {0, ..., len(x)-1}.
x = jnp.array([1, 2, 3])
print("jnp.argmax(x):", jnp.argmax(x))
jnp.argmax(x): 2
In comparison, SoftJAX computes a SoftIndex array. Each entry of a SoftIndex array contains the probability that the index is being selected.
x = jnp.array([1, 2, 3])
print("sj.argmax(x):", sj.argmax(x))
sj.argmax(x): [2.06106005e-09 4.53978686e-05 9.99954600e-01]
Example - Softmax: The "softmax" (or more precisely "softargmax") is a commonly used differentiable surrogate for the
argmaxfunction (it is also the default softening mode insj.argmax). The $\text{softmax}(x) = \frac{\exp(x_i)}{\sum_j\exp(x_j)}$ returns a discrete probability distribution over indices (aka aSoftIndex). As shown in the plots below, the softmax is fully differentiable. It is commonly used for multi-class classification and in transformer networks.
When softness is low, sj.argmax concentrates probability on the true maximum index (e.g., [1.0, 0.0, 0.0]), recovering the hard maximum. When softness is higher, the result smoothly interpolates between values, providing useful gradients for optimization.
def cross_entropy(x, class_target=5):
probs = sj.argmax(x, softness=10.0)
target_one_hot = jax.nn.one_hot(class_target, num_classes=x.shape[0])
log_probs = jnp.log(probs)
return -(target_one_hot * log_probs).mean()
x = jax.random.normal(jax.random.key(0), shape=(10,))
probs = sj.argmax(x, softness=10.0)
lossgrads = jax.grad(cross_entropy)(x)
plot_softindices_1D(x, title="logits")
plot_softindices_1D(probs, title="index probabilities (softmax)")
plot_softindices_1D(lossgrads, title="gradients of cross entropy loss")
Note that while in a conventional array of indices, the index information is stored in the integer values, a SoftIndex stores the probabilities over possible indices in an extra dimension. By convention, we always put this additional dimension into the final axis. Except for this additional final dimension, the shape of the returned soft index matches that of the indices returned by standard JAX.
Here are a few examples of this:
x = jnp.arange(12).reshape((3, 4))
print("x.shape:", x.shape)
print("jnp.argmax(x, axis=1).shape:", jnp.argmax(x, axis=1).shape)
print("sj.argmax(x, axis=1).shape:", sj.argmax(x, axis=1).shape)
print("jnp.argmax(x, axis=0).shape:", jnp.argmax(x, axis=0).shape)
print("sj.argmax(x, axis=0).shape:", sj.argmax(x, axis=0).shape)
print(
"jnp.argmax(x, axis=1, keepdims=True).shape:",
jnp.argmax(x, axis=1, keepdims=True).shape,
)
print(
"sj.argmax(x, axis=1, keepdims=True).shape:",
sj.argmax(x, axis=1, keepdims=True).shape,
)
print(
"jnp.argmax(x, axis=0, keepdims=True).shape:",
jnp.argmax(x, axis=0, keepdims=True).shape,
)
print(
"sj.argmax(x, axis=0, keepdims=True).shape:",
sj.argmax(x, axis=0, keepdims=True).shape,
)
x.shape: (3, 4) jnp.argmax(x, axis=1).shape: (3,) sj.argmax(x, axis=1).shape: (3, 4) jnp.argmax(x, axis=0).shape: (4,) sj.argmax(x, axis=0).shape: (4, 3)
jnp.argmax(x, axis=1, keepdims=True).shape: (3, 1) sj.argmax(x, axis=1, keepdims=True).shape: (3, 1, 4) jnp.argmax(x, axis=0, keepdims=True).shape: (1, 4) sj.argmax(x, axis=0, keepdims=True).shape: (1, 4, 3)
We also offer soft versions of argmedian, argsort and top_k.
x = jax.random.uniform(jax.random.key(0), shape=(4,))
print("x:", x)
print("\njnp.argmedian(x):", "Not implemented in standard JAX")
print("sj.argmedian(x):", sj.argmedian(x))
print("\njnp.argsort(x):", jnp.argsort(x))
print("sj.argsort(x):", sj.argsort(x))
print("\njax.lax.top_k(x, k=3)[1]:", jax.lax.top_k(x, k=3)[1])
print("sj.top_k(x, k=3)[1]:", sj.top_k(x, k=3)[1])
x: [0.41845711 0.21629545 0.96532146 0.57450053] jnp.argmedian(x): Not implemented in standard JAX
sj.argmedian(x): [4.95553956e-01 8.69234800e-03 1.99739291e-04 4.95553956e-01] jnp.argsort(x): [1 0 3 2] sj.argsort(x): [[1.14092958e-01 8.61461280e-01 4.81124140e-04 2.39646378e-02] [7.42554231e-01 9.83447670e-02 3.13131302e-03 1.55969689e-01] [1.66975269e-01 2.21144035e-02 1.59597616e-02 7.94950566e-01] [4.11469085e-03 5.44954558e-04 9.75750772e-01 1.95895826e-02]]
jax.lax.top_k(x, k=3)[1]: [2 3 0] sj.top_k(x, k=3)[1]: [[4.11469085e-03 5.44954558e-04 9.75750772e-01 1.95895826e-02] [1.66975269e-01 2.21144035e-02 1.59597616e-02 7.94950566e-01] [7.42554231e-01 9.83447670e-02 3.13131302e-03 1.55969689e-01]]
Again, the shape of the returned SoftIndex matches that of the normal index array, except for an additional dimension in the last axis that matches the size of the input array along the specified axis. A few examples:
x = jax.random.uniform(jax.random.key(0), shape=(3, 4))
print("x.shape:", x.shape)
# standard JAX only added support for axis argument in jax.lax.top_k recently, normally uses last axis
print("\njax.lax.top_k(x, k=2, axis=1)[1].shape:", jax.lax.top_k(x, k=2)[1].shape)
print("sj.top_k(x, k=2, axis=1)[1].shape:", sj.top_k(x, k=2, axis=1)[1].shape)
print("sj.top_k(x, k=2, axis=0)[1].shape:", sj.top_k(x, k=2, axis=0)[1].shape)
print("\njnp.argsort(x, axis=1).shape:", jnp.argsort(x, axis=1).shape)
print("sj.argsort(x, axis=1).shape:", sj.argsort(x, axis=1).shape)
print("jnp.argsort(x, axis=0).shape:", jnp.argsort(x, axis=0).shape)
print("sj.argsort(x, axis=0).shape:", sj.argsort(x, axis=0).shape)
# standard JAX does not support argmedian
print("\nsj.argmedian(x, axis=1).shape:", sj.argmedian(x, axis=1).shape)
print("sj.argmedian(x, axis=0).shape:", sj.argmedian(x, axis=0).shape)
print(
"sj.argmedian(x, axis=1, keepdims=True).shape:",
sj.argmedian(x, axis=1, keepdims=True).shape,
)
print(
"sj.argmedian(x, axis=0, keepdims=True).shape:",
sj.argmedian(x, axis=0, keepdims=True).shape,
)
x.shape: (3, 4) jax.lax.top_k(x, k=2, axis=1)[1].shape: (3, 2)
sj.top_k(x, k=2, axis=1)[1].shape: (3, 2, 4)
sj.top_k(x, k=2, axis=0)[1].shape: (2, 3, 4) jnp.argsort(x, axis=1).shape: (3, 4) sj.argsort(x, axis=1).shape: (3, 4, 4) jnp.argsort(x, axis=0).shape: (3, 4)
sj.argsort(x, axis=0).shape: (3, 4, 3) sj.argmedian(x, axis=1).shape: (3, 4)
sj.argmedian(x, axis=0).shape: (4, 3) sj.argmedian(x, axis=1, keepdims=True).shape: (3, 1, 4) sj.argmedian(x, axis=0, keepdims=True).shape: (1, 4, 3)
Note: All of the functions in this section come with the three modes: hard, entropic and euclidean.
hard mode produces one-hot soft indices and is mainly used in straight-through estimation. entropic is the recommended soft default and reduces all operations to either a softmax or an entropy-regularized optimal transport problem.
Finally, euclidean reduces operations to L2 projection onto the unit simplex or the Birkhoff polytope, which can be used to produce sparse outputs. See the API documentation for details.
Selection with SoftIndices¶
Given a SoftIndex, Softjax provides (differentiable) helper functions for selecting array elements, mirroring the non-differentiable indexing in standard JAX. Put simply, entries of an array are selected by computing the expected value:
$$\mathrm{E}(arr, p) = \sum_{i} arr[i] \cdot p[i] = arr^{\top} \cdot p$$
where $p$ is the SoftIndex.
Example
sj.take_along_axis: The functionsj.take_along_axisis central to this selection mechanism. It generalizesjnp.take_along_axisto work with probability distributions (SoftIndices) instead of just integer indices.
The standard jnp.take_along_axis(arr, indices, axis) selects elements from arr using integer indices. Conceptually, it works by:
- Slicing along the specified axis to get 1D arrays
- Using the corresponding indices to select elements
- Assembling the results into the output array
One of its main uses is to accept the index output of e.g. jnp.argmax and select the maximum values at the indexed locations.
While jnp.take_along_axis uses integer indices out_1d[j] = arr_1d[indices_1d[j]], sj.take_along_axis accepts a SoftIndex to compute the corresponding the weighted sum: out_1d[j] = sum_i(arr_1d[i] * soft_indices_2d[j, i]).
x = jax.random.uniform(jax.random.key(0), shape=(2, 3))
print("x:\n", x)
indices = jnp.argmin(x, axis=1, keepdims=True)
print("min_jnp:\n", jnp.take_along_axis(x, indices, axis=1))
indices_onehot = sj.argmin(x, axis=1, mode="hard", keepdims=True)
print("min_sj_hard:\n", sj.take_along_axis(x, indices_onehot, axis=1))
indices_soft = sj.argmin(x, axis=1, keepdims=True)
print("min_sj_soft:\n", sj.take_along_axis(x, indices_soft, axis=1))
x: [[0.41845711 0.21629545 0.96532146] [0.57450053 0.53222649 0.35490518]]
min_jnp: [[0.21629545] [0.35490518]]
min_sj_hard: [[0.21629545] [0.35490518]] min_sj_soft: [[0.24029622] [0.39747788]]
As a convenience, this combination of sj.take_along_axis with SoftIndex-generataing functions is already implemented ino Softjax's max, median, top_k and sort functions.
x = jax.random.uniform(jax.random.key(0), shape=(4,))
print("x:", x)
print("\njnp.max(x):", jnp.max(x))
print("sj.max(x, mode='hard'):", sj.max(x, mode="hard"))
print("sj.max(x):", sj.max(x))
print("\njnp.median(x):", jnp.median(x))
print("sj.median(x, mode='hard'):", sj.median(x, mode="hard"))
print("sj.median(x):", sj.median(x))
print("\njax.lax.top_k(x, k=2)[0]:", jax.lax.top_k(x, k=2)[0])
print("sj.top_k(x, k=2, mode='hard')[0]:", sj.top_k(x, k=2, mode="hard")[0])
print("sj.top_k(x, k=2)[0]:", sj.top_k(x, k=2)[0])
print("\njnp.sort(x):", jnp.sort(x))
print("sj.sort(x, mode='hard'):", sj.sort(x, mode="hard"))
print("sj.sort(x):", sj.sort(x))
x: [0.41845711 0.21629545 0.96532146 0.57450053] jnp.max(x): 0.9653214611189975 sj.max(x, mode='hard'): 0.9653214611189975
sj.max(x): 0.9550070794223621
jnp.median(x): 0.49647882272194555 sj.median(x, mode='hard'): 0.49647882272194555 sj.median(x): 0.494137017678798 jax.lax.top_k(x, k=2)[0]: [0.96532146 0.57450053]
sj.top_k(x, k=2, mode='hard')[0]: [0.96532146 0.57450053]
sj.top_k(x, k=2)[0]: [0.95500708 0.54676106] jnp.sort(x): [0.21629545 0.41845711 0.57450053 0.96532146] sj.sort(x, mode='hard'): [0.21629545 0.41845711 0.57450053 0.96532146]
sj.sort(x): [0.24830531 0.42462602 0.54676106 0.95500708]
Finally, we also offer a soft ranking operation. While it does not return a SoftIndex (because its output is the same shape as the input), it relies on similar computations under the hood as e.g. sort, and also offers the same modes.
x = jax.random.uniform(jax.random.key(0), shape=(5,))
print("x:\n", x)
# This computes the ranking operation
print("jnp.argsort(jnp.argsort(x)):\n", jnp.argsort(jnp.argsort(x)))
print(
"sj.ranking(x, descending=False, mode='hard'):\n",
sj.ranking(x, descending=False, mode="hard"),
)
print("sj.ranking(x, descending=False):\n", sj.ranking(x, descending=False))
x: [0.41845711 0.21629545 0.96532146 0.57450053 0.53222649] jnp.argsort(jnp.argsort(x)): [1 0 4 3 2]
sj.ranking(x, descending=False, mode='hard'): [1. 0. 4. 3. 2.]
sj.ranking(x, descending=False): [1.37238141 0.25184717 3.94097212 2.40480632 2.13591075]
A note on modes¶
The naming of the modes stems from a reduction of simple higher-level functions like sj.abs to more complex lower-level functions like sj.argmax.
Starting from the argmax function, we relax it with a projection onto the unit simplex. We offer two modes, entropic and euclidean (in fact, ALL functions in Softjax support at least these two and the hard mode), which determine the regularizer used in the relaxed optimization problem.
The solution to the euclidean case is available in closed-form via the classic softmax function, the euclidean case is a simple L2-projection onto the unit simplex which boils down to a sort+cumsum operation.
Given the argmax relaxation, we directly get a max relaxation by taking the inner product of the soft indices with the original vector.
Now we can define a relaxation of the heaviside function from the softened argmax operation, by observing that
$\text{heaviside}(x)=\text{argmax}([x,0])[0]$.
This results in different S-shaped sigmoid functions, in fact the standard exponential-sigmoid is the closed-solution to the entropic mode, whereas a linear inteprolation between 0 and 1 is the closed-form solution to the euclidean mode.
Besides these modes, we define additional heaviside modes like cubic, quintic and pseudohuber, which all define different sigmoidal functions with different properties.
Our heaviside relaxation can now be used to define relaxations for e.g. the sign, abs and round function.
Most importantly though, we can move up the ladder to even higher-level functions based on the ReLU function. We first observe that we can generate the ReLU function from the heaviside function in two ways:
- By integrating $\text{heaviside}(x)$ from negative infinity to x.
- By taking $x \cdot \text{heaviside}(x)$ (a "gating"-mechanism).
Therefore, for each of our heaviside relaxations we can define two ReLU relaxations, some of which are well known. For example, the entropic case leads to the classic softplus function when integrated, and to the SiLU function when "gated" (we refer to this relaxation mode as gated_entropic).
Similarly, the euclidean, cubic and quintic are simple piecewise-polynomials that we can integrate in closed-form.
We can now use the soft ReLU relaxation to define more relaxations like e.g. clip.