Skip to contents

Builds a torch classifier and trains it.

Parameters

See LearnerTorch

Input and Output Channels

There is one input channel "input" that takes in ModelDescriptor during traing and a Task of the specified task_type during prediction. The output is NULL during training and a Prediction of given task_type during prediction.

State

A trained LearnerTorchModel.

Internals

A LearnerTorchModel is created by calling model_descriptor_to_learner() on the provided ModelDescriptor that is received through the input channel. Then the parameters are set according to the parameters specified in PipeOpTorchModel and its '$train() method is called on the [Task][mlr3::Task] stored in the [ModelDescriptor`].

Super classes

mlr3pipelines::PipeOp -> mlr3pipelines::PipeOpLearner -> mlr3torch::PipeOpTorchModel -> PipeOpTorchModelClassif

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

PipeOpTorchModelClassif$new(id = "torch_model_classif", param_vals = list())

Arguments

id

(character(1))
Identifier of the resulting object.

param_vals

(list())
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.


Method clone()

The objects of this class are cloneable with this method.

Usage

PipeOpTorchModelClassif$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# simple logistic regression

# configure the model descriptor
md = as_graph(po("torch_ingress_num") %>>%
  po("nn_head") %>>%
  po("torch_loss", "cross_entropy") %>>%
  po("torch_optimizer", "adam"))$train(tsk("iris"))[[1L]]

print(md)
#> <ModelDescriptor: 2 ops>
#> * Ingress:  torch_ingress_num.input: [(NA,4)]
#> * Task:  iris [classif]
#> * Callbacks:  N/A
#> * Optimizer:  Adaptive Moment Estimation
#> * Loss:  Cross Entropy
#> * pointer:  nn_head.output [(NA,3)]

# build the learner from the model descriptor and train it
po_model = po("torch_model_classif", batch_size = 50, epochs = 1)
po_model$train(list(md))
#> $output
#> NULL
#> 
po_model$state
#> $model
#> $network
#> An `nn_module` containing 15 parameters.
#> 
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • module_list: <nn_module_list> #15 parameters
#> 
#> $internal_valid_scores
#> NULL
#> 
#> $loss_fn
#> list()
#> 
#> $optimizer
#> $optimizer$param_groups
#> $optimizer$param_groups[[1]]
#> $optimizer$param_groups[[1]]$params
#> [1] 1 2
#> 
#> $optimizer$param_groups[[1]]$lr
#> [1] 0.001
#> 
#> $optimizer$param_groups[[1]]$betas
#> [1] 0.900 0.999
#> 
#> $optimizer$param_groups[[1]]$eps
#> [1] 1e-08
#> 
#> $optimizer$param_groups[[1]]$weight_decay
#> [1] 0
#> 
#> $optimizer$param_groups[[1]]$amsgrad
#> [1] FALSE
#> 
#> 
#> 
#> $optimizer$state
#> $optimizer$state$`1`
#> $optimizer$state$`1`$step
#> [1] 3
#> 
#> $optimizer$state$`1`$exp_avg
#> torch_tensor
#> -0.0899 -0.0134 -0.3285 -0.2275
#>  0.5184  0.1767  0.7711  0.4077
#> -0.4285 -0.1632 -0.4426 -0.1802
#> [ CPUFloatType{3,4} ]
#> 
#> $optimizer$state$`1`$exp_avg_sq
#> torch_tensor
#> 0.01 *
#>  0.1531  0.0044  1.7816  0.8325
#>   2.6082  0.3350  4.5874  1.2071
#>   2.5581  0.3371  3.7053  0.7826
#> [ CPUFloatType{3,4} ]
#> 
#> 
#> $optimizer$state$`2`
#> $optimizer$state$`2`$step
#> [1] 3
#> 
#> $optimizer$state$`2`$exp_avg
#> torch_tensor
#> -0.0661
#>  0.1280
#> -0.0619
#> [ CPUFloatType{3} ]
#> 
#> $optimizer$state$`2`$exp_avg_sq
#> torch_tensor
#> 0.001 *
#>  0.7107
#>  1.2320
#>  0.8732
#> [ CPUFloatType{3} ]
#> 
#> 
#> 
#> 
#> $epochs
#> [1] 1
#> 
#> $callbacks
#> named list()
#> 
#> $seed
#> [1] 1285070263
#> 
#> $task_col_info
#> Key: <id>
#>              id    type                      levels
#>          <char>  <char>                      <list>
#> 1: Petal.Length numeric                      [NULL]
#> 2:  Petal.Width numeric                      [NULL]
#> 3: Sepal.Length numeric                      [NULL]
#> 4:  Sepal.Width numeric                      [NULL]
#> 5:      Species  factor setosa,versicolor,virginica
#> 
#> attr(,"class")
#> [1] "learner_torch_model" "list"               
#> 
#> $log
#> Empty data.table (0 rows and 3 cols): stage,class,msg
#> 
#> $train_time
#> [1] 0.054
#> 
#> $param_vals
#> $param_vals$epochs
#> [1] 1
#> 
#> $param_vals$device
#> [1] "auto"
#> 
#> $param_vals$num_threads
#> [1] 1
#> 
#> $param_vals$seed
#> [1] "random"
#> 
#> $param_vals$eval_freq
#> [1] 1
#> 
#> $param_vals$measures_train
#> list()
#> 
#> $param_vals$measures_valid
#> list()
#> 
#> $param_vals$patience
#> [1] 0
#> 
#> $param_vals$min_delta
#> [1] 0
#> 
#> $param_vals$batch_size
#> [1] 50
#> 
#> 
#> $task_hash
#> [1] "34a6beb27c3181c3"
#> 
#> $feature_names
#> [1] "Petal.Length" "Petal.Width"  "Sepal.Length" "Sepal.Width" 
#> 
#> $validate
#> NULL
#> 
#> $mlr3_version
#> [1] ‘0.21.0’
#> 
#> $internal_tuned_values
#> named list()
#> 
#> $data_prototype
#> Empty data.table (0 rows and 5 cols): Species,Petal.Length,Petal.Width,Sepal.Length,Sepal.Width
#> 
#> $task_prototype
#> Empty data.table (0 rows and 5 cols): Species,Petal.Length,Petal.Width,Sepal.Length,Sepal.Width
#> 
#> $train_task
#> <TaskClassif:iris> (150 x 5): Iris Flowers
#> * Target: Species
#> * Properties: multiclass
#> * Features (4):
#>   - dbl (4): Petal.Length, Petal.Width, Sepal.Length, Sepal.Width
#> 
#> attr(,"class")
#> [1] "learner_state" "list"