Plots¶
Below we plot many of the Softtorch functions for different modes and softnesses to give an idea of how they look like. All functions are designed to be soft when used on inputs in an interval like [-1, 1] or [0, 1] and match the respective hard functions outside of these intervals.
In [2]:
Copied!
sigmoid_modes = ["entropic", "euclidean", "pseudohuber", "cubic", "quintic"]
sigmoid_modes = ["entropic", "euclidean", "pseudohuber", "cubic", "quintic"]
In [3]:
Copied!
def median_newton(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
stacked = torch.stack(
[x, torch.full_like(x, -0.5), torch.full_like(x, 0.5)], dim=-1
)
median_fn = getattr(st, "median_newton", None) or st.median
res = median_fn(stacked, dim=-1, **kwargs)
return res.values if hasattr(res, "values") else res
def greater_than_1(x, **kwargs):
return st.greater(x, torch.tensor(1.0), **kwargs)
def less_than_1(x, **kwargs):
return st.less(x, torch.tensor(1.0), **kwargs)
def equal_to_0(x, **kwargs):
return st.equal(x, torch.tensor(0.0), **kwargs)
def not_equal_to_0(x, **kwargs):
return st.not_equal(x, torch.tensor(0.0), **kwargs)
def isclose_to_0(x, **kwargs):
return st.not_equal(x, torch.tensor(0.0), **kwargs)
plot(st.heaviside, modes=sigmoid_modes)
plot(st.abs, modes=sigmoid_modes)
plot(st.sign, modes=sigmoid_modes)
plot(st.round, modes=sigmoid_modes)
plot(greater_than_1, modes=sigmoid_modes)
plot(less_than_1, modes=sigmoid_modes)
plot(equal_to_0, modes=sigmoid_modes)
plot(not_equal_to_0, modes=sigmoid_modes)
plot(isclose_to_0, modes=sigmoid_modes)
def median_newton(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
stacked = torch.stack(
[x, torch.full_like(x, -0.5), torch.full_like(x, 0.5)], dim=-1
)
median_fn = getattr(st, "median_newton", None) or st.median
res = median_fn(stacked, dim=-1, **kwargs)
return res.values if hasattr(res, "values") else res
def greater_than_1(x, **kwargs):
return st.greater(x, torch.tensor(1.0), **kwargs)
def less_than_1(x, **kwargs):
return st.less(x, torch.tensor(1.0), **kwargs)
def equal_to_0(x, **kwargs):
return st.equal(x, torch.tensor(0.0), **kwargs)
def not_equal_to_0(x, **kwargs):
return st.not_equal(x, torch.tensor(0.0), **kwargs)
def isclose_to_0(x, **kwargs):
return st.not_equal(x, torch.tensor(0.0), **kwargs)
plot(st.heaviside, modes=sigmoid_modes)
plot(st.abs, modes=sigmoid_modes)
plot(st.sign, modes=sigmoid_modes)
plot(st.round, modes=sigmoid_modes)
plot(greater_than_1, modes=sigmoid_modes)
plot(less_than_1, modes=sigmoid_modes)
plot(equal_to_0, modes=sigmoid_modes)
plot(not_equal_to_0, modes=sigmoid_modes)
plot(isclose_to_0, modes=sigmoid_modes)
Softplus-based¶
In [4]:
Copied!
softplus_modes = ["entropic", "euclidean", "quartic"]
softplus_modes_gated = [
"gated_entropic",
"gated_euclidean",
"gated_cubic",
"gated_quintic",
"gated_pseudohuber",
]
softplus_modes = ["entropic", "euclidean", "quartic"]
softplus_modes_gated = [
"gated_entropic",
"gated_euclidean",
"gated_cubic",
"gated_quintic",
"gated_pseudohuber",
]
In [5]:
Copied!
def clamp_between_0_and_1(x, **kwargs):
return st.clamp(x, torch.tensor(0.0), torch.tensor(1.0), **kwargs)
plot(st.relu, modes=softplus_modes)
plot(st.relu, modes=softplus_modes_gated)
plot(clamp_between_0_and_1, modes=softplus_modes)
plot(clamp_between_0_and_1, modes=softplus_modes_gated)
def clamp_between_0_and_1(x, **kwargs):
return st.clamp(x, torch.tensor(0.0), torch.tensor(1.0), **kwargs)
plot(st.relu, modes=softplus_modes)
plot(st.relu, modes=softplus_modes_gated)
plot(clamp_between_0_and_1, modes=softplus_modes)
plot(clamp_between_0_and_1, modes=softplus_modes_gated)
Simplex-projection / optimal transport based¶
In [6]:
Copied!
projection_modes = ["entropic", "euclidean"]
projection_modes = ["entropic", "euclidean"]
In [7]:
Copied!
def soft_max(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
augmented = torch.stack([x, torch.full_like(x, 0.5)], dim=-1)
return st.max(augmented, dim=-1, **kwargs).values
def soft_min(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
augmented = torch.stack([x, torch.full_like(x, 0.5)], dim=-1)
return st.min(augmented, dim=-1, **kwargs).values
def soft_median(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
augmented = torch.stack(
[
x,
torch.full_like(x, -1.0),
torch.full_like(x, 1.0),
torch.full_like(x, -2.0),
torch.full_like(x, 2.0),
],
dim=-1,
)
return st.median(augmented, dim=-1, **kwargs).values
def soft_sort(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
augmented = torch.stack(
[x, torch.full_like(x, -1.0), torch.full_like(x, 1.0)], dim=-1
)
return st.sort(augmented, **kwargs).values[..., 2]
def soft_top_k(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
augmented = torch.stack(
[x, torch.full_like(x, -1.0), torch.full_like(x, 1.0)], dim=-1
)
return st.topk(augmented, k=2, **kwargs).values[..., 1]
plot(soft_max, modes=projection_modes)
plot(soft_min, modes=projection_modes)
plot(soft_median, modes=projection_modes)
plot(soft_sort, modes=projection_modes)
plot(soft_top_k, modes=projection_modes)
def soft_max(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
augmented = torch.stack([x, torch.full_like(x, 0.5)], dim=-1)
return st.max(augmented, dim=-1, **kwargs).values
def soft_min(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
augmented = torch.stack([x, torch.full_like(x, 0.5)], dim=-1)
return st.min(augmented, dim=-1, **kwargs).values
def soft_median(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
augmented = torch.stack(
[
x,
torch.full_like(x, -1.0),
torch.full_like(x, 1.0),
torch.full_like(x, -2.0),
torch.full_like(x, 2.0),
],
dim=-1,
)
return st.median(augmented, dim=-1, **kwargs).values
def soft_sort(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
augmented = torch.stack(
[x, torch.full_like(x, -1.0), torch.full_like(x, 1.0)], dim=-1
)
return st.sort(augmented, **kwargs).values[..., 2]
def soft_top_k(x, **kwargs):
x = torch.as_tensor(x, dtype=torch.get_default_dtype())
augmented = torch.stack(
[x, torch.full_like(x, -1.0), torch.full_like(x, 1.0)], dim=-1
)
return st.topk(augmented, k=2, **kwargs).values[..., 1]
plot(soft_max, modes=projection_modes)
plot(soft_min, modes=projection_modes)
plot(soft_median, modes=projection_modes)
plot(soft_sort, modes=projection_modes)
plot(soft_top_k, modes=projection_modes)
In [ ]:
Copied!