Skip to contents

Rectified Gated Linear Unit (ReGLU) module. Computes the output as \(\text{ReGLU}(x, g) = x \cdot \text{ReLU}(g)\) where \(x\) and \(g\) are created by splitting the input tensor in half along the last dimension.

Usage

nn_reglu()

References

Shazeer N (2020). “GLU Variants Improve Transformer.” 2002.05202, https://arxiv.org/abs/2002.05202.

Examples

x = torch::torch_randn(10, 10)
reglu = nn_reglu()
reglu(x)
#> torch_tensor
#>  0.8159 -0.0000 -0.7747 -0.0677 -2.5267
#>  0.0000 -0.0158  0.0000 -0.0000 -3.3878
#> -0.2122  0.9207  0.0000  0.0000 -0.0000
#>  0.4702  0.0000  0.0000 -0.0000 -0.0705
#> -0.0000  0.0000  0.0535  0.6026 -0.3574
#> -0.0000  0.0000  0.8842 -0.0000 -0.1450
#> -0.2433 -0.8799  0.0000  0.2056  0.0000
#> -1.3344  0.0000  0.0340  0.2439 -0.5074
#>  0.0000  0.2539 -0.2752 -0.0000  0.9829
#> -0.0023 -0.0000  0.8591  0.0275 -0.0000
#> [ CPUFloatType{10,5} ]