This wraps a torch::torch_optimizer_generator
a and annotates it with metadata, most importantly a ParamSet
.
The optimizer is created for the given parameter values by calling the $generate()
method.
This class is usually used to configure the optimizer of a torch learner, e.g.
when construcing a learner or in a ModelDescriptor
.
For a list of available optimizers, see mlr3torch_optimizers
.
Items from this dictionary can be retrieved using t_opt()
.
Parameters
Defined by the constructor argument param_set
.
If no parameter set is provided during construction, the parameter set is constructed by creating a parameter
for each argument of the wrapped loss function, where the parametes are then of type ParamUty
.
See also
Other Torch Descriptor:
TorchCallback
,
TorchDescriptor
,
TorchLoss
,
as_torch_callbacks()
,
as_torch_loss()
,
as_torch_optimizer()
,
mlr3torch_losses
,
mlr3torch_optimizers
,
t_clbk()
,
t_loss()
,
t_opt()
Super class
mlr3torch::TorchDescriptor
-> TorchOptimizer
Methods
Inherited methods
Method new()
Creates a new instance of this R6 class.
Usage
TorchOptimizer$new(
torch_optimizer,
param_set = NULL,
id = NULL,
label = NULL,
packages = NULL,
man = NULL
)
Arguments
torch_optimizer
(
torch_optimizer_generator
)
The torch optimizer.param_set
(
ParamSet
orNULL
)
The parameter set. IfNULL
(default) it is inferred fromtorch_optimizer
.id
(
character(1)
)
The id for of the new object.label
(
character(1)
)
Label for the new instance.packages
(
character()
)
The R packages this object depends on.man
(
character(1)
)
String in the format[pkg]::[topic]
pointing to a manual page for this object. The referenced help package can be opened via method$help()
.
Method generate()
Instantiates the optimizer.
Arguments
params
(named
list()
oftorch_tensor
s)
The parameters of the network.
Examples
# Create a new torch loss
torch_opt = TorchOptimizer$new(optim_adam, label = "adam")
torch_opt
#> <TorchOptimizer:optim_adam> adam
#> * Generator: optim_adam
#> * Parameters: list()
#> * Packages: torch,mlr3torch
# If the param set is not specified, parameters are inferred but are of class ParamUty
torch_opt$param_set
#> <ParamSet(5)>
#> id class lower upper nlevels default value
#> <char> <char> <num> <num> <num> <list> <list>
#> 1: lr ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 2: betas ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 3: eps ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 4: weight_decay ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 5: amsgrad ParamUty NA NA Inf <NoDefault[0]> [NULL]
# open the help page of the wrapped optimizer
# torch_opt$help()
# Retrieve an optimizer from the dictionary
torch_opt = t_opt("sgd", lr = 0.1)
torch_opt
#> <TorchOptimizer:sgd> Stochastic Gradient Descent
#> * Generator: optim_sgd
#> * Parameters: lr=0.1
#> * Packages: torch,mlr3torch
torch_opt$param_set
#> <ParamSet(5)>
#> id class lower upper nlevels default value
#> <char> <char> <num> <num> <num> <list> <list>
#> 1: lr ParamDbl 0 Inf Inf <NoDefault[0]> 0.1
#> 2: momentum ParamDbl 0 1 Inf 0 [NULL]
#> 3: dampening ParamDbl 0 1 Inf 0 [NULL]
#> 4: weight_decay ParamDbl 0 1 Inf 0 [NULL]
#> 5: nesterov ParamLgl NA NA 2 FALSE [NULL]
torch_opt$label
#> [1] "Stochastic Gradient Descent"
torch_opt$id
#> [1] "sgd"
# Create the optimizer for a network
net = nn_linear(10, 1)
opt = torch_opt$generate(net$parameters)
# is the same as
optim_sgd(net$parameters, lr = 0.1)
#> <optim_sgd>
#> Inherits from: <torch_optimizer>
#> Public:
#> add_param_group: function (param_group)
#> clone: function (deep = FALSE)
#> defaults: list
#> initialize: function (params, lr = optim_required(), momentum = 0, dampening = 0,
#> load_state_dict: function (state_dict, ..., .refer_to_state_dict = FALSE)
#> param_groups: list
#> state: State, R6
#> state_dict: function ()
#> step: function (closure = NULL)
#> zero_grad: function ()
#> Private:
#> step_helper: function (closure, loop_fun)
# Use in a learner
learner = lrn("regr.mlp", optimizer = t_opt("sgd"))
# The parameters of the optimizer are added to the learner's parameter set
learner$param_set
#> <ParamSetCollection(32)>
#> id class lower upper nlevels default value
#> <char> <char> <num> <num> <num> <list> <list>
#> 1: epochs ParamInt 0 Inf Inf <NoDefault[0]> [NULL]
#> 2: device ParamFct NA NA 12 <NoDefault[0]> auto
#> 3: num_threads ParamInt 1 Inf Inf <NoDefault[0]> 1
#> 4: seed ParamInt -Inf Inf Inf <NoDefault[0]> random
#> 5: eval_freq ParamInt 1 Inf Inf <NoDefault[0]> 1
#> 6: measures_train ParamUty NA NA Inf <NoDefault[0]> <list[0]>
#> 7: measures_valid ParamUty NA NA Inf <NoDefault[0]> <list[0]>
#> 8: patience ParamInt 0 Inf Inf <NoDefault[0]> 0
#> 9: min_delta ParamDbl 0 Inf Inf <NoDefault[0]> 0
#> 10: batch_size ParamInt 1 Inf Inf <NoDefault[0]> [NULL]
#> 11: shuffle ParamLgl NA NA 2 FALSE [NULL]
#> 12: sampler ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 13: batch_sampler ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 14: num_workers ParamInt 0 Inf Inf 0 [NULL]
#> 15: collate_fn ParamUty NA NA Inf [NULL] [NULL]
#> 16: pin_memory ParamLgl NA NA 2 FALSE [NULL]
#> 17: drop_last ParamLgl NA NA 2 FALSE [NULL]
#> 18: timeout ParamDbl -Inf Inf Inf -1 [NULL]
#> 19: worker_init_fn ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 20: worker_globals ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 21: worker_packages ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 22: neurons ParamUty NA NA Inf <NoDefault[0]>
#> 23: p ParamDbl 0 1 Inf <NoDefault[0]> 0.5
#> 24: activation ParamUty NA NA Inf <NoDefault[0]> <nn_relu[1]>
#> 25: activation_args ParamUty NA NA Inf <NoDefault[0]> <list[0]>
#> 26: shape ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 27: opt.lr ParamDbl 0 Inf Inf <NoDefault[0]> [NULL]
#> 28: opt.momentum ParamDbl 0 1 Inf 0 [NULL]
#> 29: opt.dampening ParamDbl 0 1 Inf 0 [NULL]
#> 30: opt.weight_decay ParamDbl 0 1 Inf 0 [NULL]
#> 31: opt.nesterov ParamLgl NA NA 2 FALSE [NULL]
#> 32: loss.reduction ParamFct NA NA 2 mean [NULL]
#> id class lower upper nlevels default value