Builds a torch regression model and trains it.
Parameters
See LearnerTorch
Input and Output Channels
There is one input channel "input"
that takes in ModelDescriptor
during traing and a Task
of the specified
task_type
during prediction.
The output is NULL
during training and a Prediction
of given task_type
during prediction.
State
A trained LearnerTorchModel
.
Internals
A LearnerTorchModel
is created by calling model_descriptor_to_learner()
on the
provided ModelDescriptor
that is received through the input channel.
Then the parameters are set according to the parameters specified in PipeOpTorchModel
and
its '$train() method is called on the [
Task][mlr3::Task] stored in the [
ModelDescriptor`].
See also
Other PipeOps:
mlr_pipeops_nn_adaptive_avg_pool1d
,
mlr_pipeops_nn_adaptive_avg_pool2d
,
mlr_pipeops_nn_adaptive_avg_pool3d
,
mlr_pipeops_nn_avg_pool1d
,
mlr_pipeops_nn_avg_pool2d
,
mlr_pipeops_nn_avg_pool3d
,
mlr_pipeops_nn_batch_norm1d
,
mlr_pipeops_nn_batch_norm2d
,
mlr_pipeops_nn_batch_norm3d
,
mlr_pipeops_nn_block
,
mlr_pipeops_nn_celu
,
mlr_pipeops_nn_conv1d
,
mlr_pipeops_nn_conv2d
,
mlr_pipeops_nn_conv3d
,
mlr_pipeops_nn_conv_transpose1d
,
mlr_pipeops_nn_conv_transpose2d
,
mlr_pipeops_nn_conv_transpose3d
,
mlr_pipeops_nn_dropout
,
mlr_pipeops_nn_elu
,
mlr_pipeops_nn_flatten
,
mlr_pipeops_nn_gelu
,
mlr_pipeops_nn_glu
,
mlr_pipeops_nn_hardshrink
,
mlr_pipeops_nn_hardsigmoid
,
mlr_pipeops_nn_hardtanh
,
mlr_pipeops_nn_head
,
mlr_pipeops_nn_layer_norm
,
mlr_pipeops_nn_leaky_relu
,
mlr_pipeops_nn_linear
,
mlr_pipeops_nn_log_sigmoid
,
mlr_pipeops_nn_max_pool1d
,
mlr_pipeops_nn_max_pool2d
,
mlr_pipeops_nn_max_pool3d
,
mlr_pipeops_nn_merge
,
mlr_pipeops_nn_merge_cat
,
mlr_pipeops_nn_merge_prod
,
mlr_pipeops_nn_merge_sum
,
mlr_pipeops_nn_prelu
,
mlr_pipeops_nn_relu
,
mlr_pipeops_nn_relu6
,
mlr_pipeops_nn_reshape
,
mlr_pipeops_nn_rrelu
,
mlr_pipeops_nn_selu
,
mlr_pipeops_nn_sigmoid
,
mlr_pipeops_nn_softmax
,
mlr_pipeops_nn_softplus
,
mlr_pipeops_nn_softshrink
,
mlr_pipeops_nn_softsign
,
mlr_pipeops_nn_squeeze
,
mlr_pipeops_nn_tanh
,
mlr_pipeops_nn_tanhshrink
,
mlr_pipeops_nn_threshold
,
mlr_pipeops_nn_unsqueeze
,
mlr_pipeops_torch_ingress
,
mlr_pipeops_torch_ingress_categ
,
mlr_pipeops_torch_ingress_ltnsr
,
mlr_pipeops_torch_ingress_num
,
mlr_pipeops_torch_loss
,
mlr_pipeops_torch_model
,
mlr_pipeops_torch_model_classif
Super classes
mlr3pipelines::PipeOp
-> mlr3pipelines::PipeOpLearner
-> mlr3torch::PipeOpTorchModel
-> PipeOpTorchModelRegr
Methods
Method new()
Creates a new instance of this R6 class.
Usage
PipeOpTorchModelRegr$new(id = "torch_model_regr", param_vals = list())
Arguments
id
(
character(1)
)
Identifier of the resulting object.param_vals
(
list()
)
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.
Examples
# simple linear regression
# build the model descriptor
md = as_graph(po("torch_ingress_num") %>>%
po("nn_head") %>>%
po("torch_loss", "mse") %>>%
po("torch_optimizer", "adam"))$train(tsk("mtcars"))[[1L]]
print(md)
#> <ModelDescriptor: 2 ops>
#> * Ingress: torch_ingress_num.input: [(NA,10)]
#> * Task: mtcars [regr]
#> * Callbacks: N/A
#> * Optimizer: Adaptive Moment Estimation
#> * Loss: Mean Squared Error
#> * pointer: nn_head.output [(NA,1)]
# build the learner from the model descriptor and train it
po_model = po("torch_model_regr", batch_size = 20, epochs = 1)
po_model$train(list(md))
#> $output
#> NULL
#>
po_model$state
#> $model
#> $network
#> An `nn_module` containing 11 parameters.
#>
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • module_list: <nn_module_list> #11 parameters
#>
#> $internal_valid_scores
#> NULL
#>
#> $loss_fn
#> list()
#>
#> $optimizer
#> $optimizer$param_groups
#> $optimizer$param_groups[[1]]
#> $optimizer$param_groups[[1]]$params
#> [1] 1 2
#>
#> $optimizer$param_groups[[1]]$lr
#> [1] 0.001
#>
#> $optimizer$param_groups[[1]]$betas
#> [1] 0.900 0.999
#>
#> $optimizer$param_groups[[1]]$eps
#> [1] 1e-08
#>
#> $optimizer$param_groups[[1]]$weight_decay
#> [1] 0
#>
#> $optimizer$param_groups[[1]]$amsgrad
#> [1] FALSE
#>
#>
#>
#> $optimizer$state
#> $optimizer$state$`1`
#> $optimizer$state$`1`$step
#> [1] 2
#>
#> $optimizer$state$`1`$exp_avg
#> torch_tensor
#> -3.6073 -12.0676 -24.0698 -664.1591 -19.2834 -20.4528 -561.1848 -88.6206 -3.3210 -11.9917
#> [ CPUFloatType{1,10} ]
#>
#> $optimizer$state$`1`$exp_avg_sq
#> torch_tensor
#> Columns 1 to 6 7.3793e-01 8.2739e+00 3.2016e+01 2.4739e+04 2.0536e+01 2.3239e+01
#>
#> Columns 7 to 10 1.8256e+04 4.3417e+02 6.2616e-01 7.9415e+00
#> [ CPUFloatType{1,10} ]
#>
#>
#> $optimizer$state$`2`
#> $optimizer$state$`2`$step
#> [1] 2
#>
#> $optimizer$state$`2`$exp_avg
#> torch_tensor
#> -4.9108
#> [ CPUFloatType{1} ]
#>
#> $optimizer$state$`2`$exp_avg_sq
#> torch_tensor
#> 1.3321
#> [ CPUFloatType{1} ]
#>
#>
#>
#>
#> $epochs
#> [1] 1
#>
#> $callbacks
#> named list()
#>
#> $seed
#> [1] 1984753776
#>
#> $task_col_info
#> id type levels
#> <char> <char> <list>
#> 1: am numeric [NULL]
#> 2: carb numeric [NULL]
#> 3: cyl numeric [NULL]
#> 4: disp numeric [NULL]
#> 5: drat numeric [NULL]
#> 6: gear numeric [NULL]
#> 7: hp numeric [NULL]
#> 8: qsec numeric [NULL]
#> 9: vs numeric [NULL]
#> 10: wt numeric [NULL]
#> 11: mpg numeric [NULL]
#>
#> attr(,"class")
#> [1] "learner_torch_model" "list"
#>
#> $log
#> Empty data.table (0 rows and 3 cols): stage,class,msg
#>
#> $train_time
#> [1] 0.04
#>
#> $param_vals
#> $param_vals$epochs
#> [1] 1
#>
#> $param_vals$device
#> [1] "auto"
#>
#> $param_vals$num_threads
#> [1] 1
#>
#> $param_vals$num_interop_threads
#> [1] 1
#>
#> $param_vals$seed
#> [1] "random"
#>
#> $param_vals$eval_freq
#> [1] 1
#>
#> $param_vals$measures_train
#> list()
#>
#> $param_vals$measures_valid
#> list()
#>
#> $param_vals$patience
#> [1] 0
#>
#> $param_vals$min_delta
#> [1] 0
#>
#> $param_vals$batch_size
#> [1] 20
#>
#>
#> $task_hash
#> [1] "6fbb93eed42adb31"
#>
#> $feature_names
#> [1] "am" "carb" "cyl" "disp" "drat" "gear" "hp" "qsec" "vs" "wt"
#>
#> $validate
#> NULL
#>
#> $mlr3_version
#> [1] ‘0.21.1’
#>
#> $internal_tuned_values
#> named list()
#>
#> $data_prototype
#> Empty data.table (0 rows and 11 cols): mpg,am,carb,cyl,disp,drat...
#>
#> $task_prototype
#> Empty data.table (0 rows and 11 cols): mpg,am,carb,cyl,disp,drat...
#>
#> $train_task
#> <TaskRegr:mtcars> (32 x 11): Motor Trends
#> * Target: mpg
#> * Properties: -
#> * Features (10):
#> - dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt
#>
#> attr(,"class")
#> [1] "learner_state" "list"