Skip to contents

Rectified Gated Linear Unit (ReGLU) module. Computes the output as \(\text{ReGLU}(x, g) = x \cdot \text{ReLU}(g)\) where \(x\) and \(g\) are created by splitting the input tensor in half along the last dimension.

Usage

nn_reglu()

References

Shazeer N (2020). “GLU Variants Improve Transformer.” 2002.05202, https://arxiv.org/abs/2002.05202.

Examples

x = torch::torch_randn(10, 10)
reglu = nn_reglu()
reglu(x)
#> torch_tensor
#> -0.0305  0.4024 -0.9155 -0.0000 -0.0000
#>  0.0066  0.0000 -0.0000  0.6590 -0.0000
#>  0.0000 -0.7126 -0.5516 -0.6000 -0.0000
#> -0.0000 -0.0000  0.0000 -0.0000 -0.8532
#> -0.3769 -0.0000  0.0000  0.0000  0.1320
#>  0.0754 -0.3070 -0.0000 -0.0000  0.0659
#> -0.0000  0.0000  0.6847  0.1908 -0.0000
#>  0.4723  0.0000 -0.0000 -0.0000  0.0000
#>  0.0000  0.0000  0.0000 -0.0000 -0.0000
#> -1.5769 -0.8469  0.0889  0.0000  0.0000
#> [ CPUFloatType{10,5} ]