Skip to content

Straight-through operators¤

Straight-through utility¤

softjax.st(fn: Callable) -> Callable ¤

This decorator calls the decorated function twice: once with mode="hard" and once with the specified mode. It returns the output from the hard forward pass, but uses the output from the soft backward pass to compute gradients.

Arguments:

  • fn: The function to be wrapped. It may accept a mode argument. If fn has no mode parameter, it defaults to "smooth" and mode is passed through via **kwargs.

Returns: A wrapped function that behaves like the mode="hard" version during the forward pass, but computes gradients using the specified mode and softness during the backward pass.

softjax.grad_replace(fn: Callable) -> Callable ¤

This decorator calls the decorated function twice: once with forward=True and once with forward=False. It returns the output from the forward pass, but uses the output from the backward pass to compute gradients.

Arguments:

  • fn: The function to be wrapped. It should accept a forward argument that specifies which computation to perform depending on forward or backward pass.

Returns: A wrapped function that behaves like the forward=True version during the forward pass, but computes gradients using the forward=False version during the backward pass.

Elementwise operators¤

softjax.abs_st(*args, **kwargs) ¤

Straight-through version of softjax.abs. Implemented using the softjax.st decorator as st(softjax.abs).

This function returns the hard abs during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.clip_st(*args, **kwargs) ¤

Straight-through version of softjax.clip. Implemented using the softjax.st decorator as st(softjax.clip).

This function returns the hard clip during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.heaviside_st(*args, **kwargs) ¤

Straight-through version of softjax.heaviside. Implemented using the softjax.st decorator as st(softjax.heaviside).

This function returns the hard heaviside during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.relu_st(*args, **kwargs) ¤

Straight-through version of softjax.relu. Implemented using the softjax.st decorator as st(softjax.relu).

This function returns the hard relu during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.round_st(*args, **kwargs) ¤

Straight-through version of softjax.round. Implemented using the softjax.st decorator as st(softjax.round).

This function returns the hard round during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.sign_st(*args, **kwargs) ¤

Straight-through version of softjax.sign. Implemented using the softjax.st decorator as st(softjax.sign).

This function returns the hard sign during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

Array-valued operators¤

softjax.argmax_st(*args, **kwargs) ¤

Straight-through version of softjax.argmax. Implemented using the softjax.st decorator as st(softjax.argmax).

This function returns the hard argmax during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.max_st(*args, **kwargs) ¤

Straight-through version of softjax.max. Implemented using the softjax.st decorator as st(softjax.max).

This function returns the hard max during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.argmin_st(*args, **kwargs) ¤

Straight-through version of softjax.argmin. Implemented using the softjax.st decorator as st(softjax.argmin).

This function returns the hard argmin during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.min_st(*args, **kwargs) ¤

Straight-through version of softjax.min. Implemented using the softjax.st decorator as st(softjax.min).

This function returns the hard min during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.argquantile_st(*args, **kwargs) ¤

Straight-through version of softjax.argquantile. Implemented using the softjax.st decorator as st(softjax.argquantile).

This function returns the hard argquantile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.quantile_st(*args, **kwargs) ¤

Straight-through version of softjax.quantile. Implemented using the softjax.st decorator as st(softjax.quantile).

This function returns the hard quantile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.argpercentile_st(*args, **kwargs) ¤

Straight-through version of softjax.argpercentile. Implemented using the softjax.st decorator as st(softjax.argpercentile).

This function returns the hard argpercentile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.percentile_st(*args, **kwargs) ¤

Straight-through version of softjax.percentile. Implemented using the softjax.st decorator as st(softjax.percentile).

This function returns the hard percentile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.argmedian_st(*args, **kwargs) ¤

Straight-through version of softjax.argmedian. Implemented using the softjax.st decorator as st(softjax.argmedian).

This function returns the hard argmedian during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.median_st(*args, **kwargs) ¤

Straight-through version of softjax.median. Implemented using the softjax.st decorator as st(softjax.median).

This function returns the hard median during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.argsort_st(*args, **kwargs) ¤

Straight-through version of softjax.argsort. Implemented using the softjax.st decorator as st(softjax.argsort).

This function returns the hard argsort during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.sort_st(*args, **kwargs) ¤

Straight-through version of softjax.sort. Implemented using the softjax.st decorator as st(softjax.sort).

This function returns the hard sort during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.top_k_st(*args, **kwargs) ¤

Straight-through version of softjax.top_k. Implemented using the softjax.st decorator as st(softjax.top_k).

This function returns the hard top_k during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.rank_st(*args, **kwargs) ¤

Straight-through version of softjax.rank. Implemented using the softjax.st decorator as st(softjax.rank).

This function returns the hard rank during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

Comparison operators¤

softjax.greater_st(*args, **kwargs) ¤

Straight-through version of softjax.greater. Implemented using the softjax.st decorator as st(softjax.greater).

This function returns the hard greater during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.greater_equal_st(*args, **kwargs) ¤

Straight-through version of softjax.greater_equal. Implemented using the softjax.st decorator as st(softjax.greater_equal).

This function returns the hard greater_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.less_st(*args, **kwargs) ¤

Straight-through version of softjax.less. Implemented using the softjax.st decorator as st(softjax.less).

This function returns the hard less during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.less_equal_st(*args, **kwargs) ¤

Straight-through version of softjax.less_equal. Implemented using the softjax.st decorator as st(softjax.less_equal).

This function returns the hard less_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.equal_st(*args, **kwargs) ¤

Straight-through version of softjax.equal. Implemented using the softjax.st decorator as st(softjax.equal).

This function returns the hard equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.not_equal_st(*args, **kwargs) ¤

Straight-through version of softjax.not_equal. Implemented using the softjax.st decorator as st(softjax.not_equal).

This function returns the hard not_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softjax.isclose_st(*args, **kwargs) ¤

Straight-through version of softjax.isclose. Implemented using the softjax.st decorator as st(softjax.isclose).

This function returns the hard isclose during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).