Skip to content

Straight-through operators¤

Straight-through utility¤

softjax.st(fn: typing.Callable) -> typing.Callable ¤

This decorator calls the decorated function twice: once with mode="hard" and once with the specified mode. It returns the output from the hard forward pass, but uses the output from the soft backward pass to compute gradients.

Arguments:

  • fn: The function to be wrapped. It should accept a mode argument.

Returns: A wrapped function that behaves like the mode="hard" version during the forward pass, but computes gradients using the specified mode and softness during the backward pass.

softjax.grad_replace(fn: typing.Callable) -> typing.Callable ¤

This decorator calls the decorated function twice: once with forward=True and once with forward=False. It returns the output from the forward pass, but uses the output from the backward pass to compute gradients.

Arguments:

  • fn: The function to be wrapped. It should accept a forward argument that specifies which computation to perform depending on forward or backward pass.

Returns: A wrapped function that behaves like the forward=True version during the forward pass, but computes gradients using the forward=False version during the backward pass.

Elementwise operators¤

softjax.abs_st(*args, **kwargs) ¤

softjax.clip_st(*args, **kwargs) ¤

softjax.heaviside_st(*args, **kwargs) ¤

softjax.relu_st(*args, **kwargs) ¤

softjax.round_st(*args, **kwargs) ¤

softjax.sign_st(*args, **kwargs) ¤

Array-valued operators¤

softjax.argmax_st(*args, **kwargs) ¤

Straight-through version of softjax.argmax.

This function returns the hard argmax during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

Implemented using the softjax.st decorator as st(softjax.argmax).

softjax.max_st(*args, **kwargs) ¤

softjax.argmin_st(*args, **kwargs) ¤

softjax.min_st(*args, **kwargs) ¤

softjax.argmedian_st(*args, **kwargs) ¤

softjax.median_st(*args, **kwargs) ¤

softjax.median_newton_st(*args, **kwargs) ¤

softjax.argsort_st(*args, **kwargs) ¤

softjax.sort_st(*args, **kwargs) ¤

softjax.top_k_st(*args, **kwargs) ¤

softjax.ranking_st(*args, **kwargs) ¤

Comparison operators¤

softjax.greater_st(*args, **kwargs) ¤

softjax.greater_equal_st(*args, **kwargs) ¤

softjax.less_st(*args, **kwargs) ¤

softjax.less_equal_st(*args, **kwargs) ¤

softjax.equal_st(*args, **kwargs) ¤

softjax.not_equal_st(*args, **kwargs) ¤

softjax.isclose_st(*args, **kwargs) ¤