Skip to contents

This wraps a torch::torch_optimizer_generatora and annotates it with metadata, most importantly a ParamSet. The optimizer is created for the given parameter values by calling the $generate() method.

This class is usually used to configure the optimizer of a torch learner, e.g. when construcing a learner or in a ModelDescriptor.

For a list of available optimizers, see mlr3torch_optimizers. Items from this dictionary can be retrieved using t_opt().

Parameters

Defined by the constructor argument param_set. If no parameter set is provided during construction, the parameter set is constructed by creating a parameter for each argument of the wrapped loss function, where the parametes are then of type ParamUty.

Super class

mlr3torch::TorchDescriptor -> TorchOptimizer

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

TorchOptimizer$new(
  torch_optimizer,
  param_set = NULL,
  id = NULL,
  label = NULL,
  packages = NULL,
  man = NULL
)

Arguments

torch_optimizer

(torch_optimizer_generator)
The torch optimizer.

param_set

(ParamSet or NULL)
The parameter set. If NULL (default) it is inferred from torch_optimizer.

id

(character(1))
The id for of the new object.

label

(character(1))
Label for the new instance.

packages

(character())
The R packages this object depends on.

man

(character(1))
String in the format [pkg]::[topic] pointing to a manual page for this object. The referenced help package can be opened via method $help().


Method generate()

Instantiates the optimizer.

Usage

TorchOptimizer$generate(params)

Arguments

params

(named list() of torch_tensors)
The parameters of the network.

Returns

torch_optimizer


Method clone()

The objects of this class are cloneable with this method.

Usage

TorchOptimizer$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# Create a new torch loss
torch_opt = TorchOptimizer$new(optim_adam, label = "adam")
torch_opt
#> <TorchOptimizer:optim_adam> adam
#> * Generator: optim_adam
#> * Parameters: list()
#> * Packages: torch,mlr3torch
# If the param set is not specified, parameters are inferred but are of class ParamUty
torch_opt$param_set
#> <ParamSet(5)>
#>              id    class lower upper nlevels        default  value
#>          <char>   <char> <num> <num>   <num>         <list> <list>
#> 1:           lr ParamUty    NA    NA     Inf <NoDefault[0]> [NULL]
#> 2:        betas ParamUty    NA    NA     Inf <NoDefault[0]> [NULL]
#> 3:          eps ParamUty    NA    NA     Inf <NoDefault[0]> [NULL]
#> 4: weight_decay ParamUty    NA    NA     Inf <NoDefault[0]> [NULL]
#> 5:      amsgrad ParamUty    NA    NA     Inf <NoDefault[0]> [NULL]

# open the help page of the wrapped optimizer
# torch_opt$help()

# Retrieve an optimizer from the dictionary
torch_opt = t_opt("sgd", lr = 0.1)
torch_opt
#> <TorchOptimizer:sgd> Stochastic Gradient Descent
#> * Generator: optim_sgd
#> * Parameters: lr=0.1
#> * Packages: torch,mlr3torch
torch_opt$param_set
#> <ParamSet(5)>
#>              id    class lower upper nlevels        default  value
#>          <char>   <char> <num> <num>   <num>         <list> <list>
#> 1:           lr ParamDbl     0   Inf     Inf <NoDefault[0]>    0.1
#> 2:     momentum ParamDbl     0     1     Inf              0 [NULL]
#> 3:    dampening ParamDbl     0     1     Inf              0 [NULL]
#> 4: weight_decay ParamDbl     0     1     Inf              0 [NULL]
#> 5:     nesterov ParamLgl    NA    NA       2          FALSE [NULL]
torch_opt$label
#> [1] "Stochastic Gradient Descent"
torch_opt$id
#> [1] "sgd"

# Create the optimizer for a network
net = nn_linear(10, 1)
opt = torch_opt$generate(net$parameters)

# is the same as
optim_sgd(net$parameters, lr = 0.1)
#> <optim_sgd>
#>   Inherits from: <torch_optimizer>
#>   Public:
#>     add_param_group: function (param_group) 
#>     clone: function (deep = FALSE) 
#>     defaults: list
#>     initialize: function (params, lr = optim_required(), momentum = 0, dampening = 0, 
#>     load_state_dict: function (state_dict, ..., .refer_to_state_dict = FALSE) 
#>     param_groups: list
#>     state: State, R6
#>     state_dict: function () 
#>     step: function (closure = NULL) 
#>     zero_grad: function () 
#>   Private:
#>     step_helper: function (closure, loop_fun) 

