Context for training a torch learner.
This is the - mostly read-only - information callbacks have access to through the argument ctx
.
For more information on callbacks, see CallbackSet
.
See also
Other Callback:
TorchCallback
,
as_torch_callback()
,
as_torch_callbacks()
,
callback_set()
,
mlr3torch_callbacks
,
mlr_callback_set
,
mlr_callback_set.checkpoint
,
mlr_callback_set.progress
,
t_clbk()
,
torch_callback()
Public fields
learner
(
Learner
)
The torch learner.task_train
(
Task
)
The training task.task_valid
(
Task
orNULL
)
The validation task.loader_train
(
torch::dataloader
)
The data loader for training.loader_valid
(
torch::dataloader
)
The data loader for validation.measures_train
measures_valid
network
(
torch::nn_module
)
The torch network.optimizer
(
torch::optimizer
)
The optimizer.loss_fn
(
torch::nn_module
)
The loss function.total_epochs
(
integer(1)
)
The total number of epochs the learner is trained for.last_scores_train
(named
list()
orNULL
)
The scores from the last training batch. Names are the ids of the training measures. IfLearnerTorch
setseval_freq
different from1
, this isNULL
in all epochs that don't evaluate the model.last_scores_valid
(
list()
)
The scores from the last validation batch. Names are the ids of the validation measures. IfLearnerTorch
setseval_freq
different from1
, this isNULL
in all epochs that don't evaluate the model.epoch
(
integer(1)
)
The current epoch.step
(
integer(1)
)
The current iteration.prediction_encoder
(
function()
)
The learner's prediction encoder.batch
(named
list()
oftorch_tensor
s)
The current batch.terminate
(
logical(1)
)
If this field is set toTRUE
at the end of an epoch, training stops.
Methods
Method new()
Creates a new instance of this R6 class.
Usage
ContextTorch$new(
learner,
task_train,
task_valid = NULL,
loader_train,
loader_valid = NULL,
measures_train = NULL,
measures_valid = NULL,
network,
optimizer,
loss_fn,
total_epochs,
prediction_encoder,
eval_freq = 1L
)
Arguments
learner
(
Learner
)
The torch learner.task_train
(
Task
)
The training task.task_valid
(
Task
orNULL
)
The validation task.loader_train
(
torch::dataloader
)
The data loader for training.loader_valid
(
torch::dataloader
orNULL
)
The data loader for validation.measures_train
(
list()
ofMeasure
s orNULL
)
Measures used for training. Default isNULL
.measures_valid
network
(
torch::nn_module
)
The torch network.optimizer
(
torch::optimizer
)
The optimizer.loss_fn
(
torch::nn_module
)
The loss function.total_epochs
(
integer(1)
)
The total number of epochs the learner is trained for.prediction_encoder
(
function()
)
The learner's prediction encoder.eval_freq
(
integer(1)
)
The evaluation frequency.