The torch callback mechanism allows to customize the training loop of
a neural network. While mlr3torch
has some predefined
callbacks for common use-cases, this vignette will show you how to write
your own custom callback.
Building Blocks
At the heart of the callback mechanism are three R6
classes:
-
CallbackSet
contains the methods that are executed at different stages of the training loop -
TorchCallback
wraps aCallbackSet
and annotates it with meta information, including aParamSet
-
ContextTorch
defines which information of the training process theCallbackSet
has access to
When using predefined callbacks, one usually only interacts with the
TorchCallback
class, which is constructed using
t_clbk(<id>)
:
tc_hist = t_clbk("history")
tc_hist
## <TorchCallback:history> History
## * Generator: CallbackSetHistory
## * Parameters: list()
## * Packages: mlr3torch,torch
The wrapped CallbackSet
is accessible via the field
$generator
, and we can create a new instance by calling
$new()
cbs_hist = tc_hist$generator$new()
Within the training loop of a torch model, different stages exist at
which a callback can execute code. Below, we list all stages that are
available, but they are also described in the documentation of
CallbackSet
.
mlr_reflections$torch$callback_stages
## [1] "on_begin" "on_epoch_begin" "on_batch_begin"
## [4] "on_after_backward" "on_batch_end" "on_before_valid"
## [7] "on_batch_valid_begin" "on_batch_valid_end" "on_valid_end"
## [10] "on_epoch_end" "on_end" "on_exit"
For the history callback, we see that it runs code at the beginning, before validation starts, and at the end of an epoch.
cbs_hist
## <CallbackSetHistory>
## * Stages: on_begin, on_before_valid, on_epoch_end
Writing a Custom Logger
In order to define our own custom callback, we can use the helper
torch_callback()
helper function. As an example, we will
create a custom logger that keeps track of an exponential moving average
of the train loss and prints it at the end of every epoch. This callback
takes one argument alpha
which is the smoothing parameter
and will store the moving average in self$moving_loss
. The
value alpha
can later be configured in the
Learner
.
Then, we implement the main logic of the callback using two stages:
-
on_batch_end()
, which is called after the optimizer updates the network parameters using a mini-batch. Here, we access the last loss via theContextTorch
’s (accessible viaself$ctx
)$last_loss
field and update the valueself$moving_loss
. -
on_before_valid()
, which is run before the validation loop. At this point we simply print the exponential moving average of the training loss.
Finally, in order to make the final value of the moving average
accessible in the Learner
after the training finishes, we
implement the
-
state_dict()
method to return this value and the -
load_state_dict(state_dict)
method, that takes a previously retrieved state dict and sets it in the callback.
Note that we only do this here to have access to the final
$moving_loss
via the learner, but we would otherwise not
have to implement these methods.
custom_logger = torch_callback("custom_logger",
initialize = function(alpha = 0.1) {
self$alpha = alpha
self$moving_loss = NULL
},
on_batch_end = function() {
if (is.null(self$moving_training_loss)) {
self$moving_loss = self$ctx$last_loss
} else {
self$moving_loss = self$alpha * self$last_loss + (1 - self$alpha) * self$moving_loss
}
},
on_before_valid = function() {
cat(sprintf("Epoch %s: %.2f\n", self$ctx$epoch, self$moving_loss))
},
state_dict = function() {
self$moving_loss
},
load_state_dict = function(state_dict) {
self$moving_loss = state_dict
}
)
This created a TorchCallback
object and and associated
CallbackSet
R6-class for us:
custom_logger
## <TorchCallback:custom_logger> Custom_logger
## * Generator: CallbackSetCustom_logger
## * Parameters: list()
## * Packages: mlr3torch,torch
custom_logger$generator
## <CallbackSetCustom_logger> object generator
## Inherits from: <inherit>
## Public:
## initialize: function (alpha = 0.1)
## state_dict: function ()
## load_state_dict: function (state_dict)
## on_before_valid: function ()
## on_batch_end: function ()
## Parent env: <environment: 0x55d884044cc8>
## Locked objects: FALSE
## Locked class: FALSE
## Portable: TRUE
Using the Custom Logger
In order to showcase its effects, we create train a simple multi
layer perceptron and pass it the callback. We set the smoothing
parameter alpha
to 0.05.
task = tsk("iris")
mlp = lrn("classif.mlp",
callbacks = custom_logger, cb.custom_logger.alpha = 0.05,
epochs = 5, batch_size = 64, neurons = 20)
## Warning in warn_deprecated("Learner$initialize argument 'data_formats'"):
## Learner$initialize argument 'data_formats' is deprecated and will be removed in
## the future.
mlp$train(task)
## Epoch 1: 1.22
## Epoch 2: 1.27
## Epoch 3: 1.48
## Epoch 4: 1.07
## Epoch 5: 1.09
The information that is returnede by state_dict()
is now
accessible via the Learner
’s $model
-slot:
mlp$model$callbacks$custom_logger
## [1] 1.090038