Skip to contents

Builds a torch regression model and trains it.

Parameters

See LearnerTorch

Input and Output Channels

There is one input channel "input" that takes in ModelDescriptor during traing and a Task of the specified task_type during prediction. The output is NULL during training and a Prediction of given task_type during prediction.

State

A trained LearnerTorchModel.

Internals

A LearnerTorchModel is created by calling model_descriptor_to_learner() on the provided ModelDescriptor that is received through the input channel. Then the parameters are set according to the parameters specified in PipeOpTorchModel and its '$train() method is called on the [Task][mlr3::Task] stored in the [ModelDescriptor`].

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

PipeOpTorchModelRegr$new(id = "torch_model_regr", param_vals = list())

Arguments

id

(character(1))
Identifier of the resulting object.

param_vals

(list())
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.


Method clone()

The objects of this class are cloneable with this method.

Usage

PipeOpTorchModelRegr$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# simple linear regression

# build the model descriptor
md = as_graph(po("torch_ingress_num") %>>%
  po("nn_head") %>>%
  po("torch_loss", "mse") %>>%
  po("torch_optimizer", "adam"))$train(tsk("mtcars"))[[1L]]

print(md)
#> <ModelDescriptor: 2 ops>
#> * Ingress:  torch_ingress_num.input: [(NA,10)]
#> * Task:  mtcars [regr]
#> * Callbacks:  N/A
#> * Optimizer:  Adaptive Moment Estimation
#> * Loss:  Mean Squared Error
#> * pointer:  nn_head.output [(NA,1)]

# build the learner from the model descriptor and train it
po_model = po("torch_model_regr", batch_size = 20, epochs = 1)
po_model$train(list(md))
#> $output
#> NULL
#> 
po_model$state
#> $model
#> $network
#> An `nn_module` containing 11 parameters.
#> 
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • module_list: <nn_module_list> #11 parameters
#> 
#> $internal_valid_scores
#> NULL
#> 
#> $loss_fn
#> list()
#> 
#> $optimizer
#> $optimizer$param_groups
#> $optimizer$param_groups[[1]]
#> $optimizer$param_groups[[1]]$params
#> [1] 1 2
#> 
#> $optimizer$param_groups[[1]]$lr
#> [1] 0.001
#> 
#> $optimizer$param_groups[[1]]$betas
#> [1] 0.900 0.999
#> 
#> $optimizer$param_groups[[1]]$eps
#> [1] 1e-08
#> 
#> $optimizer$param_groups[[1]]$weight_decay
#> [1] 0
#> 
#> $optimizer$param_groups[[1]]$amsgrad
#> [1] FALSE
#> 
#> 
#> 
#> $optimizer$state
#> $optimizer$state$`1`
#> $optimizer$state$`1`$step
#> [1] 2
#> 
#> $optimizer$state$`1`$exp_avg
#> torch_tensor
#> -4.5992 -35.7493 -88.8332 -3538.9309 -48.4815 -49.0613 -2061.8740 -244.7690 -5.1170 -45.9065
#> [ CPUFloatType{1,10} ]
#> 
#> $optimizer$state$`1`$exp_avg_sq
#> torch_tensor
#> Columns 1 to 6 1.1836e+00  7.4080e+01  4.4728e+02  7.1508e+05  1.3268e+02  1.3489e+02
#> 
#> Columns 7 to 10 2.3861e+05  3.4465e+03  1.6658e+00  1.2360e+02
#> [ CPUFloatType{1,10} ]
#> 
#> 
#> $optimizer$state$`2`
#> $optimizer$state$`2`$step
#> [1] 2
#> 
#> $optimizer$state$`2`$exp_avg
#> torch_tensor
#> -13.7939
#> [ CPUFloatType{1} ]
#> 
#> $optimizer$state$`2`$exp_avg_sq
#> torch_tensor
#>  10.7954
#> [ CPUFloatType{1} ]
#> 
#> 
#> 
#> 
#> $epochs
#> [1] 1
#> 
#> $callbacks
#> named list()
#> 
#> $seed
#> [1] 1984753776
#> 
#> $task_col_info
#>         id    type levels
#>     <char>  <char> <list>
#>  1:     am numeric [NULL]
#>  2:   carb numeric [NULL]
#>  3:    cyl numeric [NULL]
#>  4:   disp numeric [NULL]
#>  5:   drat numeric [NULL]
#>  6:   gear numeric [NULL]
#>  7:     hp numeric [NULL]
#>  8:   qsec numeric [NULL]
#>  9:     vs numeric [NULL]
#> 10:     wt numeric [NULL]
#> 11:    mpg numeric [NULL]
#> 
#> attr(,"class")
#> [1] "learner_torch_model" "list"               
#> 
#> $log
#> Empty data.table (0 rows and 3 cols): stage,class,msg
#> 
#> $train_time
#> [1] 0.04
#> 
#> $param_vals
#> $param_vals$epochs
#> [1] 1
#> 
#> $param_vals$device
#> [1] "auto"
#> 
#> $param_vals$num_threads
#> [1] 1
#> 
#> $param_vals$seed
#> [1] "random"
#> 
#> $param_vals$eval_freq
#> [1] 1
#> 
#> $param_vals$measures_train
#> list()
#> 
#> $param_vals$measures_valid
#> list()
#> 
#> $param_vals$patience
#> [1] 0
#> 
#> $param_vals$min_delta
#> [1] 0
#> 
#> $param_vals$batch_size
#> [1] 20
#> 
#> 
#> $task_hash
#> [1] "6fbb93eed42adb31"
#> 
#> $feature_names
#>  [1] "am"   "carb" "cyl"  "disp" "drat" "gear" "hp"   "qsec" "vs"   "wt"  
#> 
#> $validate
#> NULL
#> 
#> $mlr3_version
#> [1] ‘0.21.0’
#> 
#> $internal_tuned_values
#> named list()
#> 
#> $data_prototype
#> Empty data.table (0 rows and 11 cols): mpg,am,carb,cyl,disp,drat...
#> 
#> $task_prototype
#> Empty data.table (0 rows and 11 cols): mpg,am,carb,cyl,disp,drat...
#> 
#> $train_task
#> <TaskRegr:mtcars> (32 x 11): Motor Trends
#> * Target: mpg
#> * Properties: -
#> * Features (10):
#>   - dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt
#> 
#> attr(,"class")
#> [1] "learner_state" "list"