Configures the callbacks of a deep learning model.
Input and Output Channels
There is one input channel "input"
and one output channel "output"
.
During training, the channels are of class ModelDescriptor
.
During prediction, the channels are of class Task
.
Parameters
The parameters are defined dynamically from the callbacks, where the id of the respective callbacks is the respective set id.
Internals
During training the callbacks are cloned and added to the ModelDescriptor
.
See also
Other Model Configuration:
ModelDescriptor()
,
mlr_pipeops_torch_loss
,
mlr_pipeops_torch_optimizer
,
model_descriptor_union()
Other PipeOp:
mlr_pipeops_module
,
mlr_pipeops_torch_optimizer
Super class
mlr3pipelines::PipeOp
-> PipeOpTorchCallbacks
Methods
Method new()
Creates a new instance of this R6 class.
Usage
PipeOpTorchCallbacks$new(
callbacks = list(),
id = "torch_callbacks",
param_vals = list()
)
Arguments
callbacks
(
list
ofTorchCallback
s)
The callbacks (or something convertible viaas_torch_callbacks()
). Must have unique ids. All callbacks are cloned during construction.id
(
character(1)
)
Identifier of the resulting object.param_vals
(
list()
)
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.
Examples
po_cb = po("torch_callbacks", "checkpoint")
po_cb$param_set
#> <ParamSetCollection(3)>
#> id class lower upper nlevels default value
#> <char> <char> <num> <num> <num> <list> <list>
#> 1: checkpoint.path ParamUty NA NA Inf <NoDefault[0]> [NULL]
#> 2: checkpoint.freq ParamInt 1 Inf Inf <NoDefault[0]> [NULL]
#> 3: checkpoint.freq_type ParamFct NA NA 2 epoch [NULL]
mdin = po("torch_ingress_num")$train(list(tsk("iris")))
mdin[[1L]]$callbacks
#> named list()
mdout = po_cb$train(mdin)[[1L]]
mdout$callbacks
#> $checkpoint
#> <TorchCallback:checkpoint> Checkpoint
#> * Generator: CallbackSetCheckpoint
#> * Parameters: list()
#> * Packages: mlr3torch,torch
#>
# Can be called again
po_cb1 = po("torch_callbacks", t_clbk("progress"))
mdout1 = po_cb1$train(list(mdout))[[1L]]
mdout1$callbacks
#> $progress
#> <TorchCallback:progress> Progress
#> * Generator: CallbackSetProgress
#> * Parameters: list()
#> * Packages: progress,mlr3torch,torch
#>
#> $checkpoint
#> <TorchCallback:checkpoint> Checkpoint
#> * Generator: CallbackSetCheckpoint
#> * Parameters: list()
#> * Packages: mlr3torch,torch
#>