Skip to contents

Rectified Gated Linear Unit (ReGLU) module. Computes the output as \(\text{ReGLU}(x, g) = x \cdot \text{ReLU}(g)\) where \(x\) and \(g\) are created by splitting the input tensor in half along the last dimension.

Usage

nn_reglu()

References

Shazeer N (2020). “GLU Variants Improve Transformer.” 2002.05202, https://arxiv.org/abs/2002.05202.

Examples

x = torch::torch_randn(10, 10)
reglu = nn_reglu()
reglu(x)
#> torch_tensor
#>  0.5048  0.0824  4.3947 -0.4379 -0.0000
#>  0.0000  0.1177 -0.0000 -0.0000  0.8755
#> -0.0000 -0.0000 -0.0000  0.0000  0.5810
#> -0.0000 -0.0000  0.0143  0.0000  0.1846
#>  0.0065 -0.1198 -1.2562  0.0000  2.6531
#> -0.0000 -0.1154  0.0000 -0.0000  0.3259
#>  0.0000  0.0000 -0.0000  0.2414  0.0004
#> -0.6159  0.5157  0.2463 -0.0000  0.0047
#>  0.7439  0.0000  1.1910  0.0000 -0.0000
#>  0.6378 -0.0314  0.0000  0.0000 -2.6684
#> [ CPUFloatType{10,5} ]