This base class provides the basic functionality for training and prediction of a neural network. All torch learners should inherit from this class.
Validation
To specify the validation data, you can set the $validate
field of the Learner, which can be set to:
NULL
: no validationratio
: only proportion1 - ratio
of the task is used for training andratio
is used for validation."test"
means that the"test"
task of a resampling is used and is not possible when calling$train()
manually."predefined"
: This will use the predefined$internal_valid_task
of amlr3::Task
.
This validation data can also be used for early stopping, see the description of the Learner
's parameters.
Saving a Learner
In order to save a LearnerTorch
for later usage, it is necessary to call the $marshal()
method on the Learner
before writing it to disk, as the object will otherwise not be saved correctly.
After loading a marshaled LearnerTorch
into R again, you then need to call $unmarshal()
to transform it
into a useable state.
Early Stopping and Tuning
In order to prevent overfitting, the LearnerTorch
class allows to use early stopping via the patience
and min_delta
parameters, see the Learner
's parameters.
When tuning a LearnerTorch
it is also possible to combine the explicit tuning via mlr3tuning
and the LearnerTorch
's internal tuning of the epochs via early stopping.
To do so, you just need to include epochs = to_tune(upper = <upper>, internal = TRUE)
in the search space,
where <upper>
is the maximally allowed number of epochs, and configure the early stopping.
Model
The Model is a list of class "learner_torch_model"
with the following elements:
network
:: The trained network.optimizer
:: The$state_dict()
optimizer used to train the network.loss_fn
:: The$state_dict()
of the loss used to train the network.callbacks
:: The callbacks used to train the network.seed
:: The seed that was / is used for training and prediction.epochs
:: How many epochs the model was trained for (early stopping).task_col_info
:: Adata.table()
containing information about the train-task.
Parameters
General:
The parameters of the optimizer, loss and callbacks,
prefixed with "opt."
, "loss."
and "cb.<callback id>."
respectively, as well as:
epochs
::integer(1)
The number of epochs.device
::character(1)
The device. One of"auto"
,"cpu"
, or"cuda"
or other values defined inmlr_reflections$torch$devices
. The value is initialized to"auto"
, which will select"cuda"
if possible, then try"mps"
and otherwise fall back to"cpu"
.num_threads
::integer(1)
The number of threads for intraop pararallelization (ifdevice
is"cpu"
). This value is initialized to 1.num_interop_threads
::integer(1)
The number of threads for intraop and interop pararallelization (ifdevice
is"cpu"
). This value is initialized to 1. Note that this can only be set once during a session and changing the value within an R session will raise a warning.seed
::integer(1)
or"random"
orNULL
The torch seed that is used during training and prediction. This value is initialized to"random"
, which means that a random seed will be sampled at the beginning of the training phase. This seed (either set or randomly sampled) is available via$model$seed
after training and used during prediction. Note that by setting the seed during the training phase this will mean that by default (i.e. whenseed
is"random"
), clones of the learner will use a different seed. If set toNULL
, no seeding will be done.
Evaluation:
measures_train
::Measure
orlist()
ofMeasure
s.
Measures to be evaluated during training.measures_valid
::Measure
orlist()
ofMeasure
s.
Measures to be evaluated during validation.eval_freq
::integer(1)
How often the train / validation predictions are evaluated usingmeasures_train
/measures_valid
. This is initialized to1
. Note that the final model is always evaluated.
Early Stopping:
patience
::integer(1)
This activates early stopping using the validation scores. If the performance of a model does not improve forpatience
evaluation steps, training is ended. Note that the final model is stored in the learner, not the best model. This is initialized to0
, which means no early stopping. The first entry frommeasures_valid
is used as the metric. This also requires to specify the$validate
field of the Learner, as well asmeasures_valid
.min_delta
::double(1)
The minimum improvement threshold (>
) for early stopping. Is initialized to 0.
Dataloader:
batch_size
::integer(1)
The batch size (required).shuffle
::logical(1)
Whether to shuffle the instances in the dataset. Default isFALSE
. This does not impact validation.sampler
::torch::sampler
Object that defines how the dataloader draw samples.batch_sampler
::torch::sampler
Object that defines how the dataloader draws batches.num_workers
::integer(1)
The number of workers for data loading (batches are loaded in parallel). The default is0
, which means that data will be loaded in the main process.collate_fn
::function
How to merge a list of samples to form a batch.pin_memory
::logical(1)
Whether the dataloader copies tensors into CUDA pinned memory before returning them.drop_last
::logical(1)
Whether to drop the last training batch in each epoch during training. Default isFALSE
.timeout
::numeric(1)
The timeout value for collecting a batch from workers. Negative values mean no timeout and the default is-1
.worker_init_fn
::function(id)
A function that receives the worker id (in[1, num_workers]
) and is exectued after seeding on the worker but before data loading.worker_globals
::list()
|character()
When loading data in parallel, this allows to export globals to the workers. If this is a character vector, the objects in the global environment with those names are copied to the workers.worker_packages
::character()
Which packages to load on the workers.
Also see torch::dataloder
for more information.
Inheriting
There are no seperate classes for classification and regression to inherit from.
Instead, the task_type
must be specified as a construction argument.
Currently, only classification and regression are supported.
When inheriting from this class, one should overload two private methods:
.network(task, param_vals)
(Task
,list()
) ->nn_module
Construct atorch::nn_module
object for the given task and parameter values, i.e. the neural network that is trained by the learner. For classification, the output of this network are expected to be the scores before the application of the final softmax layer..dataset(task, param_vals)
(Task
,list()
) ->torch::dataset
Create the dataset for the task. Must respect the parameter value of the device. Moreover, one needs to pay attention respect the row ids of the provided task.
It is also possible to overwrite the private .dataloader()
method instead of the .dataset()
method.
Per default, a dataloader is constructed using the output from the .dataset()
method.
However, this should respect the dataloader parameters from the ParamSet
.
.dataloader(task, param_vals)
(Task
,list()
) ->torch::dataloader
Create a dataloader from the task. Needs to respect at leastbatch_size
andshuffle
(otherwise predictions can be permuted).
To change the predict types, the private .encode_prediction()
method can be overwritten:
.encode_prediction(predict_tensor, task, param_vals)
(torch_tensor
,Task
,list()
) ->list()
Take in the raw predictions fromself$network
(predict_tensor
) and encode them into a format that can be converted to validmlr3
predictions usingmlr3::as_prediction_data()
. This method must takeself$predict_type
into account.
While it is possible to add parameters by specifying the param_set
construction argument, it is currently
not possible to remove existing parameters, i.e. those listed in section Parameters.
None of the parameters provided in param_set
can have an id that starts with "loss."
, "opt.", or
"cb."`, as these are preserved for the dynamically constructed parameters of the optimizer, the loss function,
and the callbacks.
To perform additional input checks on the task, the private .verify_train_task(task, param_vals)
and
.verify_predict_task(task, param_vals)
can be overwritten.
For learners that have other construction arguments that should change the hash of a learner, it is required
to implement the private $.additional_phash_input()
.
Super class
mlr3::Learner
-> LearnerTorch
Active bindings
validate
How to construct the internal validation data. This parameter can be either
NULL
, a ratio in $(0, 1)$,"test"
, or"predefined"
.loss
(
TorchLoss
)
The torch loss.optimizer
(
TorchOptimizer
)
The torch optimizer.callbacks
(
list()
ofTorchCallback
s)
List of torch callbacks. The ids will be set as the names.internal_valid_scores
Retrieves the internal validation scores as a named
list()
. Specify the$validate
field and themeasures_valid
parameter to configure this. ReturnsNULL
if learner is not trained yet.internal_tuned_values
When early stopping is activate, this returns a named list with the early-stopped epochs, otherwise an empty list is returned. Returns
NULL
if learner is not trained yet.marshaled
(
logical(1)
)
Whether the learner is marshaled.network
(
nn_module()
)
Shortcut forlearner$model$network
.param_set
(
ParamSet
)
The parameter sethash
(
character(1)
)
Hash (unique identifier) for this object.phash
(
character(1)
)
Hash (unique identifier) for this partial object, excluding some components which are varied systematically during tuning (parameter values).
Methods
Method new()
Creates a new instance of this R6 class.
Usage
LearnerTorch$new(
id,
task_type,
param_set,
properties,
man,
label,
feature_types,
optimizer = NULL,
loss = NULL,
packages = character(),
predict_types = NULL,
callbacks = list()
)
Arguments
id
(
character(1)
)
The id for of the new object.task_type
(
character(1)
)
The task type.param_set
(
ParamSet
oralist()
)
Either a parameter set, or analist()
containing different values of self, e.g.alist(private$.param_set1, private$.param_set2)
, from which aParamSet
collection should be created.properties
(
character()
)
The properties of the object. Seemlr_reflections$learner_properties
for available values.man
(
character(1)
)
String in the format[pkg]::[topic]
pointing to a manual page for this object. The referenced help package can be opened via method$help()
.label
(
character(1)
)
Label for the new instance.feature_types
(
character()
)
The feature types. Seemlr_reflections$task_feature_types
for available values, Additionally,"lazy_tensor"
is supported.optimizer
(
NULL
orTorchOptimizer
)
The optimizer to use for training. Defaults to adam.loss
(
NULL
orTorchLoss
)
The loss to use for training. Defaults to MSE for regression and cross entropy for classification.packages
(
character()
)
The R packages this object depends on.predict_types
(
character()
)
The predict types. Seemlr_reflections$learner_predict_types
for available values. For regression, the default is"response"
. For classification, this defaults to"response"
and"prob"
. To deviate from the defaults, it is necessary to overwrite the private$.encode_prediction()
method, see section Inheriting.callbacks
(
list()
ofTorchCallback
s)
The callbacks to use for training. Defaults to an emptylist()
, i.e. no callbacks.
Method print()
Prints the object.