Rectified Gated Linear Unit (ReGLU) module. Computes the output as \(\text{ReGLU}(x, g) = x \cdot \text{ReLU}(g)\) where \(x\) and \(g\) are created by splitting the input tensor in half along the last dimension.
References
Shazeer N (2020). “GLU Variants Improve Transformer.” 2002.05202, https://arxiv.org/abs/2002.05202.
Examples
x = torch::torch_randn(10, 10)
reglu = nn_reglu()
reglu(x)
#> torch_tensor
#> 0.4318 -0.2161 0.0000 0.0572 0.0000
#> -0.0000 -0.0000 -0.1832 -0.7971 -0.0000
#> 0.4366 -0.0000 0.1494 0.3086 -0.0000
#> 0.9792 0.0000 -0.0000 0.1754 -0.0000
#> -0.3526 -0.0000 -0.7102 -0.0000 -0.0000
#> 0.7202 -0.3128 0.0000 -0.5192 0.8357
#> -0.0000 0.0000 -0.0000 0.0665 0.2250
#> 0.2766 0.4052 -0.0000 0.0000 -3.2107
#> 0.0000 0.0000 -0.0000 -0.6722 -0.0000
#> 0.0000 0.0000 -0.1028 -0.0000 -0.0000
#> [ CPUFloatType{10,5} ]