Quick Example¶
This notebook contains the quick examples from the Readme.
import softtorch as st
import torch
The default mode of all our SoftTorch functions is mode=entropic. This uses some well-known exponential-function-based relaxations, e.g. an exponential sigmoid as a soft heaviside function and the softplus function as a soft ReLU function.
We can replicate the original PyTorch function outputs (plus some one-hot generation for index functions) by setting mode=hard. This is especially useful for straight-through estimators, where we want to easily switch between hard and soft modes.
Note that all SoftTorch functions taking a mode argument also take a softness argument, which determines how soft the relaxation is. Note that while the limit of softness going to zero should always recover the hard function, setting softness=0 is not allowed. Instead, use the mode=hard setting.
x = torch.tensor([-0.2, -1.0, 0.3, 1.0])
# Elementwise functions
print("\nTorch absolute:", torch.abs(x))
print("SoftTorch absolute (hard mode):", st.abs(x, mode="hard"))
print("SoftTorch absolute (soft mode):", st.abs(x))
print("\nTorch clamp:", torch.clamp(x, -0.5, 0.5))
print("SoftTorch clamp (hard mode):", st.clamp(x, -0.5, 0.5, mode="hard"))
print("SoftTorch clamp (soft mode):", st.clamp(x, -0.5, 0.5))
print("\nTorch heaviside:", torch.heaviside(x, torch.tensor(0.5)))
print("SoftTorch heaviside (hard mode):", st.heaviside(x, mode="hard"))
print("SoftTorch heaviside (soft mode):", st.heaviside(x))
print("\nTorch ReLU:", torch.nn.functional.relu(x))
print("SoftTorch ReLU (hard mode):", st.relu(x, mode="hard"))
print("SoftTorch ReLU (soft mode):", st.relu(x))
print("\nTorch round:", torch.round(x))
print("SoftTorch round (hard mode):", st.round(x, mode="hard"))
print("SoftTorch round (soft mode):", st.round(x))
print("\nTorch sign:", torch.sign(x))
print("SoftTorch sign (hard mode):", st.sign(x, mode="hard"))
print("SoftTorch sign (soft mode):", st.sign(x))
Torch absolute: tensor([0.2000, 1.0000, 0.3000, 1.0000]) SoftTorch absolute (hard mode): tensor([0.2000, 1.0000, 0.3000, 1.0000]) SoftTorch absolute (soft mode): tensor([0.1523, 0.9999, 0.2715, 0.9999]) Torch clamp: tensor([-0.2000, -0.5000, 0.3000, 0.5000]) SoftTorch clamp (hard mode): tensor([-0.2000, -0.5000, 0.3000, 0.5000]) SoftTorch clamp (soft mode): tensor([-0.1952, -0.4993, 0.2873, 0.4993]) Torch heaviside: tensor([0., 0., 1., 1.]) SoftTorch heaviside (hard mode): tensor([0., 0., 1., 1.]) SoftTorch heaviside (soft mode): tensor([1.1920e-01, 4.5398e-05, 9.5257e-01, 9.9995e-01]) Torch ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000]) SoftTorch ReLU (hard mode): tensor([0.0000, 0.0000, 0.3000, 1.0000]) SoftTorch ReLU (soft mode): tensor([1.2693e-02, 4.5399e-06, 3.0486e-01, 1.0000e+00]) Torch round: tensor([-0., -1., 0., 1.]) SoftTorch round (hard mode): tensor([-0., -1., 0., 1.]) SoftTorch round (soft mode): tensor([-0.0465, -1.0000, 0.1189, 1.0000]) Torch sign: tensor([-1., -1., 1., 1.]) SoftTorch sign (hard mode): tensor([-1., -1., 1., 1.]) SoftTorch sign (soft mode): tensor([-0.7616, -0.9999, 0.9051, 0.9999])
# Tensor-valued operators
print("\nTorch max:", torch.max(x))
print("SoftTorch max (hard mode):", st.max(x, mode="hard"))
print("SoftTorch max (soft mode):", st.max(x))
print("\nTorch min:", torch.min(x))
print("SoftTorch min (hard mode):", st.min(x, mode="hard"))
print("SoftTorch min (soft mode):", st.min(x))
print("\nTorch median:", torch.median(x))
print("SoftTorch median (hard mode):", st.median(x, mode="hard"))
print("SoftTorch median (soft mode):", st.median(x))
print("\nTorch sort:", torch.sort(x).values)
print("SoftTorch sort (hard mode):", st.sort(x, mode="hard").values)
print("SoftTorch sort (soft mode):", st.sort(x).values)
print("\nTorch topk:", torch.topk(x, k=2).values)
print("SoftTorch topk (hard mode):", st.topk(x, k=2, mode="hard").values)
print("SoftTorch topk (soft mode):", st.topk(x, k=2).values)
print("\nTorch ranking:", torch.argsort(torch.argsort(x)))
print("SoftTorch ranking (hard mode):", st.ranking(x, descending=False, mode="hard"))
print("SoftTorch ranking (soft mode):", st.ranking(x, descending=False))
Torch max: tensor(1.) SoftTorch max (hard mode): tensor(1.) SoftTorch max (soft mode): tensor(0.9994) Torch min: tensor(-1.) SoftTorch min (hard mode): tensor(-1.) SoftTorch min (soft mode): tensor(-0.9997) Torch median: tensor(-0.2000) SoftTorch median (hard mode): tensor(0.0500) SoftTorch median (soft mode): tensor(0.0500) Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) SoftTorch sort (hard mode): tensor([-1.0000, -0.2000, 0.3000, 1.0000]) SoftTorch sort (soft mode): tensor([-0.9997, -0.1969, 0.2973, 0.9994]) Torch topk: tensor([1.0000, 0.3000]) SoftTorch topk (hard mode): tensor([1.0000, 0.3000]) SoftTorch topk (soft mode): tensor([0.9994, 0.2973]) Torch ranking: tensor([1, 0, 2, 3]) SoftTorch ranking (hard mode): tensor([1., 0., 2., 3.]) SoftTorch ranking (soft mode): tensor([1.0064e+00, 3.3987e-04, 1.9942e+00, 2.9991e+00])
# Operators returning indices
print("\nTorch argmax:", torch.argmax(x))
print("SoftTorch argmax (hard mode):", st.argmax(x, mode="hard"))
print("SoftTorch argmax (soft mode):", st.argmax(x))
print("\nTorch argmin:", torch.argmin(x))
print("SoftTorch argmin (hard mode):", st.argmin(x, mode="hard"))
print("SoftTorch argmin (soft mode):", st.argmin(x))
print("\nTorch argmedian:", torch.median(x, dim=0).indices)
print("SoftTorch argmedian (hard mode):", st.median(x, mode="hard", dim=0).indices)
print("SoftTorch argmedian (soft mode):", st.median(x, dim=0).indices)
print("\nTorch argsort:", torch.argsort(x))
print("SoftTorch argsort (hard mode):", st.argsort(x, mode="hard"))
print("SoftTorch argsort (soft mode):", st.argsort(x))
print("\nTorch argtopk:", torch.topk(x, k=2).indices)
print("SoftTorch argtopk (hard mode):", st.topk(x, k=2, mode="hard").indices)
print("SoftTorch argtopk (soft mode):", st.topk(x, k=2).indices)
Torch argmax: tensor(3)
SoftTorch argmax (hard mode): tensor([0., 0., 0., 1.])
SoftTorch argmax (soft mode): tensor([6.1386e-06, 2.0593e-09, 9.1105e-04, 9.9908e-01])
Torch argmin: tensor(1)
SoftTorch argmin (hard mode): tensor([0., 1., 0., 0.])
SoftTorch argmin (soft mode): tensor([3.3535e-04, 9.9966e-01, 2.2596e-06, 2.0605e-09])
Torch argmedian: tensor(0)
SoftTorch argmedian (hard mode): tensor([0.5000, 0.0000, 0.5000, 0.0000])
SoftTorch argmedian (soft mode): tensor([5.0000e-01, 5.6268e-08, 5.0000e-01, 4.1576e-07])
Torch argsort: tensor([1, 0, 2, 3])
SoftTorch argsort (hard mode): tensor([[0., 1., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
SoftTorch argsort (soft mode): tensor([[3.3535e-04, 9.9966e-01, 2.2596e-06, 2.0605e-09],
[9.9297e-01, 3.3310e-04, 6.6906e-03, 6.1010e-06],
[6.6868e-03, 2.2432e-06, 9.9241e-01, 9.0496e-04],
[6.1386e-06, 2.0593e-09, 9.1105e-04, 9.9908e-01]])
Torch argtopk: tensor([3, 2])
SoftTorch argtopk (hard mode): tensor([[0., 0., 0., 1.],
[0., 0., 1., 0.]])
SoftTorch argtopk (soft mode): tensor([[6.1386e-06, 2.0593e-09, 9.1105e-04, 9.9908e-01],
[6.6868e-03, 2.2432e-06, 9.9241e-01, 9.0496e-04]])
y = torch.tensor([0.2, -0.5, 0.5, -1.0])
# Comparison operators
print("\nTorch greater:", torch.greater(x, y))
print("SoftTorch greater (hard mode):", st.greater(x, y, mode="hard"))
print("SoftTorch greater (soft mode):", st.greater(x, y))
print("\nTorch greater equal:", torch.greater_equal(x, y))
print("SoftTorch greater equal (hard mode):", st.greater_equal(x, y, mode="hard"))
print("SoftTorch greater equal (soft mode):", st.greater_equal(x, y))
print("\nTorch less:", torch.less(x, y))
print("SoftTorch less (hard mode):", st.less(x, y, mode="hard"))
print("SoftTorch less (soft mode):", st.less(x, y))
print("\nTorch less equal:", torch.less_equal(x, y))
print("SoftTorch less equal (hard mode):", st.less_equal(x, y, mode="hard"))
print("SoftTorch less equal (soft mode):", st.less_equal(x, y))
print("\nTorch equal:", torch.equal(x, y))
print("SoftTorch equal (hard mode):", st.equal(x, y, mode="hard"))
print("SoftTorch equal (soft mode):", st.equal(x, y))
print("\nTorch not equal:", torch.not_equal(x, y))
print("SoftTorch not equal (hard mode):", st.not_equal(x, y, mode="hard"))
print("SoftTorch not equal (soft mode):", st.not_equal(x, y))
print("\nTorch isclose:", torch.isclose(x, y))
print("SoftTorch isclose (hard mode):", st.isclose(x, y, mode="hard"))
print("SoftTorch isclose (soft mode):", st.isclose(x, y))
Torch greater: tensor([False, False, False, True]) SoftTorch greater (hard mode): tensor([0., 0., 0., 1.]) SoftTorch greater (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000]) Torch greater equal: tensor([False, False, False, True]) SoftTorch greater equal (hard mode): tensor([0., 0., 0., 1.]) SoftTorch greater equal (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000]) Torch less: tensor([ True, True, True, False]) SoftTorch less (hard mode): tensor([1., 1., 1., 0.]) SoftTorch less (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000]) Torch less equal: tensor([ True, True, True, False]) SoftTorch less equal (hard mode): tensor([1., 1., 1., 0.]) SoftTorch less equal (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000]) Torch equal: False SoftTorch equal (hard mode): tensor([0., 0., 0., 0.]) SoftTorch equal (soft mode): tensor([0.0180, 0.0067, 0.1192, 0.0000]) Torch not equal: tensor([True, True, True, True]) SoftTorch not equal (hard mode): tensor([1., 1., 1., 1.]) SoftTorch not equal (soft mode): tensor([0.9820, 0.9933, 0.8808, 1.0000]) Torch isclose: tensor([False, False, False, False]) SoftTorch isclose (hard mode): tensor([0., 0., 0., 0.]) SoftTorch isclose (soft mode): tensor([0.0180, 0.0067, 0.1192, 0.0000])
# Logical operators
fuzzy_a = torch.tensor([0.1, 0.2, 0.8, 1.0])
fuzzy_b = torch.tensor([0.7, 0.3, 0.1, 0.9])
bool_a = fuzzy_a >= 0.5
bool_b = fuzzy_b >= 0.5
print("\nTorch AND:", torch.logical_and(bool_a, bool_b))
print("SoftTorch AND:", st.logical_and(fuzzy_a, fuzzy_b))
print("\nTorch OR:", torch.logical_or(bool_a, bool_b))
print("SoftTorch OR:", st.logical_or(fuzzy_a, fuzzy_b))
print("\nTorch NOT:", torch.logical_not(bool_a))
print("SoftTorch NOT:", st.logical_not(fuzzy_a))
print("\nTorch XOR:", torch.logical_xor(bool_a, bool_b))
print("SoftTorch XOR:", st.logical_xor(fuzzy_a, fuzzy_b))
print("\nTorch ALL:", torch.all(bool_a))
print("SoftTorch ALL:", st.all(fuzzy_a))
print("\nTorch ANY:", torch.any(bool_a))
print("SoftTorch ANY:", st.any(fuzzy_a))
# Selection operators
print("SoftTorch Where:", st.where(fuzzy_a, x, y))
Torch AND: tensor([False, False, False, True]) SoftTorch AND: tensor([0.2646, 0.2449, 0.2828, 0.9487]) Torch OR: tensor([ True, False, True, True]) SoftTorch OR: tensor([0.4804, 0.2517, 0.5757, 1.0000]) Torch NOT: tensor([ True, True, False, False]) SoftTorch NOT: tensor([0.9000, 0.8000, 0.2000, 0.0000]) Torch XOR: tensor([ True, False, True, False]) SoftTorch XOR: tensor([0.5870, 0.4350, 0.6394, 0.1731]) Torch ALL: tensor(False) SoftTorch ALL: tensor(0.3557) Torch ANY: tensor(True) SoftTorch ANY: tensor(0.9981) SoftTorch Where: tensor([ 0.1600, -0.6000, 0.3400, 1.0000])
# Straight-through operators: Use hard function on forward and soft on backward
print("Straight-through ReLU:", st.relu_st(x))
print("Straight-through sort:", st.sort_st(x))
print("Straight-through topk:", st.topk_st(x, k=3))
print("Straight-through greater:", st.greater_st(x, y))
# And many more...
Straight-through ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000])
Straight-through sort: torch.return_types.sort(
values=tensor([-1.0000, -0.2000, 0.3000, 1.0000]),
indices=tensor([[0., 1., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]]))
Straight-through topk: torch.return_types.topk(
values=tensor([ 1.0000, 0.3000, -0.2000]),
indices=tensor([[0., 0., 0., 1.],
[0., 0., 1., 0.],
[1., 0., 0., 0.]]))
Straight-through greater: tensor([0., 0., 0., 1.])