Quick Example¶
This notebook contains the quick examples from the Readme.
import jax
import jax.numpy as jnp
import softjax as sj
jnp.set_printoptions(precision=4, suppress=True)
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_default_matmul_precision", "high")
jax.config.update("jax_platforms", "cpu")
The default mode of all our Softjax functions is mode=entropic. This uses some well-known exponential-function-based relaxations, e.g. an exponential sigmoid as a soft heaviside function and the softplus function as a soft ReLU function.
We can replicate the original JAX function outputs (plus some one-hot generation for index functions) by setting mode=hard. This is especially useful for straight-through estimators, where we want to easily switch between hard and soft modes.
Note that all Softjax functions taking a mode argument also take a softness argument, which determines how soft the relaxation is. Note that while the limit of softness going to zero should always recover the hard function, setting softness=0 is not allowed. Instead, use the mode=hard setting.
x = jnp.array([-0.2, -1.0, 0.3, 1.0])
# Elementwise operators
print("\nJAX absolute:", jnp.abs(x))
print("SoftJAX absolute (hard mode):", sj.abs(x, mode="hard"))
print("SoftJAX absolute (soft mode):", sj.abs(x))
print("\nJAX clip:", jnp.clip(x, -0.5, 0.5))
print("SoftJAX clip (hard mode):", sj.clip(x, -0.5, 0.5, mode="hard"))
print("SoftJAX clip (soft mode):", sj.clip(x, -0.5, 0.5))
print("\nJAX heaviside:", jnp.heaviside(x, 0.5))
print("SoftJAX heaviside (hard mode):", sj.heaviside(x, mode="hard"))
print("SoftJAX heaviside (soft mode):", sj.heaviside(x))
print("\nJAX ReLU:", jax.nn.relu(x))
print("SoftJAX ReLU (hard mode):", sj.relu(x, mode="hard"))
print("SoftJAX ReLU (soft mode):", sj.relu(x))
print("\nJAX round:", jnp.round(x))
print("SoftJAX round (hard mode):", sj.round(x, mode="hard"))
print("SoftJAX round (soft mode):", sj.round(x))
print("\nJAX sign:", jnp.sign(x))
print("SoftJAX sign (hard mode):", sj.sign(x, mode="hard"))
print("SoftJAX sign (soft mode):", sj.sign(x))
JAX absolute: [0.2 1. 0.3 1. ] SoftJAX absolute (hard mode): [0.2 1. 0.3 1. ] SoftJAX absolute (soft mode): [0.1523 0.9999 0.2715 0.9999] JAX clip: [-0.2 -0.5 0.3 0.5] SoftJAX clip (hard mode): [-0.2 -0.5 0.3 0.5]
SoftJAX clip (soft mode): [-0.1952 -0.4993 0.2873 0.4993] JAX heaviside: [0. 0. 1. 1.] SoftJAX heaviside (hard mode): [0. 0. 1. 1.] SoftJAX heaviside (soft mode): [0.1192 0. 0.9526 1. ]
JAX ReLU: [0. 0. 0.3 1. ] SoftJAX ReLU (hard mode): [0. 0. 0.3 1. ] SoftJAX ReLU (soft mode): [0.0127 0. 0.3049 1. ] JAX round: [-0. -1. 0. 1.] SoftJAX round (hard mode): [-0. -1. 0. 1.]
SoftJAX round (soft mode): [-0.0465 -1. 0.1189 1. ] JAX sign: [-1. -1. 1. 1.] SoftJAX sign (hard mode): [-1. -1. 1. 1.] SoftJAX sign (soft mode): [-0.7616 -0.9999 0.9051 0.9999]
# Array-valued operators
print("\nJAX max:", jnp.max(x))
print("SoftJAX max (hard mode):", sj.max(x, mode="hard"))
print("SoftJAX max (soft mode):", sj.max(x))
print("\nJAX min:", jnp.min(x))
print("SoftJAX min (hard mode):", sj.min(x, mode="hard"))
print("SoftJAX min (soft mode):", sj.min(x))
print("\nJAX sort:", jnp.sort(x))
print("SoftJAX sort (hard mode):", sj.sort(x, mode="hard"))
print("SoftJAX sort (soft mode):", sj.sort(x))
print("\nJAX median:", jnp.median(x))
print("SoftJAX median (hard mode):", sj.median(x, mode="hard"))
print("SoftJAX median (soft mode):", sj.median(x))
print("\nJAX top_k:", jax.lax.top_k(x, k=3)[0])
print("SoftJAX top_k (hard mode):", sj.top_k(x, k=3, mode="hard")[0])
print("SoftJAX top_k (soft mode):", sj.top_k(x, k=3)[0])
print("\nJAX ranking:", jnp.argsort(jnp.argsort(x)))
print("SoftJAX ranking (hard mode):", sj.ranking(x, mode="hard", descending=False))
print("SoftJAX ranking (soft mode):", sj.ranking(x, descending=False))
JAX max: 1.0 SoftJAX max (hard mode): 1.0
SoftJAX max (soft mode): 0.9993548976691374
JAX min: -1.0 SoftJAX min (hard mode): -1.0 SoftJAX min (soft mode): -0.9997287789452775 JAX sort: [-1. -0.2 0.3 1. ] SoftJAX sort (hard mode): [-1. -0.2 0.3 1. ]
SoftJAX sort (soft mode): [-0.9997 -0.1969 0.2973 0.9994] JAX median: 0.04999999999999999 SoftJAX median (hard mode): 0.04999999999999999 SoftJAX median (soft mode): 0.05000033589501627
JAX top_k: [ 1. 0.3 -0.2] SoftJAX top_k (hard mode): [ 1. 0.3 -0.2]
SoftJAX top_k (soft mode): [ 0.9994 0.2973 -0.1969] JAX ranking: [1 0 2 3] SoftJAX ranking (hard mode): [1. 0. 2. 3.]
SoftJAX ranking (soft mode): [1.0064 0.0003 1.9942 2.9991]
# Operators returning indices
print("\nJAX argmax:", jnp.argmax(x))
print("SoftJAX argmax (hard mode):", sj.argmax(x, mode="hard"))
print("SoftJAX argmax (soft mode):", sj.argmax(x))
print("\nJAX argmin:", jnp.argmin(x))
print("SoftJAX argmin (hard mode):", sj.argmin(x, mode="hard"))
print("SoftJAX argmin (soft mode):", sj.argmin(x))
print("\nJAX argmedian:", "Not implemented in standard JAX")
print("SoftJAX argmedian (hard mode):", sj.argmedian(x, mode="hard"))
print("SoftJAX argmedian (soft mode):", sj.argmedian(x))
print("\nJAX argsort:", jnp.argsort(x))
print("SoftJAX argsort (hard mode):", sj.argsort(x, mode="hard"))
print("SoftJAX argsort (soft mode):", sj.argsort(x))
print("\nJAX argtop_k:", jax.lax.top_k(x, k=3)[1])
print("SoftJAX argtop_k (hard mode):", sj.top_k(x, k=3, mode="hard")[1])
print("SoftJAX argtop_k (soft mode):", sj.top_k(x, k=3)[1])
JAX argmax: 3 SoftJAX argmax (hard mode): [0. 0. 0. 1.] SoftJAX argmax (soft mode): [0. 0. 0.0009 0.9991] JAX argmin: 1 SoftJAX argmin (hard mode): [0. 1. 0. 0.] SoftJAX argmin (soft mode): [0.0003 0.9997 0. 0. ] JAX argmedian: Not implemented in standard JAX
SoftJAX argmedian (hard mode): [0.5 0. 0.5 0. ] SoftJAX argmedian (soft mode): [0.5 0. 0.5 0. ] JAX argsort: [1 0 2 3] SoftJAX argsort (hard mode): [[0. 1. 0. 0.] [1. 0. 0. 0.] [0. 0. 1. 0.] [0. 0. 0. 1.]] SoftJAX argsort (soft mode): [[0.0003 0.9997 0. 0. ] [0.993 0.0003 0.0067 0. ] [0.0067 0. 0.9924 0.0009] [0. 0. 0.0009 0.9991]] JAX argtop_k: [3 2 0] SoftJAX argtop_k (hard mode): [[0. 0. 0. 1.] [0. 0. 1. 0.] [1. 0. 0. 0.]] SoftJAX argtop_k (soft mode): [[0. 0. 0.0009 0.9991] [0.0067 0. 0.9924 0.0009] [0.993 0.0003 0.0067 0. ]]
y = jnp.array([0.2, -0.5, 0.5, -1.0])
# Comparison operators
print("\nJAX greater:", jnp.greater(x, y))
print("SoftJAX greater (hard mode):", sj.greater(x, y, mode="hard"))
print("SoftJAX greater (soft mode):", sj.greater(x, y))
print("\nJAX greater equal:", jnp.greater_equal(x, y))
print("SoftJAX greater equal (hard mode):", sj.greater_equal(x, y, mode="hard"))
print("SoftJAX greater equal (soft mode):", sj.greater_equal(x, y))
print("\nJAX less:", jnp.less(x, y))
print("SoftJAX less (hard mode):", sj.less(x, y, mode="hard"))
print("SoftJAX less (soft mode):", sj.less(x, y))
print("\nJAX less equal:", jnp.less_equal(x, y))
print("SoftJAX less equal (hard mode):", sj.less_equal(x, y, mode="hard"))
print("SoftJAX less equal (soft mode):", sj.less_equal(x, y))
print("\nJAX equal:", jnp.equal(x, y))
print("SoftJAX equal (hard mode):", sj.equal(x, y, mode="hard"))
print("SoftJAX equal (soft mode):", sj.equal(x, y))
print("\nJAX not equal:", jnp.not_equal(x, y))
print("SoftJAX not equal (hard mode):", sj.not_equal(x, y, mode="hard"))
print("SoftJAX not equal (soft mode):", sj.not_equal(x, y))
print("\nJAX isclose:", jnp.isclose(x, y))
print("SoftJAX isclose (hard mode):", sj.isclose(x, y, mode="hard"))
print("SoftJAX isclose (soft mode):", sj.isclose(x, y))
JAX greater: [False False False True] SoftJAX greater (hard mode): [0. 0. 0. 1.] SoftJAX greater (soft mode): [0.018 0.0067 0.1192 1. ] JAX greater equal: [False False False True] SoftJAX greater equal (hard mode): [0. 0. 0. 1.] SoftJAX greater equal (soft mode): [0.018 0.0067 0.1192 1. ] JAX less: [ True True True False] SoftJAX less (hard mode): [1. 1. 1. 0.] SoftJAX less (soft mode): [0.982 0.9933 0.8808 0. ]
JAX less equal: [ True True True False] SoftJAX less equal (hard mode): [1. 1. 1. 0.] SoftJAX less equal (soft mode): [0.982 0.9933 0.8808 0. ] JAX equal: [False False False False] SoftJAX equal (hard mode): [0. 0. 0. 0.]
SoftJAX equal (soft mode): [0.018 0.0067 0.1192 0. ] JAX not equal: [ True True True True] SoftJAX not equal (hard mode): [1. 1. 1. 1.] SoftJAX not equal (soft mode): [0.982 0.9933 0.8808 1. ] JAX isclose: [False False False False] SoftJAX isclose (hard mode): [0. 0. 0. 0.] SoftJAX isclose (soft mode): [0.018 0.0067 0.1192 0. ]
# Logical operators
fuzzy_a = jnp.array([0.1, 0.2, 0.8, 1.0])
fuzzy_b = jnp.array([0.7, 0.3, 0.1, 0.9])
bool_a = fuzzy_a >= 0.5
bool_b = fuzzy_b >= 0.5
print("\nJAX AND:", jnp.logical_and(bool_a, bool_b))
print("SoftJAX AND:", sj.logical_and(fuzzy_a, fuzzy_b))
print("\nJAX OR:", jnp.logical_or(bool_a, bool_b))
print("SoftJAX OR:", sj.logical_or(fuzzy_a, fuzzy_b))
print("\nJAX NOT:", jnp.logical_not(bool_a))
print("SoftJAX NOT:", sj.logical_not(fuzzy_a))
print("\nJAX XOR:", jnp.logical_xor(bool_a, bool_b))
print("SoftJAX XOR:", sj.logical_xor(fuzzy_a, fuzzy_b))
print("\nJAX ALL:", jnp.all(bool_a))
print("SoftJAX ALL:", sj.all(fuzzy_a))
print("\nJAX ANY:", jnp.any(bool_a))
print("SoftJAX ANY:", sj.any(fuzzy_a))
# Selection operators
print("\nJAX Where:", jnp.where(bool_a, x, y))
print("SoftJAX Where:", sj.where(fuzzy_a, x, y))
JAX AND: [False False False True]
SoftJAX AND: [0.2646 0.2449 0.2828 0.9487] JAX OR: [ True False True True] SoftJAX OR: [0.4804 0.2517 0.5757 1. ]
JAX NOT: [ True True False False] SoftJAX NOT: [0.9 0.8 0.2 0. ] JAX XOR: [ True False True False] SoftJAX XOR: [0.587 0.435 0.6394 0.1731] JAX ALL: False
SoftJAX ALL: 0.35565588200778464 JAX ANY: True
SoftJAX ANY: 0.9980519925071494 JAX Where: [ 0.2 -0.5 0.3 1. ] SoftJAX Where: [ 0.16 -0.6 0.34 1. ]
# Straight-through operators: Use hard function on forward and soft on backward
print("Straight-through ReLU:", sj.relu_st(x))
print("Straight-through sort:", sj.sort_st(x))
print("Straight-through argtop_k:", sj.top_k_st(x, k=3)[1])
print("Straight-through greater:", sj.greater_st(x, y))
# And many more...
Straight-through ReLU: [0. 0. 0.3 1. ] Straight-through sort: [-1. -0.2 0.3 1. ]
Straight-through argtop_k: [[0. 0. 0. 1.] [0. 0. 1. 0.] [1. 0. 0. 0.]] Straight-through greater: [0. 0. 0. 1.]