Builds a torch classifier and trains it.
Parameters
See LearnerTorch
Input and Output Channels
There is one input channel "input"
that takes in ModelDescriptor
during traing and a Task
of the specified
task_type
during prediction.
The output is NULL
during training and a Prediction
of given task_type
during prediction.
State
A trained LearnerTorchModel
.
Internals
A LearnerTorchModel
is created by calling model_descriptor_to_learner()
on the
provided ModelDescriptor
that is received through the input channel.
Then the parameters are set according to the parameters specified in PipeOpTorchModel
and
its '$train() method is called on the [
Task][mlr3::Task] stored in the [
ModelDescriptor`].
See also
Other PipeOps:
mlr_pipeops_nn_adaptive_avg_pool1d
,
mlr_pipeops_nn_adaptive_avg_pool2d
,
mlr_pipeops_nn_adaptive_avg_pool3d
,
mlr_pipeops_nn_avg_pool1d
,
mlr_pipeops_nn_avg_pool2d
,
mlr_pipeops_nn_avg_pool3d
,
mlr_pipeops_nn_batch_norm1d
,
mlr_pipeops_nn_batch_norm2d
,
mlr_pipeops_nn_batch_norm3d
,
mlr_pipeops_nn_block
,
mlr_pipeops_nn_celu
,
mlr_pipeops_nn_conv1d
,
mlr_pipeops_nn_conv2d
,
mlr_pipeops_nn_conv3d
,
mlr_pipeops_nn_conv_transpose1d
,
mlr_pipeops_nn_conv_transpose2d
,
mlr_pipeops_nn_conv_transpose3d
,
mlr_pipeops_nn_dropout
,
mlr_pipeops_nn_elu
,
mlr_pipeops_nn_flatten
,
mlr_pipeops_nn_gelu
,
mlr_pipeops_nn_glu
,
mlr_pipeops_nn_hardshrink
,
mlr_pipeops_nn_hardsigmoid
,
mlr_pipeops_nn_hardtanh
,
mlr_pipeops_nn_head
,
mlr_pipeops_nn_layer_norm
,
mlr_pipeops_nn_leaky_relu
,
mlr_pipeops_nn_linear
,
mlr_pipeops_nn_log_sigmoid
,
mlr_pipeops_nn_max_pool1d
,
mlr_pipeops_nn_max_pool2d
,
mlr_pipeops_nn_max_pool3d
,
mlr_pipeops_nn_merge
,
mlr_pipeops_nn_merge_cat
,
mlr_pipeops_nn_merge_prod
,
mlr_pipeops_nn_merge_sum
,
mlr_pipeops_nn_prelu
,
mlr_pipeops_nn_relu
,
mlr_pipeops_nn_relu6
,
mlr_pipeops_nn_reshape
,
mlr_pipeops_nn_rrelu
,
mlr_pipeops_nn_selu
,
mlr_pipeops_nn_sigmoid
,
mlr_pipeops_nn_softmax
,
mlr_pipeops_nn_softplus
,
mlr_pipeops_nn_softshrink
,
mlr_pipeops_nn_softsign
,
mlr_pipeops_nn_squeeze
,
mlr_pipeops_nn_tanh
,
mlr_pipeops_nn_tanhshrink
,
mlr_pipeops_nn_threshold
,
mlr_pipeops_nn_unsqueeze
,
mlr_pipeops_torch_ingress
,
mlr_pipeops_torch_ingress_categ
,
mlr_pipeops_torch_ingress_ltnsr
,
mlr_pipeops_torch_ingress_num
,
mlr_pipeops_torch_loss
,
mlr_pipeops_torch_model
,
mlr_pipeops_torch_model_regr
Super classes
mlr3pipelines::PipeOp
-> mlr3pipelines::PipeOpLearner
-> mlr3torch::PipeOpTorchModel
-> PipeOpTorchModelClassif
Methods
Method new()
Creates a new instance of this R6 class.
Usage
PipeOpTorchModelClassif$new(id = "torch_model_classif", param_vals = list())
Arguments
id
(
character(1)
)
Identifier of the resulting object.param_vals
(
list()
)
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.
Examples
# simple logistic regression
# configure the model descriptor
md = as_graph(po("torch_ingress_num") %>>%
po("nn_head") %>>%
po("torch_loss", "cross_entropy") %>>%
po("torch_optimizer", "adam"))$train(tsk("iris"))[[1L]]
print(md)
#> <ModelDescriptor: 2 ops>
#> * Ingress: torch_ingress_num.input: [(NA,4)]
#> * Task: iris [classif]
#> * Callbacks: N/A
#> * Optimizer: Adaptive Moment Estimation
#> * Loss: Cross Entropy
#> * pointer: nn_head.output [(NA,3)]
# build the learner from the model descriptor and train it
po_model = po("torch_model_classif", batch_size = 50, epochs = 1)
po_model$train(list(md))
#> $output
#> NULL
#>
po_model$state
#> $model
#> $network
#> An `nn_module` containing 15 parameters.
#>
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • module_list: <nn_module_list> #15 parameters
#>
#> $internal_valid_scores
#> NULL
#>
#> $loss_fn
#> list()
#>
#> $optimizer
#> $optimizer$param_groups
#> $optimizer$param_groups[[1]]
#> $optimizer$param_groups[[1]]$params
#> [1] 1 2
#>
#> $optimizer$param_groups[[1]]$lr
#> [1] 0.001
#>
#> $optimizer$param_groups[[1]]$betas
#> [1] 0.900 0.999
#>
#> $optimizer$param_groups[[1]]$eps
#> [1] 1e-08
#>
#> $optimizer$param_groups[[1]]$weight_decay
#> [1] 0
#>
#> $optimizer$param_groups[[1]]$amsgrad
#> [1] FALSE
#>
#>
#>
#> $optimizer$state
#> $optimizer$state$`1`
#> $optimizer$state$`1`$step
#> [1] 3
#>
#> $optimizer$state$`1`$exp_avg
#> torch_tensor
#> 0.1824 0.0781 0.0404 -0.0493
#> -0.0506 -0.0145 0.0062 0.0439
#> -0.1318 -0.0637 -0.0466 0.0055
#> [ CPUFloatType{3,4} ]
#>
#> $optimizer$state$`1`$exp_avg_sq
#> torch_tensor
#> 0.01 *
#> 0.5263 0.0525 2.0401 0.7900
#> 1.1783 0.1205 2.5981 0.6958
#> 1.3648 0.1695 2.2630 0.5236
#> [ CPUFloatType{3,4} ]
#>
#>
#> $optimizer$state$`2`
#> $optimizer$state$`2`$step
#> [1] 3
#>
#> $optimizer$state$`2`$exp_avg
#> torch_tensor
#> 0.001 *
#> -5.6068
#> 3.9102
#> 1.6966
#> [ CPUFloatType{3} ]
#>
#> $optimizer$state$`2`$exp_avg_sq
#> torch_tensor
#> 0.0001 *
#> 7.1996
#> 7.7785
#> 5.8956
#> [ CPUFloatType{3} ]
#>
#>
#>
#>
#> $epochs
#> [1] 1
#>
#> $callbacks
#> named list()
#>
#> $seed
#> [1] 1285070263
#>
#> $task_col_info
#> Key: <id>
#> id type levels
#> <char> <char> <list>
#> 1: Petal.Length numeric [NULL]
#> 2: Petal.Width numeric [NULL]
#> 3: Sepal.Length numeric [NULL]
#> 4: Sepal.Width numeric [NULL]
#> 5: Species factor setosa,versicolor,virginica
#>
#> attr(,"class")
#> [1] "learner_torch_model" "list"
#>
#> $log
#> Empty data.table (0 rows and 3 cols): stage,class,msg
#>
#> $train_time
#> [1] 0.055
#>
#> $param_vals
#> $param_vals$epochs
#> [1] 1
#>
#> $param_vals$device
#> [1] "auto"
#>
#> $param_vals$num_threads
#> [1] 1
#>
#> $param_vals$num_interop_threads
#> [1] 1
#>
#> $param_vals$seed
#> [1] "random"
#>
#> $param_vals$eval_freq
#> [1] 1
#>
#> $param_vals$measures_train
#> list()
#>
#> $param_vals$measures_valid
#> list()
#>
#> $param_vals$patience
#> [1] 0
#>
#> $param_vals$min_delta
#> [1] 0
#>
#> $param_vals$batch_size
#> [1] 50
#>
#>
#> $task_hash
#> [1] "34a6beb27c3181c3"
#>
#> $feature_names
#> [1] "Petal.Length" "Petal.Width" "Sepal.Length" "Sepal.Width"
#>
#> $validate
#> NULL
#>
#> $mlr3_version
#> [1] ‘0.21.1’
#>
#> $internal_tuned_values
#> named list()
#>
#> $data_prototype
#> Empty data.table (0 rows and 5 cols): Species,Petal.Length,Petal.Width,Sepal.Length,Sepal.Width
#>
#> $task_prototype
#> Empty data.table (0 rows and 5 cols): Species,Petal.Length,Petal.Width,Sepal.Length,Sepal.Width
#>
#> $train_task
#> <TaskClassif:iris> (150 x 5): Iris Flowers
#> * Target: Species
#> * Properties: multiclass
#> * Features (4):
#> - dbl (4): Petal.Length, Petal.Width, Sepal.Length, Sepal.Width
#>
#> attr(,"class")
#> [1] "learner_state" "list"