Create a torch learner from an instantiated nn_module()
.
For classification, the output of the network must be the scores (before the softmax).
Parameters
See LearnerTorch
See also
Other Learner:
mlr_learners.mlp
,
mlr_learners.tab_resnet
,
mlr_learners.torch_featureless
,
mlr_learners_torch
,
mlr_learners_torch_image
Other Graph Network:
ModelDescriptor()
,
TorchIngressToken()
,
mlr_pipeops_module
,
mlr_pipeops_torch
,
mlr_pipeops_torch_ingress
,
mlr_pipeops_torch_ingress_categ
,
mlr_pipeops_torch_ingress_ltnsr
,
mlr_pipeops_torch_ingress_num
,
model_descriptor_to_learner()
,
model_descriptor_to_module()
,
model_descriptor_union()
,
nn_graph()
Super classes
mlr3::Learner
-> mlr3torch::LearnerTorch
-> LearnerTorchModel
Active bindings
network_stored
(
nn_module
orNULL
)
The network that will be trained. After calling$train()
, this isNULL
.ingress_tokens
(named
list()
withTorchIngressToken
orNULL
)
The ingress tokens. Must be non-NULL
when calling$train()
.
Methods
Inherited methods
mlr3::Learner$base_learner()
mlr3::Learner$encapsulate()
mlr3::Learner$help()
mlr3::Learner$predict()
mlr3::Learner$predict_newdata()
mlr3::Learner$reset()
mlr3::Learner$train()
mlr3torch::LearnerTorch$dataset()
mlr3torch::LearnerTorch$format()
mlr3torch::LearnerTorch$marshal()
mlr3torch::LearnerTorch$print()
mlr3torch::LearnerTorch$unmarshal()
Method new()
Creates a new instance of this R6 class.
Usage
LearnerTorchModel$new(
network = NULL,
ingress_tokens = NULL,
task_type,
properties = NULL,
optimizer = NULL,
loss = NULL,
callbacks = list(),
packages = character(0),
feature_types = NULL
)
Arguments
network
(
nn_module
)
An instantiatednn_module
. Is not cloned during construction. For classification, outputs must be the scores (before the softmax).ingress_tokens
(
list
ofTorchIngressToken()
)
A list with ingress tokens that defines how the dataloader will be defined.task_type
(
character(1)
)
The task type.properties
(
NULL
orcharacter()
)
The properties of the learner. Defaults to all available properties for the given task type.optimizer
(
TorchOptimizer
)
The torch optimizer.loss
(
TorchLoss
)
The loss to use for training.callbacks
(
list()
ofTorchCallback
s)
The callbacks used during training. Must have unique ids. They are executed in the order in which they are providedpackages
(
character()
)
The R packages this object depends on.feature_types
(
NULL
orcharacter()
)
The feature types. Defaults to all available feature types.
Examples
# We show the learner using a classification task
# The iris task has 4 features and 3 classes
network = nn_linear(4, 3)
task = tsk("iris")
# This defines the dataloader.
# It loads all 4 features, which are also numeric.
# The shape is (NA, 4) because the batch dimension is generally NA
ingress_tokens = list(
input = TorchIngressToken(task$feature_names, batchgetter_num, c(NA, 4))
)
# Creating the learner and setting required parameters
learner = lrn("classif.torch_model",
network = network,
ingress_tokens = ingress_tokens,
batch_size = 16,
epochs = 1,
device = "cpu"
)
# A simple train-predict
ids = partition(task)
learner$train(task, ids$train)
learner$predict(task, ids$test)
#> <PredictionClassif> for 50 observations:
#> row_ids truth response
#> 3 setosa versicolor
#> 9 setosa versicolor
#> 10 setosa versicolor
#> --- --- ---
#> 147 virginica setosa
#> 149 virginica setosa
#> 150 virginica setosa