Softtorch operators¤
Helper functions¤
softtorch.sigmoidal(x: torch.Tensor, softness: float = 0.1, mode: Literal['smooth', 'c0', 'c1', '_c1_pnorm', 'c2', '_c2_pnorm'] = 'smooth') -> torch.Tensor
¤
Sigmoidal functions defining a characteristic S-shaped curve.
Arguments:
x: Input Tensor.softness: Softness of the function, should be larger than zero.mode: Choice of smoothing family for the surrogate step.smooth: Smooth sigmoidal based on the logistic function.c0: Continuous sigmoidal based on a piecewise quadratic polynomial.c1: Differentiable sigmoidal based on a piecewise cubic polynomial.c2: Twice differentiable sigmoidal based on a piecewise quintic polynomial.
Returns:
SoftBool of same shape as x (Tensor with values in [0, 1]).
softtorch.softrelu(x: torch.Tensor, softness: float = 0.1, mode: Literal['smooth', 'c0', 'c1', '_c1_pnorm', 'c2', '_c2_pnorm'] = 'smooth', gated: bool = False) -> torch.Tensor
¤
Family of soft relaxations to ReLU.
Arguments:
x: Input Tensor.softness: Softness of the function, should be larger than zero.mode: Choice ofsofttorch.sigmoidalsmoothing mode, options are "smooth", "c0", "c1", "_c1_pnorm", "c2", or "_c2_pnorm".gated: If True, uses the 'gated' versionx * sigmoidal(x). If False, uses the integral of the sigmoidal.
Returns:
Result of applying soft elementwise ReLU to x.
Elementwise operators¤
softtorch.abs(x: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth') -> torch.Tensor
¤
Performs a soft version of torch.abs.
Arguments:
x: Input Tensor of any shape.softness: Softness of the function, should be larger than zero.mode: Projection mode. "hard" returns the exact absolute value. Otherwise usessofttorch.sigmoidal-based "smooth", "c0", "c1", "c2" relaxations.
Returns:
Result of applying soft elementwise absolute value to x.
softtorch.clamp(x: torch.Tensor, a: torch.Tensor, b: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', gated: bool = False) -> torch.Tensor
¤
Performs a soft version of torch.clamp.
Arguments:
x: Input Tensor of any shape.a: Lower bound scalar.b: Upper bound scalar.softness: Softness of the function, should be larger than zero.mode: If "hard", appliestorch.clamp. Otherwise usessofttorch.softrelu-based "smooth", "c0", "c1", "c2" relaxations.gated: Seesofttorch.softreludocumentation.
Returns:
Result of applying soft elementwise clamping to x.
softtorch.heaviside(x: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth') -> torch.Tensor
¤
Performs a soft version of torch.heaviside(x,0.5).
Arguments:
x: Input Tensor of any shape.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact Heaviside step. Otherwise usessofttorch.sigmoidal-based "smooth", "c0", "c1", "c2" relaxations.
Returns:
SoftBool of same shape as x (Tensor with values in [0, 1]), relaxing the elementwise Heaviside step function.
softtorch.relu(x: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', gated: bool = False) -> torch.Tensor
¤
Performs a soft version of torch.relu.
Arguments:
x: Input Tensor of any shape.softness: Softness of the function, should be larger than zero.mode: If "hard", appliestorch.relu. Otherwise usessofttorch.softreluwith "smooth", "c0", "c1", "c2" relaxations.gated: Seesofttorch.softreludocumentation.
Returns:
Result of applying soft elementwise ReLU to x.
softtorch.round(x: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', neighbor_radius: int = 5) -> torch.Tensor
¤
Performs a soft version of torch.round.
Arguments:
x: Input Tensor of any shape.softness: Softness of the function, should be larger than zero.mode: If "hard", appliestorch.round. Otherwise uses a sigmoidal-based relaxation based on the algorithm described in Smooth Approximations of the Rounding Function. Supports thesofttorch.sigmoidalmodes "smooth", "c0", "c1", "c2".neighbor_radius: Number of neighbors on each side of the floor value to consider for the soft rounding.
Returns:
Result of applying soft elementwise rounding to x.
softtorch.sign(x: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth') -> torch.Tensor
¤
Performs a soft version of torch.sign.
Arguments:
x: Input Tensor of any shape.softness: Softness of the function, should be larger than zero.mode: If "hard", returnstorch.sign. Otherwise usessofttorch.sigmoidal-based "smooth", "c0", "c1", "c2" relaxations.
Returns:
Result of applying soft elementwise sign to x.
Tensor-valued operators¤
softtorch.argmax(x: torch.Tensor, dim: int | None = None, keepdim: bool = False, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'sorting_network'] = 'softsort', standardize: bool = True, ot_kwargs: dict | None = None) -> torch.Tensor
¤
Performs a soft version of torch.argmax of x along the specified dim.
Arguments:
x: Input Tensor of shape (..., n, ...).dim: The dimension along which to compute the argmax. If None, the input Tensor is flattened before computing the argmax.keepdim: If True, keeps the reduced dimension as a singleton {1}.softness: Softness of the function, should be larger than zero.mode: Type of regularizer in the projection operators.hard: Returns the result of torch.argmax with a one-hot encoding of the indices.smooth: C∞ smooth (entropy-based). Soften unit simplex projection via an entropic regularizer, computed in closed-form via a softmax operation.c0: C0 continuous (L2-based). Soften unit simplex projection via a quadratic regularizer, computed via the algorithm in Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application.c1: C1 differentiable (p=3/2 p-norm). Computed in closed form via quadratic formula.c2: C2 twice differentiable (p=4/3 p-norm). Computed in closed form via Cardano's method.
method: Method to compute the soft argmax. All approaches were originally proposed for the smooth mode, we extend them to the c0,c1,c2 modes as well.ot: Computes the max element via optimal transport projection onto a 2-point support.softsort: Computes the max element of the "SoftSort" operator from SoftSort: A Continuous Relaxation for the argsort Operator. Reduces to projectingxonto the unit simplex.neuralsort: Computes the max element of the "NeuralSort" operator from Stochastic Optimization of Sorting Networks via Continuous Relaxations.
-
standardize: If True, standardizes and squashes the inputxalong the specified dim before applying the softargmax operation. This can improve numerical stability and performance, especially when the values inxvary widely in scale. -
ot_kwargs: Additional optional keyword arguments to pass to the OT projection operator, e.g., to control the number of max iterations or tolerance.
Returns:
A SoftIndex of shape (..., {1}, ..., [n]) (positive Tensor which sums to 1 over the last dimension). Represents the probability of an index corresponding to the argmax along the specified dim.
softtorch.max(x: torch.Tensor, dim: int | None = None, keepdim: bool = False, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'softsort', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True) -> torch.Tensor | torch.return_types.max[torch.Tensor, torch.Tensor]
¤
Performs a soft version of torch.max of x along the specified dim.
For methods other than fast_soft_sort and sorting_network, implemented as softtorch.argmax followed by softtorch.take_along_dim, see respective documentations for details.
For fast_soft_sort and sorting_network, uses softtorch.sort to compute soft sorted values and retrieves the maximum as the first element. See softtorch.sort for method details.
Extra Arguments:
gated_grad: IfFalse, stops the gradient flow through the soft index. True gives gated 'SiLU-style' gradients, while False gives integrated 'Softplus-style' gradients.
Returns:
- If
dimis None (default): Scalar tensor representing the soft maximum of the flattenedx. - If
dimis specified: Namedtuple containing two fields:values: Tensor of shape (..., {1}, ...) representing the soft maximum ofxalong the specified dim.indices: SoftIndex of shape (..., {1}, ..., [n]) (positive Tensor which sums to 1 over the last dimension). Represents the soft indices of the maximum values.
softtorch.argmin(x: torch.Tensor, dim: int | None = None, keepdim: bool = False, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'sorting_network'] = 'softsort', standardize: bool = True, ot_kwargs: dict | None = None) -> torch.Tensor
¤
Performs a soft version of torch.argmin of x along the specified dim.
Implemented as softtorch.argmax on -x, see respective documentation for details.
softtorch.min(x: torch.Tensor, dim: int | None = None, keepdim: bool = False, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'softsort', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True) -> torch.Tensor | torch.return_types.min[torch.Tensor, torch.Tensor]
¤
Performs a soft version of torch.min of x along the specified dim.
Implemented via softtorch.max on -x, see respective documentation for details.
softtorch.argquantile(x: torch.Tensor, q: torch.Tensor | float, dim: int | None = None, keepdim: bool = False, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'sorting_network'] = 'neuralsort', interpolation: Literal['linear', 'lower', 'higher', 'nearest', 'midpoint'] = 'linear', standardize: bool = True, ot_kwargs: dict | None = None) -> torch.Tensor
¤
Performs a soft version of torch.quantile
of x along the specified dim.
Arguments:
x: Input Tensor of shape (..., n, ...).q: Scalar quantile or 1-D Tensor of quantiles in [0, 1]. When a 1-D tensor of length k is passed, the q dimension is prepended to the output shape.dim: The dim along which to compute the argquantile. If None, the input Tensor is flattened before computing the argquantile.keepdim: If True, keeps the reduced dimension as a singleton {1}.softness: Softness of the function, should be larger than zero.mode: Type of regularizer in the projection operators.hard: Returns a one/two-hot encoding of the indices corresponding to the torch.quantile definitions.smooth: Soften projections via an entropic regularizer.- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via a softmax operation. - For optimal transport (
otmethod), transport plan is computed via Sinkhorn iterations (see Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances).
- For unit simplex projection (
c0: Soften projections via a quadratic regularizer.- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via the algorithm in Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application. - For optimal transport (
otmethod), transport plan is computed via LBFGS (see Smooth and Sparse Optimal Transport).
- For unit simplex projection (
c1/c2: C1 differentiable / C2 twice differentiable. Similar toc0, but using p-norm regularizers with p=3/2 and p=4/3, respectively.
method: Method to compute the soft argquantile. All approaches were originally proposed for the smooth mode, we extend them to the c0,c1,c2 modes as well.ot: Uses a variation of the soft quantile approach in Differentiable Ranks and Sorting using Optimal Transport, which is adapted to converge to the standard quantile definitions for small softness. Depending on the interpolation, either a lower and upper quantile are computed and combined, or just a single quantile is computed. Intuition: The sorted elements are selected by specifying 4 or 3 "anchors" and then transporting the upper/lower quantile values to the appropriate anchors. Note:sinkhorn_max_iterandsinkhorn_tolcan be passed viaot_kwargsto control convergence.softsort: Computes the upper and lower quantiles via the "SoftSort" operator from SoftSort: A Continuous Relaxation for the argsort Operator. Note: Can introduce gradient discontinuities when elements inxare not unique, but is much faster than OT-based method.neuralsort: Computes the upper and lower quantiles via the "NeuralSort" operator from Stochastic Optimization of Sorting Networks via Continuous Relaxations.
interpolation: Method to compute the quantile, following the options in torch.quantile.-
standardize: If True, standardizes and squashes the inputxalong the specified dim before applying the softargquantile operation. This can improve numerical stability and performance, especially when the values inxvary widely in scale. -
ot_kwargs: Additional optional keyword arguments to pass to the OT projection operator, e.g., to control the number of max iterations or tolerance.
Returns:
A SoftIndex of shape (..., {1}, ..., [n]) for scalar q, or (k, ..., {1}, ..., [n]) for vector q of length k (q dimension prepended). Positive Tensor which sums to 1 over the last dimension. It represents a distribution over values in x being the q-quantile along the specified dim.
softtorch.quantile(x: torch.Tensor, q: torch.Tensor | float, dim: int | None = None, keepdim: bool = False, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'neuralsort', interpolation: Literal['linear', 'lower', 'higher', 'nearest', 'midpoint'] = 'linear', standardize: bool = True, ot_kwargs: dict | None = None, return_argquantile: bool = False, gated_grad: bool = True) -> torch.Tensor
¤
Performs a soft version of torch.quantile of x along the specified dim.
For methods other than fast_soft_sort and sorting_network, implemented as softtorch.argquantile followed by softtorch.take_along_dim, see respective documentations for details.
For fast_soft_sort and sorting_network, uses softtorch.sort to compute soft sorted values, then retrieves the quantile as a combination of the appropriate elements depending on the interpolation method. See softtorch.sort for method details.
Extra Arguments:
gated_grad: IfFalse, stops the gradient flow through the soft index. True gives gated 'SiLU-style' gradients, while False gives integrated 'Softplus-style' gradients.
Returns:
Tensor of shape (..., {1}, ...) for scalar q, or (k, ..., {1}, ...) for vector q of length k (q dimension prepended). Represents the soft q-quantile of x along the specified dim.
softtorch.argmedian(x: torch.Tensor, dim: int | None = None, keepdim: bool = False, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'sorting_network'] = 'neuralsort', standardize: bool = True, ot_kwargs: dict | None = None) -> torch.Tensor
¤
Computes the soft argmedian of x along the specified dim.
Implemented as softtorch.argquantile with q=0.5, see respective documentation for details.
softtorch.median(x: torch.Tensor, dim: int | None = None, keepdim: bool = False, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'neuralsort', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True) -> torch.Tensor | torch.return_types.median[torch.Tensor, torch.Tensor]
¤
Performs a soft version of torch.median of x along the specified dim.
Implemented as softtorch.quantile with q=0.5, see respective documentation for details.
softtorch.argsort(x: torch.Tensor, dim: int | None = -1, descending: bool = False, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'sorting_network'] = 'neuralsort', standardize: bool = True, ot_kwargs: dict | None = None) -> torch.Tensor
¤
Performs a soft version of torch.argsort of x along the specified dim.
Arguments:
x: Input Tensor of shape (..., n, ...).dim: The dim along which to compute the argsort operation. If None, uses the last dimension.descending: If True, sorts in descending order.softness: Softness of the function, should be larger than zero.mode: Type of regularizer in the projection operators.hard: Returns the result of torch.argsort with a one-hot encoding of the indices.smooth: Soften projections via an entropic regularizer.- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via a softmax operation. - For optimal transport (
otmethod), transport plan is computed via Sinkhorn iterations (see Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances).
- For unit simplex projection (
c0: Soften projections via a quadratic regularizer.- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via the algorithm in Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application. - For optimal transport (
otmethod), transport plan is computed via LBFGS (see Smooth and Sparse Optimal Transport).
- For unit simplex projection (
c1/c2: C1 differentiable / C2 twice differentiable. Similar toc0, but using p-norm regularizers with p=3/2 and p=4/3, respectively.
method: Method to compute the soft argsort. All approaches were originally proposed for the smooth mode, we extend them to the c0,c1,c2 modes as well.ot: Uses the approach in Differentiable Ranks and Sorting using Optimal Transport. Intuition: The sorted elements are selected by specifying n "anchors" and then transporting the ith-largest value to the ith-largest anchor. Note:sinkhorn_max_iterandsinkhorn_tolcan be passed viaot_kwargsto control convergence.softsort: Computes the "SoftSort" operator from SoftSort: A Continuous Relaxation for the argsort Operator. Note: Can introduce gradient discontinuities when elements inxare not unique, but is much faster than OT-based method.neuralsort: Computes the "NeuralSort" operator from Stochastic Optimization of Sorting Networks via Continuous Relaxations.
-
standardize: If True, standardizes and squashes the inputxalong the specified dim before applying the softargsort operation. This can improve numerical stability and performance, especially when the values inxvary widely in scale. -
ot_kwargs: Additional optional keyword arguments to pass to the OT projection operator, e.g., to control the number of max iterations or tolerance.
Returns:
A SoftIndex of shape (..., n, ..., [n]) (positive Tensor which sums to 1 over the last dimension). The elements in (..., i, ..., [n]) represent a distribution over values in x for the ith smallest element along the specified dim.
softtorch.sort(x: torch.Tensor, dim: int | None = -1, descending: bool = False, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'neuralsort', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True, return_indices: bool = True) -> torch.return_types.sort[torch.Tensor, torch.Tensor]
¤
Performs a soft version of torch.sort of x along the specified dim.
Most methods go through softtorch.argsort + softtorch.take_along_dim to produce soft sorted values.
The exceptions (fast_soft_sort, smooth_sort, sorting_network) bypass soft indices and compute values directly:
fast_soft_sort: permutahedron projection via PAV isotonic regression (Blondel et al., 2020). Insmoothmode, uses an entropic (log-KL) variant that is piecewise smooth but not C∞ (discontinuities at argsort chamber boundaries). The PAV step uses Numba JIT (CPU-only); GPU tensors are transferred to CPU for the forward pass.smooth_sort: not available in SoftTorch, seesoftjax.sortfor the SoftJAX-only implementation.sorting_network: soft bitonic sorting network (Petersen et al., 2021).
Extra Arguments:
gated_grad: IfFalse, stops the gradient flow through the soft index. True gives gated 'SiLU-style' gradients, while False gives integrated 'Softplus-style' gradients.return_indices: IfFalse, skips computation of the soft index (indices will beNone). This avoids the O(n²) memory cost of materializing the n×n soft permutation matrix.
Returns:
- Namedtuple containing two fields:
values: Soft sorted values ofx, shape (..., n, ...).indices: SoftIndex of shape (..., n, ..., [n]) (positive Tensor which sums to 1 over the last dimension). Represents the soft indices of the sorted values.Noneifreturn_indices=False, or when usingfast_soft_sortorsorting_networkmethods.
softtorch.topk(x: torch.Tensor, k: int, dim: int | None = None, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'neuralsort', standardize: bool = True, ot_kwargs: dict | None = None, gated_grad: bool = True) -> torch.return_types.topk[torch.Tensor, torch.Tensor]
¤
Performs a soft version of torch.topk
Arguments:
x: Input Tensor of shape (..., n, ...).k: The number of top elements to select.dim: The dim along which to compute the topk operation. If dim is None, the last dimension of the input is chosen.softness: Softness of the function, should be larger than zero.mode: Type of regularizer in the projection operators.hard: Returns the result of torch.topk with a one-hot encoding of the indices.smooth: Soften projections via an entropic regularizer.- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via a softmax operation. - For optimal transport (
otmethod), transport plan is computed via Sinkhorn iterations (see Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances).
- For unit simplex projection (
c0: Soften projections via a quadratic regularizer.- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via the algorithm in Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application. - For optimal transport (
otmethod), transport plan is computed via LBFGS (see Smooth and Sparse Optimal Transport).
- For unit simplex projection (
c1/c2: C1 differentiable / C2 twice differentiable. Similar toc0, but using p-norm regularizers with p=3/2 and p=4/3, respectively.
method: Method to compute the soft topk. All approaches were originally proposed for the smooth mode, we extend them to the c0,c1,c2 modes as well.ot: Uses the approach in Differentiable Top-k with Optimal Transport. Intuition: The top-k elements are selected by specifying k+1 "anchors" and then transporting the top_k values to the top k anchors, and the remaining (n-k) values to the last anchor. Note:sinkhorn_max_iterandsinkhorn_tolcan be passed viaot_kwargsto control convergence.softsort: Computes the top-k elements of the "SoftSort" operator from SoftSort: A Continuous Relaxation for the argsort Operator. Note: Can introduce gradient discontinuities when elements inxare not unique, but is much faster than OT-based method.neuralsort: Computes the top-k elements of the "NeuralSort" operator from Stochastic Optimization of Sorting Networks via Continuous Relaxations.fast_soft_sort: Uses theFastSoftSortoperator from Fast Differentiable Sorting and Ranking to directly compute the soft sorted values, via projection onto the permutahedron. The projection is solved via a PAV algorithm as proposed in Fast Differentiable Sorting and Ranking. The top-k values are then retrieved by taking the first k values from the soft sorted output. Note: This method does not return the soft indices, only the soft values. The PAV step uses Numba JIT (CPU-only); GPU tensors are transferred to CPU for the forward pass.sorting_network: Uses a soft bitonic sorting network as proposed in Differentiable Sorting Networks for Scalable Sorting and Ranking Supervision, replacing hard compare-and-swap with soft versions. The top-k values are retrieved by taking the first k values from the soft sorted output. Note: This method does not return the soft indices, only the soft values.
-
standardize: If True, standardizes and squashes the inputxalong the specified dim before applying the softtopk operation. This can improve numerical stability and performance, especially when the values inxvary widely in scale. -
ot_kwargs: Additional optional keyword arguments to pass to the OT projection operator, e.g., to control the number of max iterations or tolerance. gated_grad: IfFalse, stops the gradient flow through the soft index. True gives gated 'SiLU-style' gradients, while False gives integrated 'Softplus-style' gradients.
Returns:
- Namedtuple containing two fields:
values: Top-k values ofx, shape (..., k, ...).indices: SoftIndex of shape (..., k, ..., [n]) (positive Tensor which sums to 1 over the last dimension). Represents the soft indices of the top-k values.
softtorch.rank(x: torch.Tensor, dim: int | None = None, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', method: Literal['ot', 'softsort', 'neuralsort', 'fast_soft_sort', 'smooth_sort', 'sorting_network'] = 'neuralsort', descending: bool = True, standardize: bool = True, ot_kwargs: dict | None = None) -> torch.Tensor
¤
Computes the soft rankings of x along the specified dim.
Arguments:
x: Input Tensor of shape (..., n, ...).dim: The dim along which to compute the ranking operation. If None, the input Tensor is flattened before computing the ranking.descending: If True, larger inputs receive smaller ranks (best rank is 0). If False, ranks increase with the input values.softness: Softness of the function, should be larger than zero.mode: Type of regularizer in the projection operators.hard: Returns rank computed as two torch.argsort calls.smooth: Soften projections via an entropic regularizer.- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via a softmax operation. - For optimal transport (
otmethod), transport plan is computed via Sinkhorn iterations (see Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances).
- For unit simplex projection (
c0: Soften projections via a quadratic regularizer.- For unit simplex projection (
softsort/neuralsortmethods), projection is computed in closed-form via the algorithm in Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application. - For optimal transport (
otmethod), transport plan is computed via LBFGS (see Smooth and Sparse Optimal Transport).
- For unit simplex projection (
c1/c2: C1 differentiable / C2 twice differentiable. Similar toc0, but using p-norm regularizers with p=3/2 and p=4/3, respectively.
method: Method to compute the soft rank. All approaches were originally proposed for the smooth mode, we extend them to the c0,c1,c2 modes as well.ot: Uses the approach in Differentiable Ranks and Sorting using Optimal Transport. Intuition: Run an OT procedure as insofttorch.argsort, then transport the sorted ranks (0, 1, ..., n-1) back to the ranks of the original values by using the transpose of the transport plan. Note:sinkhorn_max_iterandsinkhorn_tolcan be passed viaot_kwargsto control convergence.softsort: Adapts the "SoftSort" operator from SoftSort: A Continuous Relaxation for the argsort Operator for rank by using a single column-wise projection instead of row-wise projection. Note: Can introduce gradient discontinuities when elements inxare not unique, but is much faster than OT-based method.neuralsort: Adapts the "NeuralSort" operator from Stochastic Optimization of Sorting Networks via Continuous Relaxations for rank by renormalizing over columns after the row-wise projection.fast_soft_sort: Uses theFastSoftSortoperator from Fast Differentiable Sorting and Ranking to directly compute the soft ranks, via projection onto the permutahedron. The projection is solved via a PAV algorithm as proposed in Fast Differentiable Sorting and Ranking. Note: The PAV step uses Numba JIT (CPU-only); GPU tensors are transferred to CPU for the forward pass.
-
standardize: If True, standardizes and squashes the inputxalong the specified dim before applying the softrank operation. This can improve numerical stability and performance, especially when the values inxvary widely in scale. -
ot_kwargs: Additional optional keyword arguments to pass to the OT projection operator, e.g., to control the number of max iterations or tolerance.
Returns:
A positive Tensor of shape (..., n, ...) with values in [1, n]. The elements in (..., i, ...) represent the soft rank of the ith element along the specified dim.
Comparison operators¤
softtorch.greater(x: torch.Tensor, y: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> torch.Tensor
¤
Computes a soft approximation to elementwise x > y.
Uses a Heaviside relaxation so the output approaches 0 at equality.
Arguments:
x: First input Tensor.y: Second input Tensor, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, greater returns 0 at equality.
Returns:
SoftBool of same shape as x and y (Tensor with values in [0, 1]), relaxing the elementwise x > y.
softtorch.greater_equal(x: torch.Tensor, y: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> torch.Tensor
¤
Computes a soft approximation to elementwise x >= y.
Uses a Heaviside relaxation so the output approaches 1 at equality.
Arguments:
x: First input Tensor.y: Second input Tensor, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, greater_equal returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Tensor with values in [0, 1]), relaxing the elementwise x >= y.
softtorch.less(x: torch.Tensor, y: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> torch.Tensor
¤
Computes a soft approximation to elementwise x < y.
Uses a Heaviside relaxation so the output approaches 0 at equality.
Arguments:
x: First input Tensor.y: Second input Tensor, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside ("smooth", "c0", "c1", or "c2"). Defaults to "smooth".epsilon: Small offset so that as softness->0, less returns 0 at equality.
Returns:
SoftBool of same shape as x and y (Tensor with values in [0, 1]), relaxing the elementwise x < y.
softtorch.less_equal(x: torch.Tensor, y: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> torch.Tensor
¤
Computes a soft approximation to elementwise x <= y.
Uses a Heaviside relaxation so the output approaches 1 at equality.
Arguments:
x: First input Tensor.y: Second input Tensor, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, less_equal returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Tensor with values in [0, 1]), relaxing the elementwise x <= y.
softtorch.eq(x: torch.Tensor, y: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> torch.Tensor
¤
Computes a soft approximation to elementwise x == y.
Implemented as a soft abs(x - y) <= 0 comparison.
Arguments:
x: First input Tensor.y: Second input Tensor, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, eq returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Tensor with values in [0, 1]), relaxing the elementwise x == y.
softtorch.not_equal(x: torch.Tensor, y: torch.Tensor, softness: float = 0.1, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> torch.Tensor
¤
Computes a soft approximation to elementwise x != y.
Implemented as a soft abs(x - y) > 0 comparison.
Arguments:
x: First input Tensor.y: Second input Tensor, same shape asx.softness: Softness of the function, should be larger than zero.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, not_equal returns 0 at equality.
Returns:
SoftBool of same shape as x and y (Tensor with values in [0, 1]), relaxing the elementwise x != y.
softtorch.isclose(x: torch.Tensor, y: torch.Tensor, softness: float = 0.1, rtol: float = 1e-05, atol: float = 1e-08, mode: Literal['hard', 'smooth', 'c0', 'c1', 'c2'] = 'smooth', epsilon: float = 1e-10) -> torch.Tensor
¤
Computes a soft approximation to torch.isclose for elementwise comparison.
Implemented as a soft abs(x - y) <= atol + rtol * abs(y) comparison.
Arguments:
x: First input Tensor.y: Second input Tensor, same shape asx.softness: Softness of the function, should be larger than zero.rtol: Relative tolerance. Defaults to 1e-5.atol: Absolute tolerance. Defaults to 1e-8.mode: If "hard", returns the exact comparison. Otherwise uses a soft Heaviside.epsilon: Small offset so that as softness->0, isclose returns 1 at equality.
Returns:
SoftBool of same shape as x and y (Tensor with values in [0, 1]), relaxing the elementwise isclose(x, y).
Logical operators¤
softtorch.logical_and(x: torch.Tensor, y: torch.Tensor, use_geometric_mean: bool = False) -> torch.Tensor
¤
Computes soft elementwise logical AND between two SoftBool Tensors.
Fuzzy logic implemented as all(stack([x, y], dim=-1), dim=-1).
Arguments:
x: First SoftBool input Tensor.y: Second SoftBool input Tensor.use_geometric_mean: If True, uses the geometric mean to compute the soft AND. Otherwise, the product is used.
Returns:
SoftBool of same shape as x and y (Tensor with values in [0, 1]), relaxing the elementwise logical AND.
softtorch.logical_not(x: torch.Tensor) -> torch.Tensor
¤
Computes soft elementwise logical NOT of a SoftBool Tensor.
Fuzzy logic implemented as 1.0 - x.
Arguments:
- x: SoftBool input Tensor.
Returns:
SoftBool of same shape as x (Tensor with values in [0, 1]), relaxing the elementwise logical NOT.
softtorch.logical_or(x: torch.Tensor, y: torch.Tensor, use_geometric_mean: bool = False) -> torch.Tensor
¤
Computes soft elementwise logical OR between two SoftBool Tensors.
Fuzzy logic implemented as any(stack([x, y], dim=-1), dim=-1).
Arguments:
- x: First SoftBool input Tensor.
- y: Second SoftBool input Tensor.
- use_geometric_mean: If True, uses the geometric mean to compute the soft AND inside the logical NOT. Otherwise, the product is used.
Returns:
SoftBool of same shape as x and y (Tensor with values in [0, 1]), relaxing the elementwise logical OR.
softtorch.logical_xor(x: torch.Tensor, y: torch.Tensor, use_geometric_mean: bool = False) -> torch.Tensor
¤
Computes soft elementwise logical XOR between two SoftBool Tensors.
Arguments:
- x: First SoftBool input Tensor.
- y: Second SoftBool input Tensor.
- use_geometric_mean: If True, uses the geometric mean to compute the soft ANDs inside the logical OR. Otherwise, the product is used.
Returns:
SoftBool of same shape as x and y (Tensor with values in [0, 1]), relaxing the elementwise logical XOR.
softtorch.all(x: torch.Tensor, dim: int = -1, epsilon: float = 1e-10, use_geometric_mean: bool = False) -> torch.Tensor
¤
Computes soft elementwise logical AND across a specified dim. Fuzzy logic implemented as the geometric mean along the dim.
Arguments:
- x: SoftBool input Tensor.
- dim: Axis along which to compute the logical AND. Default is -1 (last dim).
- epsilon: Minimum value for numerical stability inside the log.
- use_geometric_mean: If True, uses the geometric mean to compute the soft AND. Otherwise, the product is used.
Returns:
SoftBool (Tensor with values in [0, 1]) with the specified dim reduced, relaxing the logical ALL along that dim.
softtorch.any(x: torch.Tensor, dim: int = -1, use_geometric_mean: bool = False) -> torch.Tensor
¤
Computes soft elementwise logical OR across a specified dim.
Fuzzy logic implemented as 1.0 - all(logical_not(x), dim=dim).
Arguments:
- x: SoftBool input Tensor.
- dim: Axis along which to compute the logical OR. Default is -1 (last dim).
- use_geometric_mean: If True, uses the geometric mean to compute the soft AND inside the logical NOT. Otherwise, the product is used.
Returns:
SoftBool (Tensor with values in [0, 1]) with the specified dim reduced, relaxing the logical ANY along that dim.
Selection operators¤
softtorch.where(condition: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor
¤
Computes a soft version of torch.where as x * condition + y * (1.0 - condition).
Arguments:
- condition: SoftBool condition Tensor, same shape as x and y.
- x: First input Tensor, same shape as condition.
- y: Second input Tensor, same shape as condition.
Returns:
Tensor of the same shape as x and y, interpolating between x and y according to condition in [0, 1].
softtorch.take_along_dim(x: torch.Tensor, soft_index: torch.Tensor, dim: int | None = None) -> torch.Tensor
¤
Performs a soft version of torch.take_along_dim via a weighted dot product.
Arguments:
x: Input Tensor of shape (..., n, ...).soft_index: A SoftIndex of shape (..., k, ..., [n]) (positive Tensor which sums to 1 over the last dimension). If dim is None, must be two-dimensional. If dim is not None, must have x.ndim + 1 == soft_index.ndim, and x must be broadcast-compatible with soft_index along dimensions other than dim.dim: Dim along which to apply the soft index. If None, the input is flattened before applying the soft indices.
Returns:
Tensor of shape (..., k, ...), representing the result after soft selection along the specified dim.
softtorch.take(x: torch.Tensor, soft_index: torch.Tensor, dim: int | None = None) -> torch.Tensor
¤
Performs a soft version of torch.take via a weighted dot product.
Arguments:
x: Input Tensor of shape (..., n, ...).soft_index: A SoftIndex of shape (k, [n]) (positive Tensor which sums to 1 over the last dimension).dim: Dim along which to apply the soft index. If None, the input is flattened before applying the soft indices.
Returns:
Tensor of shape (..., k, ...) after soft selection.
softtorch.index_select(x: torch.Tensor, soft_index: torch.Tensor, dim: int, keepdim: bool = True) -> torch.Tensor
¤
Performs a soft version of torch.index_select via a weighted dot product.
Arguments:
x: Input Tensor of shape (..., n, ...).soft_index: A SoftIndex of shape ([n],) (positive Tensor which sums to 1 over the last dimension).dim: Dim along which to apply the soft index.keepdim: If True, keeps the reduced dimension as a singleton {1}.
Returns:
Tensor after soft indexing, shape (..., {1}, ...).
softtorch.narrow(x: torch.Tensor, soft_start: torch.Tensor, length: int, dim: int = 0) -> torch.Tensor
¤
Performs a soft version of torch.narrow via a weighted dot product.
Arguments:
x: Input Tensor of shape (..., n, ...).soft_start: A SoftIndex of shape ([n],) (positive Tensor which sums to 1 over the last dimension).length: Length of the slice to extract.dim: Dim along which to apply the soft slice.
Returns:
Tensor of shape (..., length, ...) after soft slicing.