Straight-through operators¤
Straight-through utility¤
softtorch.st(fn: Callable) -> Callable
¤
This decorator calls the decorated function twice: once with mode="hard" and
once with the specified mode. It returns the output from the hard forward pass,
but uses the output from the soft backward pass to compute gradients.
Arguments:
fn: The function to be wrapped. It should accept amodeargument.
Returns:
A wrapped function that behaves like the mode="hard" version during the
forward pass, but computes gradients using the specified mode and softness
during the backward pass.
softtorch.grad_replace(fn: Callable) -> Callable
¤
This decorator calls the decorated function twice: once with forward=True and
once with forward=False. It returns the output from the forward pass, but uses
the output from the backward pass to compute gradients.
Arguments:
fn: The function to be wrapped. It should accept aforwardargument that specifies which computation to perform depending on forward or backward pass.
Returns:
A wrapped function that behaves like the forward=True version during the
forward pass, but computes gradients using the forward=False version
during the backward pass.
Elementwise operators¤
softtorch.abs_st(*args, **kwargs)
¤
softtorch.clamp_st(*args, **kwargs)
¤
softtorch.heaviside_st(*args, **kwargs)
¤
softtorch.relu_st(*args, **kwargs)
¤
softtorch.round_st(*args, **kwargs)
¤
softtorch.sign_st(*args, **kwargs)
¤
Tensor-valued operators¤
softtorch.argmax_st(*args, **kwargs)
¤
Straight-through version of softtorch.argmax.
This function returns the hard argmax during the forward pass, but uses a
soft relaxation (controlled by the mode argument) for the backward pass
(i.e., gradients are computed through the soft version).
Implemented using the softtorch.st decorator as st(softtorch.argmax).