Rectified Gated Linear Unit (ReGLU) module. Computes the output as \(\text{ReGLU}(x, g) = x \cdot \text{ReLU}(g)\) where \(x\) and \(g\) are created by splitting the input tensor in half along the last dimension.
References
Shazeer N (2020). “GLU Variants Improve Transformer.” 2002.05202, https://arxiv.org/abs/2002.05202.
Examples
x = torch::torch_randn(10, 10)
reglu = nn_reglu()
reglu(x)
#> torch_tensor
#> -0.0000 0.0561 -0.0000 0.9413 0.5324
#> 0.1230 -0.0000 -1.0723 -0.0430 -0.6069
#> -0.0000 0.2986 0.0000 -0.0000 0.1807
#> -0.0000 0.2848 0.0582 -0.0997 0.0276
#> 0.3495 -0.0374 -0.0000 -0.0000 0.0000
#> -0.0000 0.0000 -0.0000 -1.0855 0.9958
#> -0.0000 0.0000 0.0000 0.0000 -0.6228
#> 0.3993 -0.0000 0.0000 -2.5017 -0.0542
#> -0.0000 -2.3652 0.0000 0.0000 -0.0000
#> -0.0000 0.2234 -0.0000 -0.0000 0.0000
#> [ CPUFloatType{10,5} ]