Context for training a torch learner.
This is the - mostly read-only - information callbacks have access to through the argument ctx.
For more information on callbacks, see CallbackSet.
Public fields
learner(
Learner)
The torch learner.task_train(
Task)
The training task.task_valid(
TaskorNULL)
The validation task.loader_train(
torch::dataloader)
The data loader for training.loader_valid(
torch::dataloader)
The data loader for validation.measures_trainmeasures_validnetwork(
torch::nn_module)
The torch network.optimizer(
torch::optimizer)
The optimizer.loss_fn(
torch::nn_module)
The loss function.total_epochs(
integer(1))
The total number of epochs the learner is trained for.last_scores_train(named
list()orNULL)
The scores from the last training batch. Names are the ids of the training measures. IfLearnerTorchsetseval_freqdifferent from1, this isNULLin all epochs that don't evaluate the model.last_scores_valid(
list())
The scores from the last validation batch. Names are the ids of the validation measures. IfLearnerTorchsetseval_freqdifferent from1, this isNULLin all epochs that don't evaluate the model.last_loss(
numeric(1))
The loss from the last trainings batch.y_hat(
torch_tensor)
The model's prediction for the current batch.epoch(
integer(1))
The current epoch.step(
integer(1))
The current iteration.prediction_encoder(
function())
The learner's prediction encoder.batch(named
list()oftorch_tensors)
The current batch.terminate(
logical(1))
If this field is set toTRUEat the end of an epoch, training stops.device(
torch::torch_device)
The device.
Methods
Method new()
Creates a new instance of this R6 class.
Usage
ContextTorch$new(
learner,
task_train,
task_valid = NULL,
loader_train,
loader_valid = NULL,
measures_train = NULL,
measures_valid = NULL,
network,
optimizer,
loss_fn,
total_epochs,
prediction_encoder,
eval_freq = 1L,
device
)Arguments
learner(
Learner)
The torch learner.task_train(
Task)
The training task.task_valid(
TaskorNULL)
The validation task.loader_train(
torch::dataloader)
The data loader for training.loader_valid(
torch::dataloaderorNULL)
The data loader for validation.measures_train(
list()ofMeasures orNULL)
Measures used for training. Default isNULL.measures_validnetwork(
torch::nn_module)
The torch network.optimizer(
torch::optimizer)
The optimizer.loss_fn(
torch::nn_module)
The loss function.total_epochs(
integer(1))
The total number of epochs the learner is trained for.prediction_encoder(
function())
The learner's prediction encoder. See section Inheriting ofLearnerTorch.eval_freq(
integer(1))
The evaluation frequency.device(
character(1))
The device.