Skip to contents

Configures the callbacks of a deep learning model.

Input and Output Channels

There is one input channel "input" and one output channel "output". During training, the channels are of class ModelDescriptor. During prediction, the channels are of class Task.

State

The state is the value calculated by the public method shapes_out().

Parameters

The parameters are defined dynamically from the callbacks, where the id of the respective callbacks is the respective set id.

Internals

During training the callbacks are cloned and added to the ModelDescriptor.

Super class

mlr3pipelines::PipeOp -> PipeOpTorchCallbacks

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

PipeOpTorchCallbacks$new(
  callbacks = list(),
  id = "torch_callbacks",
  param_vals = list()
)

Arguments

callbacks

(list of TorchCallbacks)
The callbacks (or something convertible via as_torch_callbacks()). Must have unique ids. All callbacks are cloned during construction.

id

(character(1))
Identifier of the resulting object.

param_vals

(list())
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.


Method clone()

The objects of this class are cloneable with this method.

Usage

PipeOpTorchCallbacks$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

po_cb = po("torch_callbacks", "checkpoint")
po_cb$param_set
#> <ParamSetCollection(3)>
#>                      id    class lower upper nlevels        default  value
#>                  <char>   <char> <num> <num>   <num>         <list> <list>
#> 1:      checkpoint.path ParamUty    NA    NA     Inf <NoDefault[0]> [NULL]
#> 2:      checkpoint.freq ParamInt     1   Inf     Inf <NoDefault[0]> [NULL]
#> 3: checkpoint.freq_type ParamFct    NA    NA       2          epoch [NULL]
mdin = po("torch_ingress_num")$train(list(tsk("iris")))
mdin[[1L]]$callbacks
#> named list()
mdout = po_cb$train(mdin)[[1L]]
mdout$callbacks
#> $checkpoint
#> <TorchCallback:checkpoint> Checkpoint
#> * Generator: CallbackSetCheckpoint
#> * Parameters: list()
#> * Packages: mlr3torch,torch
#> 
# Can be called again
po_cb1 = po("torch_callbacks", t_clbk("progress"))
mdout1 = po_cb1$train(list(mdout))[[1L]]
mdout1$callbacks
#> $progress
#> <TorchCallback:progress> Progress
#> * Generator: CallbackSetProgress
#> * Parameters: list()
#> * Packages: progress,mlr3torch,torch
#> 
#> $checkpoint
#> <TorchCallback:checkpoint> Checkpoint
#> * Generator: CallbackSetCheckpoint
#> * Parameters: list()
#> * Packages: mlr3torch,torch
#>