Builds a Torch Learner from a ModelDescriptor
and trains it with the given parameter specification.
The task type must be specified during construction.
Input and Output Channels
There is one input channel "input"
that takes in ModelDescriptor
during traing and a Task
of the specified
task_type
during prediction.
The output is NULL
during training and a Prediction
of given task_type
during prediction.
State
A trained LearnerTorchModel
.
Parameters
General:
The parameters of the optimizer, loss and callbacks,
prefixed with "opt."
, "loss."
and "cb.<callback id>."
respectively, as well as:
epochs
::integer(1)
The number of epochs.device
::character(1)
The device. One of"auto"
,"cpu"
, or"cuda"
or other values defined inmlr_reflections$torch$devices
. The value is initialized to"auto"
, which will select"cuda"
if possible, then try"mps"
and otherwise fall back to"cpu"
.num_threads
::integer(1)
The number of threads for intraop pararallelization (ifdevice
is"cpu"
). This value is initialized to 1.num_interop_threads
::integer(1)
The number of threads for intraop and interop pararallelization (ifdevice
is"cpu"
). This value is initialized to 1. Note that this can only be set once during a session and changing the value within an R session will raise a warning.seed
::integer(1)
or"random"
orNULL
The torch seed that is used during training and prediction. This value is initialized to"random"
, which means that a random seed will be sampled at the beginning of the training phase. This seed (either set or randomly sampled) is available via$model$seed
after training and used during prediction. Note that by setting the seed during the training phase this will mean that by default (i.e. whenseed
is"random"
), clones of the learner will use a different seed. If set toNULL
, no seeding will be done.
Evaluation:
measures_train
::Measure
orlist()
ofMeasure
s.
Measures to be evaluated during training.measures_valid
::Measure
orlist()
ofMeasure
s.
Measures to be evaluated during validation.eval_freq
::integer(1)
How often the train / validation predictions are evaluated usingmeasures_train
/measures_valid
. This is initialized to1
. Note that the final model is always evaluated.
Early Stopping:
patience
::integer(1)
This activates early stopping using the validation scores. If the performance of a model does not improve forpatience
evaluation steps, training is ended. Note that the final model is stored in the learner, not the best model. This is initialized to0
, which means no early stopping. The first entry frommeasures_valid
is used as the metric. This also requires to specify the$validate
field of the Learner, as well asmeasures_valid
.min_delta
::double(1)
The minimum improvement threshold (>
) for early stopping. Is initialized to 0.
Dataloader:
batch_size
::integer(1)
The batch size (required).shuffle
::logical(1)
Whether to shuffle the instances in the dataset. Default isFALSE
. This does not impact validation.sampler
::torch::sampler
Object that defines how the dataloader draw samples.batch_sampler
::torch::sampler
Object that defines how the dataloader draws batches.num_workers
::integer(1)
The number of workers for data loading (batches are loaded in parallel). The default is0
, which means that data will be loaded in the main process.collate_fn
::function
How to merge a list of samples to form a batch.pin_memory
::logical(1)
Whether the dataloader copies tensors into CUDA pinned memory before returning them.drop_last
::logical(1)
Whether to drop the last training batch in each epoch during training. Default isFALSE
.timeout
::numeric(1)
The timeout value for collecting a batch from workers. Negative values mean no timeout and the default is-1
.worker_init_fn
::function(id)
A function that receives the worker id (in[1, num_workers]
) and is exectued after seeding on the worker but before data loading.worker_globals
::list()
|character()
When loading data in parallel, this allows to export globals to the workers. If this is a character vector, the objects in the global environment with those names are copied to the workers.worker_packages
::character()
Which packages to load on the workers.
Also see torch::dataloder
for more information.
Internals
A LearnerTorchModel
is created by calling model_descriptor_to_learner()
on the
provided ModelDescriptor
that is received through the input channel.
Then the parameters are set according to the parameters specified in PipeOpTorchModel
and
its '$train() method is called on the [
Task][mlr3::Task] stored in the [
ModelDescriptor`].
See also
Other PipeOps:
mlr_pipeops_nn_adaptive_avg_pool1d
,
mlr_pipeops_nn_adaptive_avg_pool2d
,
mlr_pipeops_nn_adaptive_avg_pool3d
,
mlr_pipeops_nn_avg_pool1d
,
mlr_pipeops_nn_avg_pool2d
,
mlr_pipeops_nn_avg_pool3d
,
mlr_pipeops_nn_batch_norm1d
,
mlr_pipeops_nn_batch_norm2d
,
mlr_pipeops_nn_batch_norm3d
,
mlr_pipeops_nn_block
,
mlr_pipeops_nn_celu
,
mlr_pipeops_nn_conv1d
,
mlr_pipeops_nn_conv2d
,
mlr_pipeops_nn_conv3d
,
mlr_pipeops_nn_conv_transpose1d
,
mlr_pipeops_nn_conv_transpose2d
,
mlr_pipeops_nn_conv_transpose3d
,
mlr_pipeops_nn_dropout
,
mlr_pipeops_nn_elu
,
mlr_pipeops_nn_flatten
,
mlr_pipeops_nn_gelu
,
mlr_pipeops_nn_glu
,
mlr_pipeops_nn_hardshrink
,
mlr_pipeops_nn_hardsigmoid
,
mlr_pipeops_nn_hardtanh
,
mlr_pipeops_nn_head
,
mlr_pipeops_nn_layer_norm
,
mlr_pipeops_nn_leaky_relu
,
mlr_pipeops_nn_linear
,
mlr_pipeops_nn_log_sigmoid
,
mlr_pipeops_nn_max_pool1d
,
mlr_pipeops_nn_max_pool2d
,
mlr_pipeops_nn_max_pool3d
,
mlr_pipeops_nn_merge
,
mlr_pipeops_nn_merge_cat
,
mlr_pipeops_nn_merge_prod
,
mlr_pipeops_nn_merge_sum
,
mlr_pipeops_nn_prelu
,
mlr_pipeops_nn_relu
,
mlr_pipeops_nn_relu6
,
mlr_pipeops_nn_reshape
,
mlr_pipeops_nn_rrelu
,
mlr_pipeops_nn_selu
,
mlr_pipeops_nn_sigmoid
,
mlr_pipeops_nn_softmax
,
mlr_pipeops_nn_softplus
,
mlr_pipeops_nn_softshrink
,
mlr_pipeops_nn_softsign
,
mlr_pipeops_nn_squeeze
,
mlr_pipeops_nn_tanh
,
mlr_pipeops_nn_tanhshrink
,
mlr_pipeops_nn_threshold
,
mlr_pipeops_nn_unsqueeze
,
mlr_pipeops_torch_ingress
,
mlr_pipeops_torch_ingress_categ
,
mlr_pipeops_torch_ingress_ltnsr
,
mlr_pipeops_torch_ingress_num
,
mlr_pipeops_torch_loss
,
mlr_pipeops_torch_model_classif
,
mlr_pipeops_torch_model_regr
Super classes
mlr3pipelines::PipeOp
-> mlr3pipelines::PipeOpLearner
-> PipeOpTorchModel
Methods
Method new()
Creates a new instance of this R6 class.
Usage
PipeOpTorchModel$new(task_type, id = "torch_model", param_vals = list())
Arguments
task_type
(
character(1)
)
The task type of the model.id
(
character(1)
)
Identifier of the resulting object.param_vals
(
list()
)
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.