Skip to content

Straight-through operators¤

Straight-through utility¤

softtorch.st(fn: Callable) -> Callable ¤

This decorator calls the decorated function twice: once with mode="hard" and once with the specified mode. It returns the output from the hard forward pass, but uses the output from the soft backward pass to compute gradients.

Arguments:

  • fn: The function to be wrapped. It should accept a mode argument.

Returns: A wrapped function that behaves like the mode="hard" version during the forward pass, but computes gradients using the specified mode and softness during the backward pass.

softtorch.grad_replace(fn: Callable) -> Callable ¤

This decorator calls the decorated function twice: once with forward=True and once with forward=False. It returns the output from the forward pass, but uses the output from the backward pass to compute gradients.

Arguments:

  • fn: The function to be wrapped. It should accept a forward argument that specifies which computation to perform depending on forward or backward pass.

Returns: A wrapped function that behaves like the forward=True version during the forward pass, but computes gradients using the forward=False version during the backward pass.

Elementwise operators¤

softtorch.abs_st(*args, **kwargs) ¤

softtorch.clamp_st(*args, **kwargs) ¤

softtorch.heaviside_st(*args, **kwargs) ¤

softtorch.relu_st(*args, **kwargs) ¤

softtorch.round_st(*args, **kwargs) ¤

softtorch.sign_st(*args, **kwargs) ¤

Tensor-valued operators¤

softtorch.argmax_st(*args, **kwargs) ¤

Straight-through version of softtorch.argmax.

This function returns the hard argmax during the forward pass, but uses a soft relaxation (controlled by the mode argument) for the backward pass (i.e., gradients are computed through the soft version).

Implemented using the softtorch.st decorator as st(softtorch.argmax).

softtorch.max_st(*args, **kwargs) ¤

softtorch.argmin_st(*args, **kwargs) ¤

softtorch.min_st(*args, **kwargs) ¤

softtorch.median_st(*args, **kwargs) ¤

softtorch.argsort_st(*args, **kwargs) ¤

softtorch.sort_st(*args, **kwargs) ¤

softtorch.topk_st(*args, **kwargs) ¤

softtorch.ranking_st(*args, **kwargs) ¤

Comparison operators¤

softtorch.greater_st(*args, **kwargs) ¤

softtorch.greater_equal_st(*args, **kwargs) ¤

softtorch.less_st(*args, **kwargs) ¤

softtorch.less_equal_st(*args, **kwargs) ¤

softtorch.equal_st(*args, **kwargs) ¤

softtorch.not_equal_st(*args, **kwargs) ¤

softtorch.isclose_st(*args, **kwargs) ¤