Rectified Gated Linear Unit (ReGLU) module. Computes the output as \(\text{ReGLU}(x, g) = x \cdot \text{ReLU}(g)\) where \(x\) and \(g\) are created by splitting the input tensor in half along the last dimension.
References
Shazeer N (2020). “GLU Variants Improve Transformer.” 2002.05202, https://arxiv.org/abs/2002.05202.
Examples
x = torch::torch_randn(10, 10)
reglu = nn_reglu()
reglu(x)
#> torch_tensor
#> 0.7128 -0.0000 0.0000 -0.0454 0.0000
#> -0.0000 -0.1432 -1.2471 -0.0488 -0.0000
#> -0.0000 -0.0000 0.0014 -1.6710 0.0560
#> 0.0000 0.0000 0.0000 -0.0000 -0.0000
#> 0.0000 -0.0000 -0.0000 -0.3214 0.0000
#> 0.0000 -0.1017 -1.1942 0.4999 -0.6515
#> -0.0000 -0.0000 -1.4429 0.0280 -0.0000
#> 0.3413 0.5825 0.2722 -0.4106 -0.0000
#> 0.0244 0.2014 -0.0000 -0.0000 -0.0000
#> 0.0000 0.0000 5.7381 -0.0000 0.0000
#> [ CPUFloatType{10,5} ]