Straight-through operators¤
Straight-through utility¤
softjax.st(fn: typing.Callable) -> typing.Callable
¤
This decorator calls the decorated function twice: once with mode="hard" and
once with the specified mode. It returns the output from the hard forward pass,
but uses the output from the soft backward pass to compute gradients.
Arguments:
fn: The function to be wrapped. It should accept amodeargument.
Returns:
A wrapped function that behaves like the mode="hard" version during the
forward pass, but computes gradients using the specified mode and softness
during the backward pass.
softjax.grad_replace(fn: typing.Callable) -> typing.Callable
¤
This decorator calls the decorated function twice: once with forward=True and
once with forward=False. It returns the output from the forward pass, but uses
the output from the backward pass to compute gradients.
Arguments:
fn: The function to be wrapped. It should accept aforwardargument that specifies which computation to perform depending on forward or backward pass.
Returns:
A wrapped function that behaves like the forward=True version during the
forward pass, but computes gradients using the forward=False version
during the backward pass.
Elementwise operators¤
softjax.abs_st(*args, **kwargs)
¤
softjax.clip_st(*args, **kwargs)
¤
softjax.heaviside_st(*args, **kwargs)
¤
softjax.relu_st(*args, **kwargs)
¤
softjax.round_st(*args, **kwargs)
¤
softjax.sign_st(*args, **kwargs)
¤
Array-valued operators¤
softjax.argmax_st(*args, **kwargs)
¤
Straight-through version of softjax.argmax.
This function returns the hard argmax during the forward pass, but uses a
soft relaxation (controlled by the mode argument) for the backward pass
(i.e., gradients are computed through the soft version).
Implemented using the softjax.st decorator as st(softjax.argmax).