Skip to contents

This base class provides the basic functionality for training and prediction of a neural network. All torch learners should inherit from this class.

Validation

To specify the validation data, you can set the $validate field of the Learner, which can be set to:

  • NULL: no validation

  • ratio: only proportion 1 - ratio of the task is used for training and ratio is used for validation.

  • "test" means that the "test" task of a resampling is used and is not possible when calling $train() manually.

  • "predefined": This will use the predefined $internal_valid_task of a mlr3::Task.

This validation data can also be used for early stopping, see the description of the Learner's parameters.

Saving a Learner

In order to save a LearnerTorch for later usage, it is necessary to call the $marshal() method on the Learner before writing it to disk, as the object will otherwise not be saved correctly. After loading a marshaled LearnerTorch into R again, you then need to call $unmarshal() to transform it into a useable state.

Early Stopping and Tuning

In order to prevent overfitting, the LearnerTorch class allows to use early stopping via the patience and min_delta parameters, see the Learner's parameters. When tuning a LearnerTorch it is also possible to combine the explicit tuning via mlr3tuning and the LearnerTorch's internal tuning of the epochs via early stopping. To do so, you just need to include epochs = to_tune(upper = <upper>, internal = TRUE) in the search space, where <upper> is the maximally allowed number of epochs, and configure the early stopping.

Model

The Model is a list of class "learner_torch_model" with the following elements:

  • network :: The trained network.

  • optimizer :: The $state_dict() optimizer used to train the network.

  • loss_fn :: The $state_dict() of the loss used to train the network.

  • callbacks :: The callbacks used to train the network.

  • seed :: The seed that was / is used for training and prediction.

  • epochs :: How many epochs the model was trained for (early stopping).

  • task_col_info :: A data.table() containing information about the train-task.

Parameters

General:

The parameters of the optimizer, loss and callbacks, prefixed with "opt.", "loss." and "cb.<callback id>." respectively, as well as:

  • epochs :: integer(1)
    The number of epochs.

  • device :: character(1)
    The device. One of "auto", "cpu", or "cuda" or other values defined in mlr_reflections$torch$devices. The value is initialized to "auto", which will select "cuda" if possible, then try "mps" and otherwise fall back to "cpu".

  • num_threads :: integer(1)
    The number of threads for intraop pararallelization (if device is "cpu"). This value is initialized to 1.

  • num_interop_threads :: integer(1)
    The number of threads for intraop and interop pararallelization (if device is "cpu"). This value is initialized to 1. Note that this can only be set once during a session and changing the value within an R session will raise a warning.

  • seed :: integer(1) or "random" or NULL
    The torch seed that is used during training and prediction. This value is initialized to "random", which means that a random seed will be sampled at the beginning of the training phase. This seed (either set or randomly sampled) is available via $model$seed after training and used during prediction. Note that by setting the seed during the training phase this will mean that by default (i.e. when seed is "random"), clones of the learner will use a different seed. If set to NULL, no seeding will be done.

Evaluation:

  • measures_train :: Measure or list() of Measures.
    Measures to be evaluated during training.

  • measures_valid :: Measure or list() of Measures.
    Measures to be evaluated during validation.

  • eval_freq :: integer(1)
    How often the train / validation predictions are evaluated using measures_train / measures_valid. This is initialized to 1. Note that the final model is always evaluated.

Early Stopping:

  • patience :: integer(1)
    This activates early stopping using the validation scores. If the performance of a model does not improve for patience evaluation steps, training is ended. Note that the final model is stored in the learner, not the best model. This is initialized to 0, which means no early stopping. The first entry from measures_valid is used as the metric. This also requires to specify the $validate field of the Learner, as well as measures_valid.

  • min_delta :: double(1)
    The minimum improvement threshold (>) for early stopping. Is initialized to 0.

Dataloader:

  • batch_size :: integer(1)
    The batch size (required).

  • shuffle :: logical(1)
    Whether to shuffle the instances in the dataset. Default is FALSE. This does not impact validation.

  • sampler :: torch::sampler
    Object that defines how the dataloader draw samples.

  • batch_sampler :: torch::sampler
    Object that defines how the dataloader draws batches.

  • num_workers :: integer(1)
    The number of workers for data loading (batches are loaded in parallel). The default is 0, which means that data will be loaded in the main process.

  • collate_fn :: function
    How to merge a list of samples to form a batch.

  • pin_memory :: logical(1)
    Whether the dataloader copies tensors into CUDA pinned memory before returning them.

  • drop_last :: logical(1)
    Whether to drop the last training batch in each epoch during training. Default is FALSE.

  • timeout :: numeric(1)
    The timeout value for collecting a batch from workers. Negative values mean no timeout and the default is -1.

  • worker_init_fn :: function(id)
    A function that receives the worker id (in [1, num_workers]) and is exectued after seeding on the worker but before data loading.

  • worker_globals :: list() | character()
    When loading data in parallel, this allows to export globals to the workers. If this is a character vector, the objects in the global environment with those names are copied to the workers.

  • worker_packages :: character()
    Which packages to load on the workers.

Also see torch::dataloder for more information.

Inheriting

There are no seperate classes for classification and regression to inherit from. Instead, the task_type must be specified as a construction argument. Currently, only classification and regression are supported.

