All of Softjax¶
Softjax provides easy-to-use differentiable function surrogates of non-differentiable functions and discrete logic operations in JAX. Softjax offers soft function surrogates operating on real values, Booleans, and indices as well as wrappers to use these functions for gradient computation via straight-through estimation.
This pages guides you through all of Softjax key functionalities.
1. Softening¶
Many functions are not particulary well suited for gradient-based optimization as their gradients are either zero or undefined at discontinuities. Therefore, a typical approach in machine learning is to use soft surrogates for these functions to enable gradient-based optimization via automatic differentiation. Softjax provides such soft surrogates for many operations in JAX.
As shown below, Softjax provides two hyper-parameters to tune soft function surrogates:
mode being the type of function used for obtaining a soft approximation.
softness defining how close the surrogate function approximates the original function.
import softjax as sj
plot(sj.relu, modes=["smooth", "c0", "c2"])
SoftJAX provides function surrogates for many other function such as the absolute function.
plot(sj.abs, modes=["smooth", "c0", "c2"])
plot(sj.round, modes=["smooth", "c0", "c2"])
2. Straight-through estimation¶
Soft function surrogates can be used to compute soft gradients (or more accurately: soft vector-jacobian-producs) without modifying the forward pass via straight-through estimation. Straight-through estimation uses JAX's automatic differentiation system to replace only the function's gradient with the gradient of the surrogate function.
Note: Historically, straight-through estimators refer to the special case of treating a function as the identity operation on the backward pass. We use the term more generally to describe the case of replacing a function with smooth a surrogate on the backward pass.
Example - ReLu activation: The rectified linear unit (aka
relu) is commonly used as activation function in neural networks. For $x<0$ the gradient of thereluis zero. In turn, neural networks containingreluactivations may suffer from the "dying ReLu problem" where the gradients computed via automatic differentiation become zero when a gradient-based optimizer adjusts the inputs of ReLU functions to $x<0$. As pointed out in "The Resurrection of the ReLU", you can mitigate these problems by replacing its backward pass with a soft surrogate function. To do this with Softjax, simply replace thereluactivation in your code withsj.st(sj.relu), or equivalently directly use our wrapped primitives viasj.relu_stfor convenience.
soft_relu = sj.st(sj.relu)
x = jnp.arange(-1, 1, 0.01)
values, grads = jax.vmap(jax.value_and_grad(soft_relu))(x)
plot_value_and_grad(x, values, grads, label_func="ReLU", label_grad="Soft gradient")
The STE Trick¶
Under the hood sj.st() uses the stop-gradient oepration to replace the gradient of a function.
def st(fn: Callable) -> Callable:
sig = inspect.signature(fn)
mode_default = sig.parameters.get("mode").default
def wrapped(*args, **kwargs):
mode = kwargs.pop("mode", mode_default)
fw_y = fn(*args, **kwargs, mode="hard")
bw_y = fn(*args, **kwargs, mode=mode)
return jtu.tree_map(lambda fw, bw: jax.lax.stop_gradient(fw - bw) + bw, fw_y, bw_y)
return wrapped
By adding and subtracting the backward function bw_y to the function call, it does not alter the function's forward pass fw_y. Due to jax.lax.stop_gradient, only the soft backward function bw_y is used in the gradient computation.
Custom differentiation rules¶
The @sj.st decorator can also be used to define custom straight-through operations.
This can be useful when combining multiple functions provided by the softjax library. For this, it is important to understand, that simply applying the straight-through trick to every non-smooth function does not always result in the intended behavior.
Consider for example the case of multiplying the output of two relu functions together. This function only provides meaningful gradients in the first quadrant, we would like to change it such that we get a meaningful signal in the whole domain, as visualized below.
Note: We normalize the plotted gradient vectors for reduced cluttering.
def relu_prod(x, y):
# Standard ReLU product, no softening
return jax.nn.relu(x) * jax.nn.relu(y)
plot_value_grad_2D(relu_prod)
A naive approach would be to replace each relu with sj.relu_st independently. However, the resulting function will not provide informative gradients for every input.
This is due to the chain rule, in which the gradient flowing through one relu is multiplid by the (forward) output of the other relu. As the forward pass is not smoothed, the gradient will sometimes be multiplied by zero, resulting in no informative gradient.
def soft_relu_prod_naive(x, y, mode="smooth", softness=0.1):
# Naive straight-through implementation
return sj.relu_st(x, mode=mode, softness=softness) * sj.relu_st(
y, mode=mode, softness=softness
)
plot_value_grad_2D(soft_relu_prod_naive)
An alternative approach for softening this function is to apply the straight through trick on the outer level as illustrated below. When applied on the outer level, the forward pass computes the hard product of ReLUs as before, whereas the backward pass differentiates through the product of smooth relus.
Notice that we use sj.relu instead of sj.relu_st here.
@sj.st
def soft_relu_prod_custom_st(x, y, mode="smooth", softness=0.1):
# Custom straight-through implementation
return sj.relu(x, mode=mode, softness=softness) * sj.relu(
y, mode=mode, softness=softness
)
plot_value_grad_2D(soft_relu_prod_custom_st)
We observe that as expected, this version of the function now also produces informative gradients in the third quadrant.
In the simple above example, the only both sj.relu functions take the same parameters, therefore it was easy to just apply the sj.st decorator again.
In general, we might want to define custom behavior. This can be implemented by using the @sj.grad_replace decorator, which allows custom control flow conditioned on a forward boolean variable. The function will then execute with forward=True on the forward pass and forward=False on the backward pass.
def grad_replace(fn: Callable) -> Callable:
def wrapped(*args, **kwargs):
fw_y = fn(*args, **kwargs, forward=True)
bw_y = fn(*args, **kwargs, forward=False)
return jtu.tree_map(lambda fw, bw: jax.lax.stop_gradient(fw - bw) + bw, fw_y, bw_y)
return wrapped
3. Soft bools¶
Softjax provides differentiable surrogates of JAX's Boolean operators. A Boolean (aka Bool) is a data type that takes one of two possible values either being false or true (aka 0 or 1). Many operations in JAX such as greater or isclose generate arrays containing Booleans, while other operations such as logical_and or any operate on such arrays.
Example -
jax.numpy.greateryields zero gradients: As shown below,jax.graddoes not raise an error when called on its boolean operations. However, the returned gradients are zero for all array entries.
x = jax.random.uniform(jax.random.key(0), shape=(2, 10))
bool_array = jax.numpy.greater(x, 0.5)
def boolean_loss(x):
return jax.numpy.greater(x, 0.5).sum().astype("float32")
boolean_grads = jax.grad(boolean_loss)(x)
plot_array(x, title="x")
plot_array(bool_array, title="jax.numpy.greater(x, 0.5)")
plot_array(boolean_grads, title="jax.grad(jax.numpy.greater(x, 0.5).sum())(x)")
Example -
softjax.greater_styields useful gradients: Instead ofjax.numpy.greater, let's usesoftjax.greater_st(straight_through variant ofsoftjax.greater). As shown below, thanks to straight-through estimation,softjax.greater_styields exact Booleans while the gradient of the Boolean loss points in informative directions.
def soft_boolean_loss(x):
return sj.greater_st(x, 0.5).sum()
x = jax.random.uniform(jax.random.key(0), shape=(2, 10))
bool_array = sj.greater_st(x, 0.5)
boolean_grads = jax.grad(soft_boolean_loss)(x)
plot_array(x, title="x")
plot_array(bool_array, title="soft_greater_st(x, 0.5)")
plot_array(boolean_grads, title="jax.grad(soft_greater_st(x, 0.5).sum())(x)")
Generating soft bools¶
How does Softjax make Boolean logic operations differentiable? A real number $x\in \mathbb{R}$ could be mapped to a Bool using the Heaviside function
$$
H(x) = \begin{cases}
1, & x > 0 \\
0.5, & x=0\\
0, & x < 0
\end{cases}.
$$
The gradient of the Heaviside function (as implemented in JAX) is zero everywhere and hence unsuited for differentiable optimization. Instead of operating directly on Booleans, Softjax's differentiable logic operators resort to soft Booleans. A soft Boolean aka SoftBool can be interpreted as the probability of a Boolean being True.
We replace the heaviside function with differentiable surrogate such as the sigmoid function $\sigma(x) = \frac{1}{1+e^{-x}}$. While the sigmoid is the canonical example for mapping a real number to a SoftBool, Softjax provides additional surrogates.
plot(sj.heaviside, modes=["smooth", "c0", "c2"])
In the above case, the c0 and c2 relaxations have the advantage of altering the original function only in a bounded region, a property that can be desirable in some cases.
Given the concept of a SoftBool, a probabilistic surrogate for binary logical operations such as jax.numpy.equal and jax.numpy.greater is obtained by simply shifting the sigmoid.
Example - Greater operator:
sj.greater(x,y)corresponds to shiftingsj.heavisidebyyto the right. The output can be interpreted as the probability $P(x \geq y)\in[0,1]$ with $x\in\mathbb{R}$ and $y\in\mathbb{R}$.
def greater_than_1(x, mode="smooth", softness=0.1):
return sj.greater(x, y=jnp.array(1.0), mode=mode, softness=softness)
plot(greater_than_1, modes=["smooth", "c2"])
Manipulating soft bools¶
Softjax replaces a Boolean with a SoftBool, in turn Boolean logic operators are replaced in Softjax with fuzzy logic operators that effectively compute the probabilities of Boolean events.
Example - Logical NOT: Given a
SoftBool$P(B)$ (being the probability that a Boolean event $B$ occurs), the probability of the event not occuring is $P(\bar B) = 1 - P(B)$ as implemented insj.logical_not.
def logical_not(x: SoftBool) -> SoftBool:
return 1 - x
Given sj.logical_not, the probability that x is not greater equal 0.5 is given by sj.logical_not(sj.greater_st(x, 0.5)). Due to the straight-through trick, the function sj.logical_not(sj.greater_st(x, 0.5)) uses exact Boolean logic in the forward pass and the SoftBool probability computation in the backward pass.
def not_greater_st(x):
return sj.logical_not(sj.greater_st(x, y=0.5, mode="smooth", softness=0.1))
x = jnp.arange(-1, 1, 0.01)
values, grads = jax.vmap(jax.value_and_grad(not_greater_st))(x)
plot_value_and_grad(x, values, grads, label_func="not_greater_st")
Example - Logical AND: Given two
SoftBools$P(A)$ and $P(B)$, the probability that both independent events occur is $P(A \wedge B) = P(A) \cdot P(B)$.
def logical_and(x: SoftBool, y: SoftBool) -> SoftBool:
return x * y
plot_softbool_operation(sj.logical_and)
Example - Logical XOR: Softjax computes other soft logic operators such as sj.logical_xor by combining sj.logical_not and sj.logical_and.
def sj.logical_xor(x: SoftBool, y: SoftBool) -> SoftBool:
return logical_or(logical_and(x, logical_not(y)), logical_and(logical_not(x), y))
plot_softbool_operation(sj.logical_xor)
Selection with soft bools¶
Through the use of Fuzzy logic operators, Softjax provides a toolbox to make many non-differentiable functions of JAX differentiable.
Example - sj.where(): The function
jax.numpy.where(condition, x, y)selects elements of arrayxifcondition == Trueand otherwise selectsy. Softjax provides a differentiable surrogate for this function viasj.where(P, x, y)which effectively computes the expected value $\mathbb{E}[X] = P \cdot x + (1-P) \cdot y$.
greater = lambda x, y: sj.greater(x, y, mode="smooth", softness=0.1)
soft_where = lambda x, y: sj.where(greater(x, y), x, y)
x = jax.random.uniform(jax.random.key(0), shape=(2, 10))
y = jax.random.uniform(jax.random.key(1), shape=(2, 10))
plot_array(x, title="x")
plot_array(y, title="y")
plot_array(soft_where(x, y), title="soft_where(x>y, x, y)")
4. Soft indices¶
Softjax offers soft surrogates for functions that generate indices as outputs, such as argmax, argmin, top_k, argmedian, and argsort.
The main mechanism here is to replace hard indices with distributions over indices (SoftIndex), allowing for informative gradients.
Similar to how SoftBool required going from boolean logic to fuzzy logic, this now requires adjusting functions that do selection via indices. As such, we provide new versions of e.g. take_along_axis, dynamic_index_in_dim, and choose.
Combining the soft index generation with the selection then allows to define surrogates for the corresponding max, min, top_k, median and sort functions.
Generating soft indices¶
In JAX, functions like jax.argmax return integer indices as outputs, which can take values within {0, ..., len(x)-1}.
x = jnp.array([1, 2, 3])
print("jnp.argmax(x):", jnp.argmax(x))
jnp.argmax(x): 2
In comparison, SoftJAX computes a SoftIndex array. Each entry of a SoftIndex array contains the probability that the index is being selected.
x = jnp.array([1, 2, 3])
print("sj.argmax(x):", sj.argmax(x))
sj.argmax(x): [0.00398514 0.06104098 0.93497388]
Example - Softmax: The "softmax" (or more precisely "softargmax") is a commonly used differentiable surrogate for the
argmaxfunction (it is also the default softening mode insj.argmax). The $\text{softmax}(x) = \frac{\exp(x_i)}{\sum_j\exp(x_j)}$ returns a discrete probability distribution over indices (aka aSoftIndex). As shown in the plots below, the softmax is fully differentiable. It is commonly used for multi-class classification and in transformer networks.
When softness is low, sj.argmax concentrates probability on the true maximum index (e.g., [1.0, 0.0, 0.0]), recovering the hard maximum. When softness is higher, the result smoothly interpolates between values, providing useful gradients for optimization.
def cross_entropy(x, class_target=5):
probs = sj.argmax(x, softness=10.0)
target_one_hot = jax.nn.one_hot(class_target, num_classes=x.shape[0])
log_probs = jnp.log(probs)
return -(target_one_hot * log_probs).mean()
x = jax.random.normal(jax.random.key(0), shape=(10,))
probs = sj.argmax(x, softness=10.0)
lossgrads = jax.grad(cross_entropy)(x)
plot_softindices_1D(x, title="logits")
plot_softindices_1D(probs, title="index probabilities (softmax)")
plot_softindices_1D(lossgrads, title="gradients of cross entropy loss")
Note that while in a conventional array of indices, the index information is stored in the integer values, a SoftIndex stores the probabilities over possible indices in an extra dimension. By convention, we always put this additional dimension into the final axis. Except for this additional final dimension, the shape of the returned soft index matches that of the indices returned by standard JAX.
Here are a few examples of this:
x = jnp.arange(12).reshape((3, 4))
print("x.shape:", x.shape)
print("jnp.argmax(x, axis=1).shape:", jnp.argmax(x, axis=1).shape)
print("sj.argmax(x, axis=1).shape:", sj.argmax(x, axis=1).shape)
print("jnp.argmax(x, axis=0).shape:", jnp.argmax(x, axis=0).shape)
print("sj.argmax(x, axis=0).shape:", sj.argmax(x, axis=0).shape)
print(
"jnp.argmax(x, axis=1, keepdims=True).shape:",
jnp.argmax(x, axis=1, keepdims=True).shape,
)
print(
"sj.argmax(x, axis=1, keepdims=True).shape:",
sj.argmax(x, axis=1, keepdims=True).shape,
)
print(
"jnp.argmax(x, axis=0, keepdims=True).shape:",
jnp.argmax(x, axis=0, keepdims=True).shape,
)
print(
"sj.argmax(x, axis=0, keepdims=True).shape:",
sj.argmax(x, axis=0, keepdims=True).shape,
)
x.shape: (3, 4) jnp.argmax(x, axis=1).shape: (3,)
sj.argmax(x, axis=1).shape: (3, 4) jnp.argmax(x, axis=0).shape: (4,)
sj.argmax(x, axis=0).shape: (4, 3) jnp.argmax(x, axis=1, keepdims=True).shape: (3, 1) sj.argmax(x, axis=1, keepdims=True).shape: (3, 1, 4) jnp.argmax(x, axis=0, keepdims=True).shape: (1, 4) sj.argmax(x, axis=0, keepdims=True).shape: (1, 4, 3)
We also offer soft versions of argmedian, argsort and top_k.
x = jax.random.uniform(jax.random.key(0), shape=(4,))
print("x:", x)
print("\njnp.argmedian(x):", "Not implemented in standard JAX")
print("sj.argmedian(x):", sj.argmedian(x))
print("\njnp.argsort(x):", jnp.argsort(x))
print("sj.argsort(x):", sj.argsort(x))
print("\njax.lax.top_k(x, k=3)[1]:", jax.lax.top_k(x, k=3)[1])
print("sj.top_k(x, k=3)[1]:", sj.top_k(x, k=3)[1])
x: [0.41845711 0.21629545 0.96532146 0.57450053] jnp.argmedian(x): Not implemented in standard JAX
sj.argmedian(x): [0.46390511 0.05535614 0.01005734 0.47068141] jnp.argsort(x): [1 0 3 2]
sj.argsort(x): [[2.28724416e-01 7.68077209e-01 6.28715863e-10 3.19837430e-03] [7.23853520e-01 1.09326810e-01 1.19146313e-05 1.66807756e-01] [2.03956693e-01 1.38547145e-03 2.01027700e-02 7.74555066e-01] [1.52953501e-03 4.67307417e-07 9.02745621e-01 9.57243770e-02]] jax.lax.top_k(x, k=3)[1]: [2 3 0]
sj.top_k(x, k=3)[1]: [[1.52953501e-03 4.67307417e-07 9.02745621e-01 9.57243770e-02] [2.03956693e-01 1.38547145e-03 2.01027700e-02 7.74555066e-01] [7.23853520e-01 1.09326810e-01 1.19146313e-05 1.66807756e-01]]
Again, the shape of the returned SoftIndex matches that of the normal index array, except for an additional dimension in the last axis that matches the size of the input array along the specified axis. A few examples:
x = jax.random.uniform(jax.random.key(0), shape=(3, 4))
print("x.shape:", x.shape)
# standard JAX only added support for axis argument in jax.lax.top_k recently, normally uses last axis
print("\njax.lax.top_k(x, k=2, axis=1)[1].shape:", jax.lax.top_k(x, k=2)[1].shape)
print("sj.top_k(x, k=2, axis=1)[1].shape:", sj.top_k(x, k=2, axis=1)[1].shape)
print("sj.top_k(x, k=2, axis=0)[1].shape:", sj.top_k(x, k=2, axis=0)[1].shape)
print("\njnp.argsort(x, axis=1).shape:", jnp.argsort(x, axis=1).shape)
print("sj.argsort(x, axis=1).shape:", sj.argsort(x, axis=1).shape)
print("jnp.argsort(x, axis=0).shape:", jnp.argsort(x, axis=0).shape)
print("sj.argsort(x, axis=0).shape:", sj.argsort(x, axis=0).shape)
# standard JAX does not support argmedian
print("\nsj.argmedian(x, axis=1).shape:", sj.argmedian(x, axis=1).shape)
print("sj.argmedian(x, axis=0).shape:", sj.argmedian(x, axis=0).shape)
print(
"sj.argmedian(x, axis=1, keepdims=True).shape:",
sj.argmedian(x, axis=1, keepdims=True).shape,
)
print(
"sj.argmedian(x, axis=0, keepdims=True).shape:",
sj.argmedian(x, axis=0, keepdims=True).shape,
)
x.shape: (3, 4) jax.lax.top_k(x, k=2, axis=1)[1].shape: (3, 2)
sj.top_k(x, k=2, axis=1)[1].shape: (3, 2, 4)
sj.top_k(x, k=2, axis=0)[1].shape: (2, 4, 3) jnp.argsort(x, axis=1).shape: (3, 4)
sj.argsort(x, axis=1).shape: (3, 4, 4) jnp.argsort(x, axis=0).shape: (3, 4)
sj.argsort(x, axis=0).shape: (3, 4, 3) sj.argmedian(x, axis=1).shape: (3, 4)
sj.argmedian(x, axis=0).shape: (4, 3) sj.argmedian(x, axis=1, keepdims=True).shape: (3, 1, 4) sj.argmedian(x, axis=0, keepdims=True).shape: (1, 4, 3)
Note: All of the functions in this section come with the five modes: hard, smooth, c0, c1, and c2.
hard mode produces one-hot soft indices and is mainly used in straight-through estimation. smooth (entropic) is the recommended soft default and reduces all operations to either a softmax or an entropy-regularized optimal transport problem.
c0 (L2) reduces operations to L2 projection onto the unit simplex or the permutahedron, which can be used to produce sparse outputs. c1 (p=3/2 p-norm) provides C1 differentiable projections, and c2 (p=4/3 p-norm) provides C2 twice-differentiable projections. See the API documentation for details.
Selection with SoftIndices¶
Given a SoftIndex, Softjax provides (differentiable) helper functions for selecting array elements, mirroring the non-differentiable indexing in standard JAX. Put simply, entries of an array are selected by computing the expected value:
$$\mathrm{E}(arr, p) = \sum_{i} arr[i] \cdot p[i] = arr^{\top} \cdot p$$
where $p$ is the SoftIndex.
Example
sj.take_along_axis: The functionsj.take_along_axisis central to this selection mechanism. It generalizesjnp.take_along_axisto work with probability distributions (SoftIndices) instead of just integer indices.
The standard jnp.take_along_axis(arr, indices, axis) selects elements from arr using integer indices. Conceptually, it works by:
- Slicing along the specified axis to get 1D arrays
- Using the corresponding indices to select elements
- Assembling the results into the output array
One of its main uses is to accept the index output of e.g. jnp.argmax and select the maximum values at the indexed locations.
While jnp.take_along_axis uses integer indices out_1d[j] = arr_1d[indices_1d[j]], sj.take_along_axis accepts a SoftIndex to compute the corresponding the weighted sum: out_1d[j] = sum_i(arr_1d[i] * soft_indices_2d[j, i]).
x = jax.random.uniform(jax.random.key(0), shape=(2, 3))
print("x:\n", x)
indices = jnp.argmin(x, axis=1, keepdims=True)
print("min_jnp:\n", jnp.take_along_axis(x, indices, axis=1))
indices_onehot = sj.argmin(x, axis=1, mode="hard", keepdims=True)
print("min_sj_hard:\n", sj.take_along_axis(x, indices_onehot, axis=1))
indices_soft = sj.argmin(x, axis=1, keepdims=True)
print("min_sj_soft:\n", sj.take_along_axis(x, indices_soft, axis=1))
x: [[0.41845711 0.21629545 0.96532146] [0.57450053 0.53222649 0.35490518]]
min_jnp: [[0.21629545] [0.35490518]]
min_sj_hard: [[0.21629545] [0.35490518]]
min_sj_soft: [[0.25864805] [0.35883677]]
As a convenience, this combination of sj.take_along_axis with SoftIndex-generataing functions is already implemented ino Softjax's max, median, top_k and sort functions.
x = jax.random.uniform(jax.random.key(0), shape=(4,))
print("x:", x)
print("\njnp.max(x):", jnp.max(x))
print("sj.max(x, mode='hard'):", sj.max(x, mode="hard"))
print("sj.max(x):", sj.max(x))
print("\njnp.median(x):", jnp.median(x))
print("sj.median(x, mode='hard'):", sj.median(x, mode="hard"))
print("sj.median(x):", sj.median(x))
print("\njax.lax.top_k(x, k=2)[0]:", jax.lax.top_k(x, k=2)[0])
print("sj.top_k(x, k=2, mode='hard')[0]:", sj.top_k(x, k=2, mode="hard")[0])
print("sj.top_k(x, k=2)[0]:", sj.top_k(x, k=2)[0])
print("\njnp.sort(x):", jnp.sort(x))
print("sj.sort(x, mode='hard'):", sj.sort(x, mode="hard"))
print("sj.sort(x):", sj.sort(x))
x: [0.41845711 0.21629545 0.96532146 0.57450053] jnp.max(x): 0.9653214611189975 sj.max(x, mode='hard'): 0.9653214611189975
sj.max(x): 0.9375883761565286 jnp.median(x): 0.49647882272194555 sj.median(x, mode='hard'): 0.49647882272194555 sj.median(x): 0.4862129625709338
jax.lax.top_k(x, k=2)[0]: [0.96532146 0.57450053] sj.top_k(x, k=2, mode='hard')[0]: [0.96532146 0.57450053] sj.top_k(x, k=2)[0]: [0.92707357 0.55003473]
jnp.sort(x): [0.21629545 0.41845711 0.57450053 0.96532146] sj.sort(x, mode='hard'): [0.21629545 0.41845711 0.57450053 0.96532146]
sj.sort(x): [0.26368044 0.42239119 0.55003473 0.92707357]
Finally, we also offer a soft rank operation. While it does not return a SoftIndex (because its output is the same shape as the input), it relies on similar computations under the hood as e.g. sort, and also offers the same modes.
x = jax.random.uniform(jax.random.key(0), shape=(5,))
print("x:\n", x)
# This computes the rank operation
print("jnp.argsort(jnp.argsort(x)):\n", jnp.argsort(jnp.argsort(x)))
print(
"sj.rank(x, descending=False, mode='hard'):\n",
sj.rank(x, descending=False, mode="hard"),
)
print("sj.rank(x, descending=False):\n", sj.rank(x, descending=False))
x: [0.41845711 0.21629545 0.96532146 0.57450053 0.53222649] jnp.argsort(jnp.argsort(x)): [1 0 4 3 2]
sj.rank(x, descending=False, mode='hard'): [2. 1. 5. 4. 3.]
sj.rank(x, descending=False): [1.99489473 1.09253493 4.98783305 3.69345064 3.23808629]
5. Autograd-safe operators¶
Several standard math functions produce NaN gradients at boundary points. For example:
jnp.arcsin(x)has gradient1/sqrt(1-x²)which is Inf atx=±1jnp.log(x)has gradient1/xwhich is Inf atx=0x/yhas gradient-x/y²w.r.t.ywhich is Inf aty=0jnp.linalg.norm(x)has gradientx/‖x‖which is NaN atx=0
SoftJAX provides autograd-safe replacements for these functions using the double-where trick: the inner jnp.where clamps the input away from the dangerous boundary to avoid feeding it to the unsafe function, and the outer jnp.where selects the correct output value at the boundary. This ensures that both the forward pass and the backward pass (via JAX's autodiff) produce correct, finite results.
# The double-where pattern:
def safe_fn(x):
safe_x = jnp.where(is_safe, x, safe_default) # inner: avoid unsafe input
return jnp.where(is_safe, unsafe_fn(safe_x), boundary_value) # outer: correct output
The available autograd-safe functions are: sj.sqrt, sj.arcsin, sj.arccos, sj.log, sj.div, and sj.norm.
import softjax as sj
# arcsin: safe at x=±1 where jnp.arcsin has infinite gradient
x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])
print("jnp.arcsin:", jnp.arcsin(x))
print("sj.arcsin: ", sj.arcsin(x))
print("Grad jnp.arcsin at x=1:", jax.grad(lambda z: jnp.arcsin(z))(1.0), " (inf!)")
print("Grad sj.arcsin at x=1:", jax.grad(lambda z: sj.arcsin(z))(1.0), " (finite)")
# arccos: safe at x=±1
print("\njnp.arccos:", jnp.arccos(x))
print("sj.arccos: ", sj.arccos(x))
# log: safe at x=0 where jnp.log has -inf value and infinite gradient
x_log = jnp.array([0.0, 0.5, 1.0, 2.0])
print("\njnp.log:", jnp.log(x_log))
print("sj.log: ", sj.log(x_log))
print("Grad jnp.log at x=0:", jax.grad(lambda z: jnp.log(z))(0.0), " (inf!)")
print("Grad sj.log at x=0:", jax.grad(lambda z: sj.log(z))(0.0), " (finite)")
# div: safe when denominator is zero
print("\n1.0 / 0.0: ", jnp.array(1.0) / jnp.array(0.0))
print("sj.div(1, 0): ", sj.div(jnp.array(1.0), jnp.array(0.0)))
# norm: safe for the zero vector where jnp.linalg.norm has NaN gradient
z = jnp.zeros(3)
print("\njnp.linalg.norm(zeros):", jnp.linalg.norm(z))
print("sj.norm(zeros): ", sj.norm(z))
print(
"Grad jnp.linalg.norm at zeros:",
jax.grad(lambda v: jnp.linalg.norm(v))(z),
" (NaN!)",
)
print("Grad sj.norm at zeros: ", jax.grad(lambda v: sj.norm(v))(z), " (finite)")
jnp.arcsin: [-1.57079633 -0.52359878 0. 0.52359878 1.57079633]
sj.arcsin: [-1.57079633 -0.52359878 0. 0.52359878 1.57079633]
Grad jnp.arcsin at x=1: inf (inf!)
Grad sj.arcsin at x=1: 0.0 (finite)
jnp.arccos: [3.14159265 2.0943951 1.57079633 1.04719755 0. ] sj.arccos: [3.14159265 2.0943951 1.57079633 1.04719755 0. ] jnp.log: [ -inf -0.69314718 0. 0.69314718]
sj.log: [ 0. -0.69314718 0. 0.69314718]
Grad jnp.log at x=0: inf (inf!) Grad sj.log at x=0: 0.0 (finite) 1.0 / 0.0: inf sj.div(1, 0): 0.0
jnp.linalg.norm(zeros): 0.0 sj.norm(zeros): 0.0
Grad jnp.linalg.norm at zeros: [nan nan nan] (NaN!)
Grad sj.norm at zeros: [0. 0. 0.] (finite)