This wraps a CallbackSet
and annotates it with metadata, most importantly a ParamSet
.
The callback is created for the given parameter values by calling the $generate()
method.
This class is usually used to configure the callback of a torch learner, e.g. when constructing
a learner of in a ModelDescriptor
.
For a list of available callbacks, see mlr3torch_callbacks
.
To conveniently retrieve a TorchCallback
, use t_clbk()
.
Parameters
Defined by the constructor argument param_set
.
If no parameter set is provided during construction, the parameter set is constructed by creating a parameter
for each argument of the wrapped loss function, where the parametes are then of type ParamUty
.
See also
Other Callback:
as_torch_callback()
,
as_torch_callbacks()
,
callback_set()
,
mlr3torch_callbacks
,
mlr_callback_set
,
mlr_callback_set.checkpoint
,
mlr_callback_set.progress
,
mlr_context_torch
,
t_clbk()
,
torch_callback()
Other Torch Descriptor:
TorchDescriptor
,
TorchLoss
,
TorchOptimizer
,
as_torch_callbacks()
,
as_torch_loss()
,
as_torch_optimizer()
,
mlr3torch_losses
,
mlr3torch_optimizers
,
t_clbk()
,
t_loss()
,
t_opt()
Super class
mlr3torch::TorchDescriptor
-> TorchCallback
Methods
Method new()
Creates a new instance of this R6 class.
Usage
TorchCallback$new(
callback_generator,
param_set = NULL,
id = NULL,
label = NULL,
packages = NULL,
man = NULL
)
Arguments
callback_generator
(
R6ClassGenerator
)
The class generator for the callback that is being wrapped.param_set
(
ParamSet
orNULL
)
The parameter set. IfNULL
(default) it is inferred fromcallback_generator
.id
(
character(1)
)
The id for of the new object.label
(
character(1)
)
Label for the new instance.packages
(
character()
)
The R packages this object depends on.man
(
character(1)
)
String in the format[pkg]::[topic]
pointing to a manual page for this object. The referenced help package can be opened via method$help()
.
Examples
# Create a new torch callback from an existing callback set
torch_callback = TorchCallback$new(CallbackSetCheckpoint)
# The parameters are inferred
torch_callback$param_set
#> <ParamSet(3)>
#> id class lower upper nlevels default value
#> <char> <char> <num> <num> <num> <list> <list>
#> 1: path ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 2: freq ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 3: freq_type ParamUty NA NA Inf <NoDefault[0]> [NULL]
# Retrieve a torch callback from the dictionary
torch_callback = t_clbk("checkpoint",
path = tempfile(), freq = 1
)
torch_callback
#> <TorchCallback:checkpoint> Checkpoint
#> * Generator: CallbackSetCheckpoint
#> * Parameters: path=/tmp/RtmpPL7hlH/file1a38410f6bff, freq=1
#> * Packages: mlr3torch,torch
torch_callback$label
#> [1] "Checkpoint"
torch_callback$id
#> [1] "checkpoint"
# open the help page of the wrapped callback set
# torch_callback$help()
# Create the callback set
callback = torch_callback$generate()
callback
#> <CallbackSetCheckpoint>
#> * Stages: on_batch_end, on_epoch_end, on_exit
# is the same as
CallbackSetCheckpoint$new(
path = tempfile(), freq = 1
)
#> <CallbackSetCheckpoint>
#> * Stages: on_batch_end, on_epoch_end, on_exit
# Use in a learner
learner = lrn("regr.mlp", callbacks = t_clbk("checkpoint"))
# the parameters of the callback are added to the learner's parameter set
learner$param_set
#> <ParamSetCollection(36)>
#> id class lower upper nlevels default
#> <char> <char> <num> <num> <num> <list>
#> 1: epochs ParamInt 0e+00 Inf Inf <NoDefault[0]>
#> 2: device ParamFct NA NA 12 <NoDefault[0]>
#> 3: num_threads ParamInt 1e+00 Inf Inf <NoDefault[0]>
#> 4: num_interop_threads ParamInt 1e+00 Inf Inf <NoDefault[0]>
#> 5: seed ParamInt -Inf Inf Inf <NoDefault[0]>
#> 6: eval_freq ParamInt 1e+00 Inf Inf <NoDefault[0]>
#> 7: measures_train ParamUty NA NA Inf <NoDefault[0]>
#> 8: measures_valid ParamUty NA NA Inf <NoDefault[0]>
#> 9: patience ParamInt 0e+00 Inf Inf <NoDefault[0]>
#> 10: min_delta ParamDbl 0e+00 Inf Inf <NoDefault[0]>
#> 11: batch_size ParamInt 1e+00 Inf Inf <NoDefault[0]>
#> 12: shuffle ParamLgl NA NA 2 FALSE
#> 13: sampler ParamUty NA NA Inf <NoDefault[0]>
#> 14: batch_sampler ParamUty NA NA Inf <NoDefault[0]>
#> 15: num_workers ParamInt 0e+00 Inf Inf 0
#> 16: collate_fn ParamUty NA NA Inf [NULL]
#> 17: pin_memory ParamLgl NA NA 2 FALSE
#> 18: drop_last ParamLgl NA NA 2 FALSE
#> 19: timeout ParamDbl -Inf Inf Inf -1
#> 20: worker_init_fn ParamUty NA NA Inf <NoDefault[0]>
#> 21: worker_globals ParamUty NA NA Inf <NoDefault[0]>
#> 22: worker_packages ParamUty NA NA Inf <NoDefault[0]>
#> 23: neurons ParamUty NA NA Inf <NoDefault[0]>
#> 24: p ParamDbl 0e+00 1e+00 Inf <NoDefault[0]>
#> 25: activation ParamUty NA NA Inf <NoDefault[0]>
#> 26: activation_args ParamUty NA NA Inf <NoDefault[0]>
#> 27: shape ParamUty NA NA Inf <NoDefault[0]>
#> 28: opt.lr ParamDbl 0e+00 Inf Inf 0.001
#> 29: opt.betas ParamUty NA NA Inf 0.900,0.999
#> 30: opt.eps ParamDbl 1e-16 1e-04 Inf 1e-08
#> 31: opt.weight_decay ParamDbl 0e+00 1e+00 Inf 0
#> 32: opt.amsgrad ParamLgl NA NA 2 FALSE
#> 33: loss.reduction ParamFct NA NA 2 mean
#> 34: cb.checkpoint.path ParamUty NA NA Inf <NoDefault[0]>
#> 35: cb.checkpoint.freq ParamInt 1e+00 Inf Inf <NoDefault[0]>
#> 36: cb.checkpoint.freq_type ParamFct NA NA 2 epoch
#> id class lower upper nlevels default
#> value
#> <list>
#> 1: [NULL]
#> 2: auto
#> 3: 1
#> 4: 1
#> 5: random
#> 6: 1
#> 7: <list[0]>
#> 8: <list[0]>
#> 9: 0
#> 10: 0
#> 11: [NULL]
#> 12: [NULL]
#> 13: [NULL]
#> 14: [NULL]
#> 15: [NULL]
#> 16: [NULL]
#> 17: [NULL]
#> 18: [NULL]
#> 19: [NULL]
#> 20: [NULL]
#> 21: [NULL]
#> 22: [NULL]
#> 23:
#> 24: 0.5
#> 25: <nn_relu[1]>
#> 26: <list[0]>
#> 27: [NULL]
#> 28: [NULL]
#> 29: [NULL]
#> 30: [NULL]
#> 31: [NULL]
#> 32: [NULL]
#> 33: [NULL]
#> 34: [NULL]
#> 35: [NULL]
#> 36: [NULL]
#> value