Straight-through operators¤
Straight-through utility¤
softjax.st(fn: Callable) -> Callable
¤
This decorator calls the decorated function twice: once with mode="hard" and once with the specified mode.
It returns the output from the hard forward pass, but uses the output from the soft backward pass to compute gradients.
Arguments:
fn: The function to be wrapped. It may accept amodeargument. Iffnhas nomodeparameter, it defaults to"smooth"andmodeis passed through via**kwargs.
Returns:
A wrapped function that behaves like the mode="hard" version during the forward pass, but computes gradients using the specified mode and softness during the backward pass.
softjax.grad_replace(fn: Callable) -> Callable
¤
This decorator calls the decorated function twice: once with forward=True and once with forward=False.
It returns the output from the forward pass, but uses the output from the backward pass to compute gradients.
Arguments:
fn: The function to be wrapped. It should accept aforwardargument that specifies which computation to perform depending on forward or backward pass.
Returns:
A wrapped function that behaves like the forward=True version during the forward pass, but computes gradients using the forward=False version during the backward pass.
Elementwise operators¤
softjax.abs_st(*args, **kwargs)
¤
Straight-through version of softjax.abs.
Implemented using the softjax.st decorator as st(softjax.abs).
This function returns the hard abs during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.clip_st(*args, **kwargs)
¤
Straight-through version of softjax.clip.
Implemented using the softjax.st decorator as st(softjax.clip).
This function returns the hard clip during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.heaviside_st(*args, **kwargs)
¤
Straight-through version of softjax.heaviside.
Implemented using the softjax.st decorator as st(softjax.heaviside).
This function returns the hard heaviside during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.relu_st(*args, **kwargs)
¤
Straight-through version of softjax.relu.
Implemented using the softjax.st decorator as st(softjax.relu).
This function returns the hard relu during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.round_st(*args, **kwargs)
¤
Straight-through version of softjax.round.
Implemented using the softjax.st decorator as st(softjax.round).
This function returns the hard round during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.sign_st(*args, **kwargs)
¤
Straight-through version of softjax.sign.
Implemented using the softjax.st decorator as st(softjax.sign).
This function returns the hard sign during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
Array-valued operators¤
softjax.argmax_st(*args, **kwargs)
¤
Straight-through version of softjax.argmax.
Implemented using the softjax.st decorator as st(softjax.argmax).
This function returns the hard argmax during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.max_st(*args, **kwargs)
¤
Straight-through version of softjax.max.
Implemented using the softjax.st decorator as st(softjax.max).
This function returns the hard max during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.argmin_st(*args, **kwargs)
¤
Straight-through version of softjax.argmin.
Implemented using the softjax.st decorator as st(softjax.argmin).
This function returns the hard argmin during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.min_st(*args, **kwargs)
¤
Straight-through version of softjax.min.
Implemented using the softjax.st decorator as st(softjax.min).
This function returns the hard min during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.argquantile_st(*args, **kwargs)
¤
Straight-through version of softjax.argquantile.
Implemented using the softjax.st decorator as st(softjax.argquantile).
This function returns the hard argquantile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.quantile_st(*args, **kwargs)
¤
Straight-through version of softjax.quantile.
Implemented using the softjax.st decorator as st(softjax.quantile).
This function returns the hard quantile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.argpercentile_st(*args, **kwargs)
¤
Straight-through version of softjax.argpercentile.
Implemented using the softjax.st decorator as st(softjax.argpercentile).
This function returns the hard argpercentile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.percentile_st(*args, **kwargs)
¤
Straight-through version of softjax.percentile.
Implemented using the softjax.st decorator as st(softjax.percentile).
This function returns the hard percentile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.argmedian_st(*args, **kwargs)
¤
Straight-through version of softjax.argmedian.
Implemented using the softjax.st decorator as st(softjax.argmedian).
This function returns the hard argmedian during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.median_st(*args, **kwargs)
¤
Straight-through version of softjax.median.
Implemented using the softjax.st decorator as st(softjax.median).
This function returns the hard median during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.argsort_st(*args, **kwargs)
¤
Straight-through version of softjax.argsort.
Implemented using the softjax.st decorator as st(softjax.argsort).
This function returns the hard argsort during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.sort_st(*args, **kwargs)
¤
Straight-through version of softjax.sort.
Implemented using the softjax.st decorator as st(softjax.sort).
This function returns the hard sort during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.top_k_st(*args, **kwargs)
¤
Straight-through version of softjax.top_k.
Implemented using the softjax.st decorator as st(softjax.top_k).
This function returns the hard top_k during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.rank_st(*args, **kwargs)
¤
Straight-through version of softjax.rank.
Implemented using the softjax.st decorator as st(softjax.rank).
This function returns the hard rank during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
Comparison operators¤
softjax.greater_st(*args, **kwargs)
¤
Straight-through version of softjax.greater.
Implemented using the softjax.st decorator as st(softjax.greater).
This function returns the hard greater during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.greater_equal_st(*args, **kwargs)
¤
Straight-through version of softjax.greater_equal.
Implemented using the softjax.st decorator as st(softjax.greater_equal).
This function returns the hard greater_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.less_st(*args, **kwargs)
¤
Straight-through version of softjax.less.
Implemented using the softjax.st decorator as st(softjax.less).
This function returns the hard less during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.less_equal_st(*args, **kwargs)
¤
Straight-through version of softjax.less_equal.
Implemented using the softjax.st decorator as st(softjax.less_equal).
This function returns the hard less_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.equal_st(*args, **kwargs)
¤
Straight-through version of softjax.equal.
Implemented using the softjax.st decorator as st(softjax.equal).
This function returns the hard equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.not_equal_st(*args, **kwargs)
¤
Straight-through version of softjax.not_equal.
Implemented using the softjax.st decorator as st(softjax.not_equal).
This function returns the hard not_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).
softjax.isclose_st(*args, **kwargs)
¤
Straight-through version of softjax.isclose.
Implemented using the softjax.st decorator as st(softjax.isclose).
This function returns the hard isclose during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).