Applies a user-supplied function to a tensor.
Parameters
By default, these are inferred as all but the first arguments of the function fn
.
It is also possible to specify these more explicitly via the param_set
constructor argument.
Input and Output Channels
One input channel called "input"
and one output channel called "output"
.
For an explanation see PipeOpTorch
.
Super classes
mlr3pipelines::PipeOp
-> mlr3torch::PipeOpTorch
-> PipeOpTorchFn
Methods
Method new()
Creates a new instance of this R6
class.
Usage
PipeOpTorchFn$new(
fn,
id = "nn_fn",
param_vals = list(),
param_set = NULL,
shapes_out = NULL
)
Arguments
fn
(
function
)
The function to be applied. Takes atorch
tensor as first argument and returns atorch
tensor.id
(
character(1)
)
Identifier of the resulting object.param_vals
(
list()
)
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.param_set
(
ParamSet
orNULL
)
A ParamSet wrapping the arguments tofn
. If omitted, then the ParamSet for this PipeOp will be inferred from the function signature.shapes_out
(
function
orNULL
)
A function that computes the output shapes of thefn
. See PipeOpTorch's.shapes_out()
method for details on the parameters, and PipeOpTaskPreprocTorch for details on how the shapes are inferred when this parameter isNULL
.
Examples
custom_fn = function(x, a) x / a
obj = po("nn_fn", fn = custom_fn, a = 2)
obj$param_set
#> <ParamSet(1)>
#> id class lower upper nlevels default value
#> <char> <char> <num> <num> <num> <list> <list>
#> 1: a ParamUty NA NA Inf <NoDefault[0]> 2
graph = po("torch_ingress_ltnsr") %>>% obj
task = tsk("lazy_iris")$filter(1)
tnsr = materialize(task$data()$x)[[1]]
md_trained = graph$train(task)
trained = md_trained[[1]]$graph$train(tnsr)
trained[[1]]
#> torch_tensor
#> 2.5500
#> 1.7500
#> 0.7000
#> 0.1000
#> [ CPUFloatType{4} ]
custom_fn(tnsr, a = 2)
#> torch_tensor
#> 2.5500
#> 1.7500
#> 0.7000
#> 0.1000
#> [ CPUFloatType{4} ]