Rectified Gated Linear Unit (ReGLU) module. Computes the output as \(\text{ReGLU}(x, g) = x \cdot \text{ReLU}(g)\) where \(x\) and \(g\) are created by splitting the input tensor in half along the last dimension.
References
Shazeer N (2020). “GLU Variants Improve Transformer.” 2002.05202, https://arxiv.org/abs/2002.05202.
Examples
x = torch::torch_randn(10, 10)
reglu = nn_reglu()
reglu(x)
#> torch_tensor
#> -0.0000 -0.8718 -0.0210 0.0000 0.1261
#> -0.9289 -0.2623 -0.0000 0.0230 1.4270
#> -0.0727 0.1273 -2.2165 0.0000 0.0000
#> 0.2225 0.0788 -0.0266 -0.6551 -0.0000
#> 0.0521 0.0000 -0.0000 -0.2298 -0.1420
#> 0.0000 -0.0000 -0.0003 -0.8754 -0.0000
#> 0.0000 -0.0000 0.0000 0.0000 -0.0000
#> -0.0000 0.3983 0.0000 0.0000 -0.0000
#> 2.0135 -0.0170 0.0000 -0.0000 0.0000
#> 0.0000 0.8987 2.5761 0.0000 0.0000
#> [ CPUFloatType{10,5} ]