Create a torch learner from an instantiated nn_module().
For classification, the output of the network must be the scores (before the softmax).
Parameters
See LearnerTorch
See also
Other Learner:
mlr_learners.ft_transformer,
mlr_learners.mlp,
mlr_learners.module,
mlr_learners.tab_resnet,
mlr_learners.torch_featureless,
mlr_learners_torch,
mlr_learners_torch_image
Other Graph Network:
ModelDescriptor(),
TorchIngressToken(),
mlr_pipeops_module,
mlr_pipeops_torch,
mlr_pipeops_torch_ingress,
mlr_pipeops_torch_ingress_categ,
mlr_pipeops_torch_ingress_ltnsr,
mlr_pipeops_torch_ingress_num,
model_descriptor_to_learner(),
model_descriptor_to_module(),
model_descriptor_union(),
nn_graph()
Super classes
mlr3::Learner -> mlr3torch::LearnerTorch -> LearnerTorchModel
Active bindings
ingress_tokens(named
list()withTorchIngressTokenorNULL)
The ingress tokens. Must be non-NULLwhen calling$train().
Methods
Inherited methods
mlr3::Learner$base_learner()mlr3::Learner$configure()mlr3::Learner$encapsulate()mlr3::Learner$help()mlr3::Learner$predict()mlr3::Learner$predict_newdata()mlr3::Learner$reset()mlr3::Learner$selected_features()mlr3::Learner$train()mlr3torch::LearnerTorch$dataset()mlr3torch::LearnerTorch$format()mlr3torch::LearnerTorch$marshal()mlr3torch::LearnerTorch$print()mlr3torch::LearnerTorch$unmarshal()
Method new()
Creates a new instance of this R6 class.
Usage
LearnerTorchModel$new(
network = NULL,
ingress_tokens = NULL,
task_type,
properties = NULL,
optimizer = NULL,
loss = NULL,
callbacks = list(),
packages = character(0),
feature_types = NULL
)Arguments
network(
nn_module)
An instantiatednn_module. Is not cloned during construction. For classification, outputs must be the scores (before the softmax).ingress_tokens(
listofTorchIngressToken())
A list with ingress tokens that defines how the dataloader will be defined.task_type(
character(1))
The task type.properties(
NULLorcharacter())
The properties of the learner. Defaults to all available properties for the given task type.optimizer(
TorchOptimizer)
The torch optimizer.loss(
TorchLoss)
The loss to use for training.callbacks(
list()ofTorchCallbacks)
The callbacks used during training. Must have unique ids. They are executed in the order in which they are providedpackages(
character())
The R packages this object depends on.feature_types(
NULLorcharacter())
The feature types. Defaults to all available feature types.
Examples
# We show the learner using a classification task
# The iris task has 4 features and 3 classes
network = nn_linear(4, 3)
task = tsk("iris")
# This defines the dataloader.
# It loads all 4 features, which are also numeric.
# The shape is (NA, 4) because the batch dimension is generally NA
ingress_tokens = list(
input = TorchIngressToken(task$feature_names, batchgetter_num, c(NA, 4))
)
# Creating the learner and setting required parameters
learner = lrn("classif.torch_model",
network = network,
ingress_tokens = ingress_tokens,
batch_size = 16,
epochs = 1,
device = "cpu"
)
# A simple train-predict
ids = partition(task)
learner$train(task, ids$train)
learner$predict(task, ids$test)
#>
#> ── <PredictionClassif> for 50 observations: ────────────────────────────────────
#> row_ids truth response
#> 2 setosa setosa
#> 4 setosa setosa
#> 8 setosa setosa
#> --- --- ---
#> 144 virginica versicolor
#> 146 virginica versicolor
#> 148 virginica versicolor