Convenience function to create a custom TorchCallback
.
All arguments that are available in callback_set()
are also available here.
For more information on how to correctly implement a new callback, see CallbackSet
.
Usage
torch_callback(
id,
classname = paste0("CallbackSet", capitalize(id)),
param_set = NULL,
packages = NULL,
label = capitalize(id),
man = NULL,
on_begin = NULL,
on_end = NULL,
on_exit = NULL,
on_epoch_begin = NULL,
on_before_valid = NULL,
on_epoch_end = NULL,
on_batch_begin = NULL,
on_batch_end = NULL,
on_after_backward = NULL,
on_batch_valid_begin = NULL,
on_batch_valid_end = NULL,
on_valid_end = NULL,
state_dict = NULL,
load_state_dict = NULL,
initialize = NULL,
public = NULL,
private = NULL,
active = NULL,
parent_env = parent.frame(),
inherit = CallbackSet,
lock_objects = FALSE
)
Arguments
- id
(
character(1)
)
`
The id for the torch callback.- classname
(
character(1)
)
The class name.- param_set
(
ParamSet
)
The parameter set, if not present it is inferred from the$initialize()
method.- packages
(
character()
)The packages the callback depends on. Default is
NULL`.- label
(
character(1)
)
The label for the torch callback. Defaults to the capitalizedid
.- man
(
character(1)
)
String in the format[pkg]::[topic]
pointing to a manual page for this object. The referenced help package can be opened via method$help()
. The default isNULL
.- on_begin, on_end, on_epoch_begin, on_before_valid, on_epoch_end, on_batch_begin, on_batch_end, on_after_backward, on_batch_valid_begin, on_batch_valid_end, on_valid_end, on_exit
(
function
)
Function to execute at the given stage, see section Stages.- state_dict
(
function()
)
The function that retrieves the state dict from the callback. This is what will be available in the learner after training.- load_state_dict
(
function(state_dict)
)
Function that loads a callback state.- initialize
(
function()
)
The initialization method of the callback.- public, private, active
(
list()
)
Additional public, private, and active fields to add to the callback.- parent_env
(
environment()
)
The parent environment for theR6Class
.- inherit
(
R6ClassGenerator
)
From which class to inherit. This class must either beCallbackSet
(default) or inherit from it.- lock_objects
(
logical(1)
)
Whether to lock the objects of the resultingR6Class
. IfFALSE
(default), values can be freely assigned toself
without declaring them in the class definition.
Internals
It first creates an R6
class inheriting from CallbackSet
(using callback_set()
) and
then wraps this generator in a TorchCallback
that can be passed to a torch learner.
Stages
begin
:: Run before the training loop begins.epoch_begin
:: Run he beginning of each epoch.batch_begin
:: Run before the forward call.after_backward
:: Run after the backward call.batch_end
:: Run after the optimizer step.batch_valid_begin
:: Run before the forward call in the validation loop.batch_valid_end
:: Run after the forward call in the validation loop.valid_end
:: Run at the end of validation.epoch_end
:: Run at the end of each epoch.end
:: Run after last epoch.exit
:: Run at last, usingon.exit()
.
Examples
custom_tcb = torch_callback("custom",
initialize = function(name) {
self$name = name
},
on_begin = function() {
cat("Hello", self$name, ", we will train for ", self$ctx$total_epochs, "epochs.\n")
},
on_end = function() {
cat("Training is done.")
}
)
learner = lrn("classif.torch_featureless",
batch_size = 16,
epochs = 1,
callbacks = custom_tcb,
cb.custom.name = "Marie",
device = "cpu"
)
task = tsk("iris")
learner$train(task)
#> Hello Marie , we will train for 1 epochs.
#> Training is done.