Builds a torch classifier and trains it.
Parameters
See LearnerTorch
Input and Output Channels
There is one input channel "input" that takes in ModelDescriptor during traing and a Task of the specified
task_type during prediction.
The output is NULL during training and a Prediction of given task_type during prediction.
State
A trained LearnerTorchModel.
Internals
A LearnerTorchModel is created by calling model_descriptor_to_learner() on the
provided ModelDescriptor that is received through the input channel.
Then the parameters are set according to the parameters specified in PipeOpTorchModel and
its '$train() method is called on the [Task][mlr3::Task] stored in the [ModelDescriptor`].
See also
Other PipeOps:
mlr_pipeops_nn_adaptive_avg_pool1d,
mlr_pipeops_nn_adaptive_avg_pool2d,
mlr_pipeops_nn_adaptive_avg_pool3d,
mlr_pipeops_nn_avg_pool1d,
mlr_pipeops_nn_avg_pool2d,
mlr_pipeops_nn_avg_pool3d,
mlr_pipeops_nn_batch_norm1d,
mlr_pipeops_nn_batch_norm2d,
mlr_pipeops_nn_batch_norm3d,
mlr_pipeops_nn_block,
mlr_pipeops_nn_celu,
mlr_pipeops_nn_conv1d,
mlr_pipeops_nn_conv2d,
mlr_pipeops_nn_conv3d,
mlr_pipeops_nn_conv_transpose1d,
mlr_pipeops_nn_conv_transpose2d,
mlr_pipeops_nn_conv_transpose3d,
mlr_pipeops_nn_dropout,
mlr_pipeops_nn_elu,
mlr_pipeops_nn_flatten,
mlr_pipeops_nn_ft_cls,
mlr_pipeops_nn_ft_transformer_block,
mlr_pipeops_nn_geglu,
mlr_pipeops_nn_gelu,
mlr_pipeops_nn_glu,
mlr_pipeops_nn_hardshrink,
mlr_pipeops_nn_hardsigmoid,
mlr_pipeops_nn_hardtanh,
mlr_pipeops_nn_head,
mlr_pipeops_nn_identity,
mlr_pipeops_nn_layer_norm,
mlr_pipeops_nn_leaky_relu,
mlr_pipeops_nn_linear,
mlr_pipeops_nn_log_sigmoid,
mlr_pipeops_nn_max_pool1d,
mlr_pipeops_nn_max_pool2d,
mlr_pipeops_nn_max_pool3d,
mlr_pipeops_nn_merge,
mlr_pipeops_nn_merge_cat,
mlr_pipeops_nn_merge_prod,
mlr_pipeops_nn_merge_sum,
mlr_pipeops_nn_prelu,
mlr_pipeops_nn_reglu,
mlr_pipeops_nn_relu,
mlr_pipeops_nn_relu6,
mlr_pipeops_nn_reshape,
mlr_pipeops_nn_rrelu,
mlr_pipeops_nn_selu,
mlr_pipeops_nn_sigmoid,
mlr_pipeops_nn_softmax,
mlr_pipeops_nn_softplus,
mlr_pipeops_nn_softshrink,
mlr_pipeops_nn_softsign,
mlr_pipeops_nn_squeeze,
mlr_pipeops_nn_tanh,
mlr_pipeops_nn_tanhshrink,
mlr_pipeops_nn_threshold,
mlr_pipeops_nn_tokenizer_categ,
mlr_pipeops_nn_tokenizer_num,
mlr_pipeops_nn_unsqueeze,
mlr_pipeops_torch_ingress,
mlr_pipeops_torch_ingress_categ,
mlr_pipeops_torch_ingress_ltnsr,
mlr_pipeops_torch_ingress_num,
mlr_pipeops_torch_loss,
mlr_pipeops_torch_model,
mlr_pipeops_torch_model_regr
Super classes
mlr3pipelines::PipeOp -> mlr3pipelines::PipeOpLearner -> mlr3torch::PipeOpTorchModel -> PipeOpTorchModelClassif
Methods
Method new()
Creates a new instance of this R6 class.
Usage
PipeOpTorchModelClassif$new(id = "torch_model_classif", param_vals = list())Arguments
id(
character(1))
Identifier of the resulting object.param_vals(
list())
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.
Examples
# simple logistic regression
# configure the model descriptor
md = as_graph(po("torch_ingress_num") %>>%
po("nn_head") %>>%
po("torch_loss", "cross_entropy") %>>%
po("torch_optimizer", "adam"))$train(tsk("iris"))[[1L]]
print(md)
#> <ModelDescriptor: 2 ops>
#> * Ingress: torch_ingress_num.input: [(NA,4)]
#> * Task: iris [classif]
#> * Callbacks: N/A
#> * Optimizer: Adaptive Moment Estimation
#> * Loss: Cross Entropy
#> * pointer: nn_head.output [(NA,3)]
# build the learner from the model descriptor and train it
po_model = po("torch_model_classif", batch_size = 50, epochs = 1)
po_model$train(list(md))
#> $output
#> NULL
#>
po_model$state
#> $model
#> $network
#> An `nn_module` containing 15 parameters.
#>
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • module_list: <nn_module_list> #15 parameters
#>
#> $internal_valid_scores
#> NULL
#>
#> $loss_fn
#> list()
#>
#> $optimizer
#> $optimizer$param_groups
#> $optimizer$param_groups[[1]]
#> $optimizer$param_groups[[1]]$params
#> [1] 1 2
#>
#> $optimizer$param_groups[[1]]$lr
#> [1] 0.001
#>
#> $optimizer$param_groups[[1]]$weight_decay
#> [1] 0
#>
#> $optimizer$param_groups[[1]]$betas
#> [1] 0.900 0.999
#>
#> $optimizer$param_groups[[1]]$eps
#> [1] 1e-08
#>
#> $optimizer$param_groups[[1]]$amsgrad
#> [1] FALSE
#>
#>
#>
#> $optimizer$state
#> $optimizer$state$`1`
#> $optimizer$state$`1`$exp_avg
#> torch_tensor
#> -0.1204 -0.0188 -0.4292 -0.2963
#> 0.5516 0.1799 0.9056 0.4998
#> -0.4312 -0.1610 -0.4764 -0.2036
#> [ CPUFloatType{3,4} ]
#>
#> $optimizer$state$`1`$exp_avg_sq
#> torch_tensor
#> 0.01 *
#> 0.0590 0.0015 0.7501 0.3571
#> 1.2417 0.1321 3.3294 1.0125
#> 0.7645 0.1061 0.9374 0.1709
#> [ CPUFloatType{3,4} ]
#>
#> $optimizer$state$`1`$max_exp_avg_sq
#> torch_tensor
#> [ CPUFloatType{0} ]
#>
#> $optimizer$state$`1`$step
#> torch_tensor
#> 3
#> [ CPULongType{1} ]
#>
#>
#> $optimizer$state$`2`
#> $optimizer$state$`2`$exp_avg
#> torch_tensor
#> -0.0860
#> 0.1552
#> -0.0692
#> [ CPUFloatType{3} ]
#>
#> $optimizer$state$`2`$exp_avg_sq
#> torch_tensor
#> 0.0001 *
#> 3.0104
#> 9.7702
#> 1.9676
#> [ CPUFloatType{3} ]
#>
#> $optimizer$state$`2`$max_exp_avg_sq
#> torch_tensor
#> [ CPUFloatType{0} ]
#>
#> $optimizer$state$`2`$step
#> torch_tensor
#> 3
#> [ CPULongType{1} ]
#>
#>
#>
#>
#> $epochs
#> [1] 1
#>
#> $callbacks
#> named list()
#>
#> $seed
#> [1] 415631006
#>
#> $task_col_info
#> Key: <id>
#> id type levels
#> <char> <char> <list>
#> 1: Petal.Length numeric [NULL]
#> 2: Petal.Width numeric [NULL]
#> 3: Sepal.Length numeric [NULL]
#> 4: Sepal.Width numeric [NULL]
#> 5: Species factor setosa,versicolor,virginica
#>
#> attr(,"class")
#> [1] "learner_torch_model" "list"
#>
#> $param_vals
#> $param_vals$epochs
#> [1] 1
#>
#> $param_vals$device
#> [1] "auto"
#>
#> $param_vals$num_threads
#> [1] 1
#>
#> $param_vals$num_interop_threads
#> [1] 1
#>
#> $param_vals$seed
#> [1] "random"
#>
#> $param_vals$eval_freq
#> [1] 1
#>
#> $param_vals$measures_train
#> list()
#>
#> $param_vals$measures_valid
#> list()
#>
#> $param_vals$patience
#> [1] 0
#>
#> $param_vals$min_delta
#> [1] 0
#>
#> $param_vals$batch_size
#> [1] 50
#>
#> $param_vals$shuffle
#> [1] TRUE
#>
#> $param_vals$tensor_dataset
#> [1] FALSE
#>
#> $param_vals$jit_trace
#> [1] FALSE
#>
#>
#> $log
#> Empty data.table (0 rows and 3 cols): stage,class,msg
#>
#> $train_time
#> [1] 0.062
#>
#> $task_hash
#> [1] "abc694dd29a7a8ce"
#>
#> $feature_names
#> [1] "Petal.Length" "Petal.Width" "Sepal.Length" "Sepal.Width"
#>
#> $validate
#> NULL
#>
#> $mlr3_version
#> [1] ‘1.2.0’
#>
#> $internal_tuned_values
#> named list()
#>
#> $data_prototype
#> Empty data.table (0 rows and 5 cols): Species,Petal.Length,Petal.Width,Sepal.Length,Sepal.Width
#>
#> $task_prototype
#> Empty data.table (0 rows and 5 cols): Species,Petal.Length,Petal.Width,Sepal.Length,Sepal.Width
#>
#> $train_task
#>
#> ── <TaskClassif> (150x5): Iris Flowers ─────────────────────────────────────────
#> • Target: Species
#> • Target classes: setosa, versicolor, virginica
#> • Properties: multiclass
#> • Features (4):
#> • dbl (4): Petal.Length, Petal.Width, Sepal.Length, Sepal.Width
#>
#> attr(,"class")
#> [1] "learner_state" "list"