Skip to contents

Context for training a torch learner. This is the - mostly read-only - information callbacks have access to through the argument ctx. For more information on callbacks, see CallbackSet.

Public fields

learner

(Learner)
The torch learner.

task_train

(Task)
The training task.

task_valid

(Task or NULL)
The validation task.

loader_train

(torch::dataloader)
The data loader for training.

loader_valid

(torch::dataloader)
The data loader for validation.

measures_train

(list() of Measures)
Measures used for training.

measures_valid

(list() of Measures)
Measures used for validation.

network

(torch::nn_module)
The torch network.

optimizer

(torch::optimizer)
The optimizer.

loss_fn

(torch::nn_module)
The loss function.

total_epochs

(integer(1))
The total number of epochs the learner is trained for.

last_scores_train

(named list() or NULL)
The scores from the last training batch. Names are the ids of the training measures. If LearnerTorch sets eval_freq different from 1, this is NULL in all epochs that don't evaluate the model.

last_scores_valid

(list())
The scores from the last validation batch. Names are the ids of the validation measures. If LearnerTorch sets eval_freq different from 1, this is NULL in all epochs that don't evaluate the model.

epoch

(integer(1))
The current epoch.

step

(integer(1))
The current iteration.

prediction_encoder

(function())
The learner's prediction encoder.

batch

(named list() of torch_tensors)
The current batch.

terminate

(logical(1))
If this field is set to TRUE at the end of an epoch, training stops.

Methods


Method new()

Creates a new instance of this R6 class.

Usage

ContextTorch$new(
  learner,
  task_train,
  task_valid = NULL,
  loader_train,
  loader_valid = NULL,
  measures_train = NULL,
  measures_valid = NULL,
  network,
  optimizer,
  loss_fn,
  total_epochs,
  prediction_encoder,
  eval_freq = 1L
)

Arguments

learner

(Learner)
The torch learner.

task_train

(Task)
The training task.

task_valid

(Task or NULL)
The validation task.

loader_train

(torch::dataloader)
The data loader for training.

loader_valid

(torch::dataloader or NULL)
The data loader for validation.

measures_train

(list() of Measures or NULL)
Measures used for training. Default is NULL.

measures_valid

(list() of Measures or NULL)
Measures used for validation.

network

(torch::nn_module)
The torch network.

optimizer

(torch::optimizer)
The optimizer.

loss_fn

(torch::nn_module)
The loss function.

total_epochs

(integer(1))
The total number of epochs the learner is trained for.

prediction_encoder

(function())
The learner's prediction encoder.

eval_freq

(integer(1))
The evaluation frequency.


Method clone()

The objects of this class are cloneable with this method.

Usage

ContextTorch$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.