# Use in a learner
learner = lrn("regr.mlp", optimizer = t_opt("sgd"))
# The parameters of the optimizer are added to the learner's parameter set
learner$param_set
#> <ParamSetCollection(33)>
#>                      id    class lower upper nlevels        default
#>                  <char>   <char> <num> <num>   <num>         <list>
#>  1:              epochs ParamInt     0   Inf     Inf <NoDefault[0]>
#>  2:              device ParamFct    NA    NA      12 <NoDefault[0]>
#>  3:         num_threads ParamInt     1   Inf     Inf <NoDefault[0]>
#>  4: num_interop_threads ParamInt     1   Inf     Inf <NoDefault[0]>
#>  5:                seed ParamInt  -Inf   Inf     Inf <NoDefault[0]>
#>  6:           eval_freq ParamInt     1   Inf     Inf <NoDefault[0]>
#>  7:      measures_train ParamUty    NA    NA     Inf <NoDefault[0]>
#>  8:      measures_valid ParamUty    NA    NA     Inf <NoDefault[0]>
#>  9:            patience ParamInt     0   Inf     Inf <NoDefault[0]>
#> 10:           min_delta ParamDbl     0   Inf     Inf <NoDefault[0]>
#> 11:          batch_size ParamInt     1   Inf     Inf <NoDefault[0]>
#> 12:             shuffle ParamLgl    NA    NA       2          FALSE
#> 13:             sampler ParamUty    NA    NA     Inf <NoDefault[0]>
#> 14:       batch_sampler ParamUty    NA    NA     Inf <NoDefault[0]>
#> 15:         num_workers ParamInt     0   Inf     Inf              0
#> 16:          collate_fn ParamUty    NA    NA     Inf         [NULL]
#> 17:          pin_memory ParamLgl    NA    NA       2          FALSE
#> 18:           drop_last ParamLgl    NA    NA       2          FALSE
#> 19:             timeout ParamDbl  -Inf   Inf     Inf             -1
#> 20:      worker_init_fn ParamUty    NA    NA     Inf <NoDefault[0]>
#> 21:      worker_globals ParamUty    NA    NA     Inf <NoDefault[0]>
#> 22:     worker_packages ParamUty    NA    NA     Inf <NoDefault[0]>
#> 23:             neurons ParamUty    NA    NA     Inf <NoDefault[0]>
#> 24:                   p ParamDbl     0     1     Inf <NoDefault[0]>
#> 25:          activation ParamUty    NA    NA     Inf <NoDefault[0]>
#> 26:     activation_args ParamUty    NA    NA     Inf <NoDefault[0]>
#> 27:               shape ParamUty    NA    NA     Inf <NoDefault[0]>
#> 28:              opt.lr ParamDbl     0   Inf     Inf <NoDefault[0]>
#> 29:        opt.momentum ParamDbl     0     1     Inf              0
#> 30:       opt.dampening ParamDbl     0     1     Inf              0
#> 31:    opt.weight_decay ParamDbl     0     1     Inf              0
#> 32:        opt.nesterov ParamLgl    NA    NA       2          FALSE
#> 33:      loss.reduction ParamFct    NA    NA       2           mean
#>                      id    class lower upper nlevels        default
#>            value
#>           <list>
#>  1:       [NULL]
#>  2:         auto
#>  3:            1
#>  4:            1
#>  5:       random
#>  6:            1
#>  7:    <list[0]>
#>  8:    <list[0]>
#>  9:            0
#> 10:            0
#> 11:       [NULL]
#> 12:       [NULL]
#> 13:       [NULL]
#> 14:       [NULL]
#> 15:       [NULL]
#> 16:       [NULL]
#> 17:       [NULL]
#> 18:       [NULL]
#> 19:       [NULL]
#> 20:       [NULL]
#> 21:       [NULL]
#> 22:       [NULL]
#> 23:             
#> 24:          0.5
#> 25: <nn_relu[1]>
#> 26:    <list[0]>
#> 27:       [NULL]
#> 28:       [NULL]
#> 29:       [NULL]
#> 30:       [NULL]
#> 31:       [NULL]
#> 32:       [NULL]
#> 33:       [NULL]
#>            value