Straight-through operators¤
Straight-through utility¤
softtorch.st(fn: Callable) -> Callable
¤
This decorator calls the decorated function twice: once with mode="hard" and once with the specified mode.
It returns the output from the hard forward pass, but uses the output from the soft backward pass to compute gradients.
Arguments:
fn: The function to be wrapped. It may accept amodeparameter. Iffnhas nomodeparameter, it defaults to"smooth"andmodeis passed through via**kwargs.
Returns:
A wrapped function that behaves like the mode="hard" version during the forward pass, but computes gradients using the specified mode and softness during the backward pass.
softtorch.grad_replace(fn: Callable) -> Callable
¤
This decorator calls the decorated function twice: once with forward=True and once with forward=False.
It returns the output from the forward pass, but uses the output from the backward pass to compute gradients.
Arguments:
fn: The function to be wrapped. It should accept aforwardargument that specifies which computation to perform depending on forward or backward pass.
Returns:
A wrapped function that behaves like the forward=True version during the forward pass, but computes gradients using the forward=False version during the backward pass.
Elementwise operators¤
softtorch.abs_st(*args, **kwargs)
¤
Straight-through version of softtorch.abs.
Implemented using the softtorch.st decorator as st(softtorch.abs).
Returns the hard abs during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.clamp_st(*args, **kwargs)
¤
Straight-through version of softtorch.clamp.
Implemented using the softtorch.st decorator as st(softtorch.clamp).
Returns the hard clamp during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.heaviside_st(*args, **kwargs)
¤
Straight-through version of softtorch.heaviside.
Implemented using the softtorch.st decorator as st(softtorch.heaviside).
Returns the hard heaviside during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.relu_st(*args, **kwargs)
¤
Straight-through version of softtorch.relu.
Implemented using the softtorch.st decorator as st(softtorch.relu).
Returns the hard relu during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.round_st(*args, **kwargs)
¤
Straight-through version of softtorch.round.
Implemented using the softtorch.st decorator as st(softtorch.round).
Returns the hard round during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.sign_st(*args, **kwargs)
¤
Straight-through version of softtorch.sign.
Implemented using the softtorch.st decorator as st(softtorch.sign).
Returns the hard sign during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
Tensor-valued operators¤
softtorch.argmax_st(*args, **kwargs)
¤
Straight-through version of softtorch.argmax.
Implemented using the softtorch.st decorator as st(softtorch.argmax).
Returns the hard argmax during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.max_st(*args, **kwargs)
¤
Straight-through version of softtorch.max.
Implemented using the softtorch.st decorator as st(softtorch.max).
Returns the hard max during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.argmin_st(*args, **kwargs)
¤
Straight-through version of softtorch.argmin.
Implemented using the softtorch.st decorator as st(softtorch.argmin).
Returns the hard argmin during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.min_st(*args, **kwargs)
¤
Straight-through version of softtorch.min.
Implemented using the softtorch.st decorator as st(softtorch.min).
Returns the hard min during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.argquantile_st(*args, **kwargs)
¤
Straight-through version of softtorch.argquantile.
Implemented using the softtorch.st decorator as st(softtorch.argquantile).
Returns the hard argquantile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.quantile_st(*args, **kwargs)
¤
Straight-through version of softtorch.quantile.
Implemented using the softtorch.st decorator as st(softtorch.quantile).
Returns the hard quantile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.argmedian_st(*args, **kwargs)
¤
Straight-through version of softtorch.argmedian.
Implemented using the softtorch.st decorator as st(softtorch.argmedian).
Returns the hard argmedian during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.median_st(*args, **kwargs)
¤
Straight-through version of softtorch.median.
Implemented using the softtorch.st decorator as st(softtorch.median).
Returns the hard median during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.argsort_st(*args, **kwargs)
¤
Straight-through version of softtorch.argsort.
Implemented using the softtorch.st decorator as st(softtorch.argsort).
Returns the hard argsort during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.sort_st(*args, **kwargs)
¤
Straight-through version of softtorch.sort.
Implemented using the softtorch.st decorator as st(softtorch.sort).
Returns the hard sort during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.topk_st(*args, **kwargs)
¤
Straight-through version of softtorch.topk.
Implemented using the softtorch.st decorator as st(softtorch.topk).
Returns the hard topk during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.rank_st(*args, **kwargs)
¤
Straight-through version of softtorch.rank.
Implemented using the softtorch.st decorator as st(softtorch.rank).
Returns the hard rank during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
Comparison operators¤
softtorch.greater_st(*args, **kwargs)
¤
Straight-through version of softtorch.greater.
Implemented using the softtorch.st decorator as st(softtorch.greater).
Returns the hard greater during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.greater_equal_st(*args, **kwargs)
¤
Straight-through version of softtorch.greater_equal.
Implemented using the softtorch.st decorator as st(softtorch.greater_equal).
Returns the hard greater_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.less_st(*args, **kwargs)
¤
Straight-through version of softtorch.less.
Implemented using the softtorch.st decorator as st(softtorch.less).
Returns the hard less during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.less_equal_st(*args, **kwargs)
¤
Straight-through version of softtorch.less_equal.
Implemented using the softtorch.st decorator as st(softtorch.less_equal).
Returns the hard less_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.eq_st(*args, **kwargs)
¤
Straight-through version of softtorch.eq.
Implemented using the softtorch.st decorator as st(softtorch.eq).
Returns the hard eq during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.not_equal_st(*args, **kwargs)
¤
Straight-through version of softtorch.not_equal.
Implemented using the softtorch.st decorator as st(softtorch.not_equal).
Returns the hard not_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softtorch.isclose_st(*args, **kwargs)
¤
Straight-through version of softtorch.isclose.
Implemented using the softtorch.st decorator as st(softtorch.isclose).
Returns the hard isclose during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).