Skip to content

Straight-through operators¤

Straight-through utility¤

softtorch.st(fn: Callable) -> Callable ¤

This decorator calls the decorated function twice: once with mode="hard" and once with the specified mode. It returns the output from the hard forward pass, but uses the output from the soft backward pass to compute gradients.

Arguments:

  • fn: The function to be wrapped. It may accept a mode parameter. If fn has no mode parameter, it defaults to "smooth" and mode is passed through via **kwargs.

Returns:

A wrapped function that behaves like the mode="hard" version during the forward pass, but computes gradients using the specified mode and softness during the backward pass.

softtorch.grad_replace(fn: Callable) -> Callable ¤

This decorator calls the decorated function twice: once with forward=True and once with forward=False. It returns the output from the forward pass, but uses the output from the backward pass to compute gradients.

Arguments:

  • fn: The function to be wrapped. It should accept a forward argument that specifies which computation to perform depending on forward or backward pass.

Returns:

A wrapped function that behaves like the forward=True version during the forward pass, but computes gradients using the forward=False version during the backward pass.

Elementwise operators¤

softtorch.abs_st(*args, **kwargs) ¤

Straight-through version of softtorch.abs. Implemented using the softtorch.st decorator as st(softtorch.abs).

Returns the hard abs during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.clamp_st(*args, **kwargs) ¤

Straight-through version of softtorch.clamp. Implemented using the softtorch.st decorator as st(softtorch.clamp).

Returns the hard clamp during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.heaviside_st(*args, **kwargs) ¤

Straight-through version of softtorch.heaviside. Implemented using the softtorch.st decorator as st(softtorch.heaviside).

Returns the hard heaviside during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.relu_st(*args, **kwargs) ¤

Straight-through version of softtorch.relu. Implemented using the softtorch.st decorator as st(softtorch.relu).

Returns the hard relu during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.round_st(*args, **kwargs) ¤

Straight-through version of softtorch.round. Implemented using the softtorch.st decorator as st(softtorch.round).

Returns the hard round during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.sign_st(*args, **kwargs) ¤

Straight-through version of softtorch.sign. Implemented using the softtorch.st decorator as st(softtorch.sign).

Returns the hard sign during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

Tensor-valued operators¤

softtorch.argmax_st(*args, **kwargs) ¤

Straight-through version of softtorch.argmax. Implemented using the softtorch.st decorator as st(softtorch.argmax).

Returns the hard argmax during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.max_st(*args, **kwargs) ¤

Straight-through version of softtorch.max. Implemented using the softtorch.st decorator as st(softtorch.max).

Returns the hard max during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.argmin_st(*args, **kwargs) ¤

Straight-through version of softtorch.argmin. Implemented using the softtorch.st decorator as st(softtorch.argmin).

Returns the hard argmin during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.min_st(*args, **kwargs) ¤

Straight-through version of softtorch.min. Implemented using the softtorch.st decorator as st(softtorch.min).

Returns the hard min during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.argquantile_st(*args, **kwargs) ¤

Straight-through version of softtorch.argquantile. Implemented using the softtorch.st decorator as st(softtorch.argquantile).

Returns the hard argquantile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.quantile_st(*args, **kwargs) ¤

Straight-through version of softtorch.quantile. Implemented using the softtorch.st decorator as st(softtorch.quantile).

Returns the hard quantile during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.argmedian_st(*args, **kwargs) ¤

Straight-through version of softtorch.argmedian. Implemented using the softtorch.st decorator as st(softtorch.argmedian).

Returns the hard argmedian during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.median_st(*args, **kwargs) ¤

Straight-through version of softtorch.median. Implemented using the softtorch.st decorator as st(softtorch.median).

Returns the hard median during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.argsort_st(*args, **kwargs) ¤

Straight-through version of softtorch.argsort. Implemented using the softtorch.st decorator as st(softtorch.argsort).

Returns the hard argsort during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.sort_st(*args, **kwargs) ¤

Straight-through version of softtorch.sort. Implemented using the softtorch.st decorator as st(softtorch.sort).

Returns the hard sort during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.topk_st(*args, **kwargs) ¤

Straight-through version of softtorch.topk. Implemented using the softtorch.st decorator as st(softtorch.topk).

Returns the hard topk during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.rank_st(*args, **kwargs) ¤

Straight-through version of softtorch.rank. Implemented using the softtorch.st decorator as st(softtorch.rank).

Returns the hard rank during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

Comparison operators¤

softtorch.greater_st(*args, **kwargs) ¤

Straight-through version of softtorch.greater. Implemented using the softtorch.st decorator as st(softtorch.greater).

Returns the hard greater during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.greater_equal_st(*args, **kwargs) ¤

Straight-through version of softtorch.greater_equal. Implemented using the softtorch.st decorator as st(softtorch.greater_equal).

Returns the hard greater_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.less_st(*args, **kwargs) ¤

Straight-through version of softtorch.less. Implemented using the softtorch.st decorator as st(softtorch.less).

Returns the hard less during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.less_equal_st(*args, **kwargs) ¤

Straight-through version of softtorch.less_equal. Implemented using the softtorch.st decorator as st(softtorch.less_equal).

Returns the hard less_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.eq_st(*args, **kwargs) ¤

Straight-through version of softtorch.eq. Implemented using the softtorch.st decorator as st(softtorch.eq).

Returns the hard eq during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.not_equal_st(*args, **kwargs) ¤

Straight-through version of softtorch.not_equal. Implemented using the softtorch.st decorator as st(softtorch.not_equal).

Returns the hard not_equal during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

softtorch.isclose_st(*args, **kwargs) ¤

Straight-through version of softtorch.isclose. Implemented using the softtorch.st decorator as st(softtorch.isclose).

Returns the hard isclose during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).