Unfreeze some weights (parameters of the network) after some number of steps or epochs.
Super class
mlr3torch::CallbackSet
-> CallbackSetUnfreeze
Methods
Method new()
Creates a new instance of this R6 class.
Usage
CallbackSetUnfreeze$new(starting_weights, unfreeze)
Arguments
starting_weights
(
Select
)
ASelect
denoting the weights that are trainable from the start.unfreeze
(
data.table
)
Adata.table
with a columnweights
(a list column ofSelect
s) and a columnepoch
orbatch
. The selector indicates which parameters to unfreeze, while theepoch
orbatch
column indicates when to do so.
Examples
task = tsk("iris")
cb = t_clbk("unfreeze")
mlp = lrn("classif.mlp", callbacks = cb,
cb.unfreeze.starting_weights = select_invert(
select_name(c("0.weight", "3.weight", "6.weight", "6.bias"))
),
cb.unfreeze.unfreeze = data.table(
epoch = c(2, 5),
weights = list(select_name("0.weight"), select_name(c("3.weight", "6.weight")))
),
epochs = 6, batch_size = 150, neurons = c(1, 1, 1)
)
mlp$train(task)