When inheriting from this class, one should overload two private methods:

  • .network(task, param_vals)
    (Task, list()) -> nn_module
    Construct a torch::nn_module object for the given task and parameter values, i.e. the neural network that is trained by the learner. For classification, the output of this network are expected to be the scores before the application of the final softmax layer.

  • .dataset(task, param_vals)
    (Task, list()) -> torch::dataset
    Create the dataset for the task. Must respect the parameter value of the device. Moreover, one needs to pay attention respect the row ids of the provided task.

It is also possible to overwrite the private .dataloader() method instead of the .dataset() method. Per default, a dataloader is constructed using the output from the .dataset() method. However, this should respect the dataloader parameters from the ParamSet.

  • .dataloader(task, param_vals)
    (Task, list()) -> torch::dataloader
    Create a dataloader from the task. Needs to respect at least batch_size and shuffle (otherwise predictions can be permuted).

To change the predict types, the private .encode_prediction() method can be overwritten:

  • .encode_prediction(predict_tensor, task, param_vals)
    (torch_tensor, Task, list()) -> list()
    Take in the raw predictions from self$network (predict_tensor) and encode them into a format that can be converted to valid mlr3 predictions using mlr3::as_prediction_data(). This method must take self$predict_type into account.

While it is possible to add parameters by specifying the param_set construction argument, it is currently not possible to remove existing parameters, i.e. those listed in section Parameters. None of the parameters provided in param_set can have an id that starts with "loss.", "opt.", or "cb."`, as these are preserved for the dynamically constructed parameters of the optimizer, the loss function, and the callbacks.

To perform additional input checks on the task, the private .verify_train_task(task, param_vals) and .verify_predict_task(task, param_vals) can be overwritten.

For learners that have other construction arguments that should change the hash of a learner, it is required to implement the private $.additional_phash_input().

Super class

mlr3::Learner -> LearnerTorch

Active bindings

validate

How to construct the internal validation data. This parameter can be either NULL, a ratio in $(0, 1)$, "test", or "predefined".

loss

(TorchLoss)
The torch loss.

optimizer

(TorchOptimizer)
The torch optimizer.

callbacks

(list() of TorchCallbacks)
List of torch callbacks. The ids will be set as the names.

internal_valid_scores

Retrieves the internal validation scores as a named list(). Specify the $validate field and the measures_valid parameter to configure this. Returns NULL if learner is not trained yet.

internal_tuned_values

When early stopping is activate, this returns a named list with the early-stopped epochs, otherwise an empty list is returned. Returns NULL if learner is not trained yet.

marshaled

(logical(1))
Whether the learner is marshaled.

network

(nn_module())
Shortcut for learner$model$network.

param_set

(ParamSet)
The parameter set

hash

(character(1))
Hash (unique identifier) for this object.

phash

(character(1))
Hash (unique identifier) for this partial object, excluding some components which are varied systematically during tuning (parameter values).

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

LearnerTorch$new(
  id,
  task_type,
  param_set,
  properties,
  man,
  label,
  feature_types,
  optimizer = NULL,
  loss = NULL,
  packages = character(),
  predict_types = NULL,
  callbacks = list()
)

Arguments

id

(character(1))
The id for of the new object.

task_type

(character(1))
The task type.

param_set

(ParamSet or alist())
Either a parameter set, or an alist() containing different values of self, e.g. alist(private$.param_set1, private$.param_set2), from which a ParamSet collection should be created.

properties

(character())
The properties of the object. See mlr_reflections$learner_properties for available values.

man

(character(1))
String in the format [pkg]::[topic] pointing to a manual page for this object. The referenced help package can be opened via method $help().

label

(character(1))
Label for the new instance.

feature_types

(character())
The feature types. See mlr_reflections$task_feature_types for available values, Additionally, "lazy_tensor" is supported.

optimizer

(NULL or TorchOptimizer)
The optimizer to use for training. Defaults to adam.

loss

(NULL or TorchLoss)
The loss to use for training. Defaults to MSE for regression and cross entropy for classification.

packages

(character())
The R packages this object depends on.

predict_types

(character())
The predict types. See mlr_reflections$learner_predict_types for available values. For regression, the default is "response". For classification, this defaults to "response" and "prob". To deviate from the defaults, it is necessary to overwrite the private $.encode_prediction() method, see section Inheriting.

callbacks

(list() of TorchCallbacks)
The callbacks to use for training. Defaults to an empty list(), i.e. no callbacks.


Method format()

Helper for print outputs.

Usage

LearnerTorch$format(...)

Arguments

...

(ignored).


Method print()

Prints the object.

Usage

LearnerTorch$print(...)

Arguments

...

(any)
Currently unused.


Method marshal()

Marshal the learner.

Usage

LearnerTorch$marshal(...)

Arguments

...

(any)
Additional parameters.

Returns

self


Method unmarshal()

Unmarshal the learner.

Usage

LearnerTorch$unmarshal(...)

Arguments

...

(any)
Additional parameters.

Returns

self


Method dataset()

Create the dataset for a task.

Usage

LearnerTorch$dataset(task)

Arguments

task

Task
The task

Returns

dataset


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerTorch$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.