Rectified Gated Linear Unit (ReGLU) module. Computes the output as \(\text{ReGLU}(x, g) = x \cdot \text{ReLU}(g)\) where \(x\) and \(g\) are created by splitting the input tensor in half along the last dimension.
References
Shazeer N (2020). “GLU Variants Improve Transformer.” 2002.05202, https://arxiv.org/abs/2002.05202.
Examples
x = torch::torch_randn(10, 10)
reglu = nn_reglu()
reglu(x)
#> torch_tensor
#> 1.8164 -0.3219 0.0000 -1.8453 0.0000
#> -0.0000 0.0000 -0.2367 -0.2391 -0.0000
#> 0.0000 -0.0000 0.0000 -0.0000 0.0000
#> 0.2730 0.0000 0.0000 0.0000 0.3078
#> 2.4144 0.0205 0.0000 0.0000 0.1125
#> 0.0000 2.9861 0.0000 -0.0000 -0.0000
#> 0.0000 -0.0000 0.0000 -0.0320 0.6442
#> -1.3642 0.0801 1.5106 -0.5688 0.0000
#> -0.0000 -0.0416 -0.0000 0.0000 0.0000
#> 0.2364 -0.0000 -0.1206 -0.8978 -0.0000
#> [ CPUFloatType{10,5} ]