Skip to contents

Unfreeze some weights (parameters of the network) after some number of steps or epochs.

Super class

mlr3torch::CallbackSet -> CallbackSetUnfreeze

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

CallbackSetUnfreeze$new(starting_weights, unfreeze)

Arguments

starting_weights

(Select)
A Select denoting the weights that are trainable from the start.

unfreeze

(data.table)
A data.table with a column weights (a list column of Selects) and a column epoch or batch. The selector indicates which parameters to unfreeze, while the epoch or batch column indicates when to do so.


Method on_begin()

Sets the starting weights

Usage

CallbackSetUnfreeze$on_begin()


Method on_epoch_begin()

Unfreezes weights if the training is at the correct epoch

Usage

CallbackSetUnfreeze$on_epoch_begin()


Method on_batch_begin()

Unfreezes weights if the training is at the correct batch

Usage

CallbackSetUnfreeze$on_batch_begin()


Method clone()

The objects of this class are cloneable with this method.

Usage

CallbackSetUnfreeze$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

task = tsk("iris")
cb = t_clbk("unfreeze")
mlp = lrn("classif.mlp", callbacks = cb,
 cb.unfreeze.starting_weights = select_invert(
   select_name(c("0.weight", "3.weight", "6.weight", "6.bias"))
 ),
 cb.unfreeze.unfreeze = data.table(
   epoch = c(2, 5),
   weights = list(select_name("0.weight"), select_name(c("3.weight", "6.weight")))
 ),
 epochs = 6, batch_size = 150, neurons = c(1, 1, 1)
)

mlp$train(task)