Plots¶
Below we plot many of the Softjax functions for different modes and softnesses to give an idea of how they look like. All functions are designed to be soft when used on inputs in an interval [0,1] when setting default softness of 1.0. When inputs in other ranges are used, softness can be scaled accrdingly, e.g. on inputs distributed in the inteval [0, 100] a softness of 100.0 should result in a soft behavior.
Sigmoid-based¶
In [2]:
Copied!
sigmoid_modes = ["entropic", "euclidean", "pseudohuber", "cubic", "quintic"]
sigmoid_modes = ["entropic", "euclidean", "pseudohuber", "cubic", "quintic"]
In [3]:
Copied!
def median_newton(x, **kwargs):
return sj.median_newton(jnp.array([x, -0.5, 0.5]), **kwargs)
def greater_than_1(x, **kwargs):
return sj.greater(x, jnp.array(1.0), **kwargs)
def less_than_1(x, **kwargs):
return sj.less(x, jnp.array(1.0), **kwargs)
def equal_to_0(x, **kwargs):
return sj.equal(x, jnp.array(0.0), **kwargs)
def not_equal_to_0(x, **kwargs):
return sj.not_equal(x, jnp.array(0.0), **kwargs)
def isclose_to_0(x, **kwargs):
return sj.not_equal(x, jnp.array(0.0), **kwargs)
plot(sj.heaviside, modes=sigmoid_modes)
plot(sj.abs, modes=sigmoid_modes)
plot(sj.sign, modes=sigmoid_modes)
plot(sj.round, modes=sigmoid_modes)
plot(median_newton, modes=sigmoid_modes)
plot(greater_than_1, modes=sigmoid_modes)
plot(less_than_1, modes=sigmoid_modes)
plot(equal_to_0, modes=sigmoid_modes)
plot(not_equal_to_0, modes=sigmoid_modes)
plot(isclose_to_0, modes=sigmoid_modes)
def median_newton(x, **kwargs):
return sj.median_newton(jnp.array([x, -0.5, 0.5]), **kwargs)
def greater_than_1(x, **kwargs):
return sj.greater(x, jnp.array(1.0), **kwargs)
def less_than_1(x, **kwargs):
return sj.less(x, jnp.array(1.0), **kwargs)
def equal_to_0(x, **kwargs):
return sj.equal(x, jnp.array(0.0), **kwargs)
def not_equal_to_0(x, **kwargs):
return sj.not_equal(x, jnp.array(0.0), **kwargs)
def isclose_to_0(x, **kwargs):
return sj.not_equal(x, jnp.array(0.0), **kwargs)
plot(sj.heaviside, modes=sigmoid_modes)
plot(sj.abs, modes=sigmoid_modes)
plot(sj.sign, modes=sigmoid_modes)
plot(sj.round, modes=sigmoid_modes)
plot(median_newton, modes=sigmoid_modes)
plot(greater_than_1, modes=sigmoid_modes)
plot(less_than_1, modes=sigmoid_modes)
plot(equal_to_0, modes=sigmoid_modes)
plot(not_equal_to_0, modes=sigmoid_modes)
plot(isclose_to_0, modes=sigmoid_modes)
Softplus-based¶
In [4]:
Copied!
softplus_modes = ["entropic", "euclidean", "quartic"]
softplus_modes_gated = [
"gated_entropic",
"gated_euclidean",
"gated_cubic",
"gated_quintic",
"gated_pseudohuber",
]
softplus_modes = ["entropic", "euclidean", "quartic"]
softplus_modes_gated = [
"gated_entropic",
"gated_euclidean",
"gated_cubic",
"gated_quintic",
"gated_pseudohuber",
]
In [5]:
Copied!
def clip_between_0_and_1(x, **kwargs):
return sj.clip(x, jnp.array(0.0), jnp.array(1.0), **kwargs)
plot(sj.relu, modes=softplus_modes)
plot(sj.relu, modes=softplus_modes_gated)
plot(clip_between_0_and_1, modes=softplus_modes)
plot(clip_between_0_and_1, modes=softplus_modes_gated)
def clip_between_0_and_1(x, **kwargs):
return sj.clip(x, jnp.array(0.0), jnp.array(1.0), **kwargs)
plot(sj.relu, modes=softplus_modes)
plot(sj.relu, modes=softplus_modes_gated)
plot(clip_between_0_and_1, modes=softplus_modes)
plot(clip_between_0_and_1, modes=softplus_modes_gated)
Simplex-projection / optimal transport based¶
In [6]:
Copied!
projection_modes = ["entropic", "euclidean"]
projection_modes = ["entropic", "euclidean"]
In [7]:
Copied!
def sj_max(x, **kwargs):
return sj.max(jnp.array([x, 0.5]), **kwargs)
def sj_min(x, **kwargs):
return sj.min(jnp.array([x, 0.5]), **kwargs)
def sj_median(x, **kwargs):
return sj.median(jnp.array([x, -1.0, 1.0, -2.0, 2.0]), **kwargs)
def sj_sort(x, **kwargs):
return sj.sort(jnp.array([x, -1.0, 1.0]), **kwargs)[1]
def sj_top_k(x, **kwargs):
return sj.top_k(jnp.array([x, -1.0, 1.0]), k=2, **kwargs)[0][1]
plot(sj_max, modes=projection_modes)
plot(sj_min, modes=projection_modes)
plot(sj_median, modes=projection_modes)
plot(sj_sort, modes=projection_modes)
plot(sj_top_k, modes=projection_modes)
# plot(sj_median, modes=projection_modes, fast=False)
# plot(sj_sort, modes=projection_modes, fast=False)
# plot(sj_top_k, modes=projection_modes, fast=False)
def sj_max(x, **kwargs):
return sj.max(jnp.array([x, 0.5]), **kwargs)
def sj_min(x, **kwargs):
return sj.min(jnp.array([x, 0.5]), **kwargs)
def sj_median(x, **kwargs):
return sj.median(jnp.array([x, -1.0, 1.0, -2.0, 2.0]), **kwargs)
def sj_sort(x, **kwargs):
return sj.sort(jnp.array([x, -1.0, 1.0]), **kwargs)[1]
def sj_top_k(x, **kwargs):
return sj.top_k(jnp.array([x, -1.0, 1.0]), k=2, **kwargs)[0][1]
plot(sj_max, modes=projection_modes)
plot(sj_min, modes=projection_modes)
plot(sj_median, modes=projection_modes)
plot(sj_sort, modes=projection_modes)
plot(sj_top_k, modes=projection_modes)
# plot(sj_median, modes=projection_modes, fast=False)
# plot(sj_sort, modes=projection_modes, fast=False)
# plot(sj_top_k, modes=projection_modes, fast=False)
In [ ]:
Copied!
In [ ]:
Copied!