Skip to contents

Builds a Torch Learner from a ModelDescriptor and trains it with the given parameter specification. The task type must be specified during construction.

Input and Output Channels

There is one input channel "input" that takes in ModelDescriptor during traing and a Task of the specified task_type during prediction. The output is NULL during training and a Prediction of given task_type during prediction.

State

A trained LearnerTorchModel.

Parameters

General:

The parameters of the optimizer, loss and callbacks, prefixed with "opt.", "loss." and "cb.<callback id>." respectively, as well as:

  • epochs :: integer(1)
    The number of epochs.

  • device :: character(1)
    The device. One of "auto", "cpu", or "cuda" or other values defined in mlr_reflections$torch$devices. The value is initialized to "auto", which will select "cuda" if possible, then try "mps" and otherwise fall back to "cpu".

  • num_threads :: integer(1)
    The number of threads for intraop pararallelization (if device is "cpu"). This value is initialized to 1.

  • seed :: integer(1) or "random" or NULL
    The torch seed that is used during training and prediction. This value is initialized to "random", which means that a random seed will be sampled at the beginning of the training phase. This seed (either set or randomly sampled) is available via $model$seed after training and used during prediction. Note that by setting the seed during the training phase this will mean that by default (i.e. when seed is "random"), clones of the learner will use a different seed. If set to NULL, no seeding will be done.

Evaluation:

  • measures_train :: Measure or list() of Measures.
    Measures to be evaluated during training.

  • measures_valid :: Measure or list() of Measures.
    Measures to be evaluated during validation.

  • eval_freq :: integer(1)
    How often the train / validation predictions are evaluated using measures_train / measures_valid. This is initialized to 1. Note that the final model is always evaluated.

Early Stopping:

  • patience :: integer(1)
    This activates early stopping using the validation scores. If the performance of a model does not improve for patience evaluation steps, training is ended. Note that the final model is stored in the learner, not the best model. This is initialized to 0, which means no early stopping. The first entry from measures_valid is used as the metric. This also requires to specify the $validate field of the Learner, as well as measures_valid.

  • min_delta :: double(1)
    The minimum improvement threshold (>) for early stopping. Is initialized to 0.

Dataloader:

  • batch_size :: integer(1)
    The batch size (required).

  • shuffle :: logical(1)
    Whether to shuffle the instances in the dataset. Default is FALSE. This does not impact validation.

  • sampler :: torch::sampler
    Object that defines how the dataloader draw samples.

  • batch_sampler :: torch::sampler
    Object that defines how the dataloader draws batches.

  • num_workers :: integer(1)
    The number of workers for data loading (batches are loaded in parallel). The default is 0, which means that data will be loaded in the main process.

  • collate_fn :: function
    How to merge a list of samples to form a batch.

  • pin_memory :: logical(1)
    Whether the dataloader copies tensors into CUDA pinned memory before returning them.

  • drop_last :: logical(1)
    Whether to drop the last training batch in each epoch during training. Default is FALSE.

  • timeout :: numeric(1)
    The timeout value for collecting a batch from workers. Negative values mean no timeout and the default is -1.

  • worker_init_fn :: function(id)
    A function that receives the worker id (in [1, num_workers]) and is exectued after seeding on the worker but before data loading.

  • worker_globals :: list() | character()
    When loading data in parallel, this allows to export globals to the workers. If this is a character vector, the objects in the global environment with those names are copied to the workers.

  • worker_packages :: character()
    Which packages to load on the workers.

Also see torch::dataloder for more information.

Internals

A LearnerTorchModel is created by calling model_descriptor_to_learner() on the provided ModelDescriptor that is received through the input channel. Then the parameters are set according to the parameters specified in PipeOpTorchModel and its '$train() method is called on the [Task][mlr3::Task] stored in the [ModelDescriptor`].

Super classes

mlr3pipelines::PipeOp -> mlr3pipelines::PipeOpLearner -> PipeOpTorchModel

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

PipeOpTorchModel$new(task_type, id = "torch_model", param_vals = list())

Arguments

task_type

(character(1))
The task type of the model.

id

(character(1))
Identifier of the resulting object.

param_vals

(list())
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.


Method clone()

The objects of this class are cloneable with this method.

Usage

PipeOpTorchModel$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.