Rectified Gated Linear Unit (ReGLU) module. Computes the output as \(\text{ReGLU}(x, g) = x \cdot \text{ReLU}(g)\) where \(x\) and \(g\) are created by splitting the input tensor in half along the last dimension.
References
Shazeer N (2020). “GLU Variants Improve Transformer.” 2002.05202, https://arxiv.org/abs/2002.05202.
Examples
x = torch::torch_randn(10, 10)
reglu = nn_reglu()
reglu(x)
#> torch_tensor
#> -5.9348 -0.3414 0.0000 0.0000 0.0000
#> 0.0000 -0.4270 -0.0000 -0.0000 -2.5472
#> 0.3446 0.0000 -0.0000 0.7185 0.6021
#> -0.0000 -0.0000 0.0000 0.0000 0.3275
#> -0.0000 -0.2929 -0.1820 0.0000 -0.0000
#> -0.6274 -0.0000 -0.0000 -0.8201 -0.0000
#> -1.1675 -0.0000 -0.0000 -2.1920 -1.2770
#> -0.0000 0.5259 -0.0451 0.0000 0.5556
#> 0.0000 -0.0000 -0.0000 0.0000 -0.0000
#> 0.0378 1.5079 -0.1216 -0.0000 -0.0000
#> [ CPUFloatType{10,5} ]