Create a torch learner from a torch module.
See also
Other Learner:
mlr_learners.ft_transformer,
mlr_learners.mlp,
mlr_learners.tab_resnet,
mlr_learners.torch_featureless,
mlr_learners_torch,
mlr_learners_torch_image,
mlr_learners_torch_model
Other Learner:
mlr_learners.ft_transformer,
mlr_learners.mlp,
mlr_learners.tab_resnet,
mlr_learners.torch_featureless,
mlr_learners_torch,
mlr_learners_torch_image,
mlr_learners_torch_model
Super classes
mlr3::Learner -> mlr3torch::LearnerTorch -> LearnerTorchModule
Methods
Inherited methods
mlr3::Learner$base_learner()mlr3::Learner$configure()mlr3::Learner$encapsulate()mlr3::Learner$help()mlr3::Learner$predict()mlr3::Learner$predict_newdata()mlr3::Learner$reset()mlr3::Learner$selected_features()mlr3::Learner$train()mlr3torch::LearnerTorch$dataset()mlr3torch::LearnerTorch$format()mlr3torch::LearnerTorch$marshal()mlr3torch::LearnerTorch$print()mlr3torch::LearnerTorch$unmarshal()
Method new()
Creates a new instance of this R6 class.
Usage
LearnerTorchModule$new(
module_generator = NULL,
param_set = NULL,
ingress_tokens = NULL,
task_type,
properties = NULL,
optimizer = NULL,
loss = NULL,
callbacks = list(),
packages = character(0),
feature_types = NULL,
predict_types = NULL
)Arguments
module_generator(
functionornn_module_generator)
Ann_module_generatororfunctionreturning annn_module. Both must take as argument thetaskfor which to construct the network. Other arguments to its initialize method can be provided as parameters.param_set(
NULLorParamSet)
If provided, contains the parameters for the module_generator. IfNULL, parameters will be inferred from the module_generator.ingress_tokens(
listofTorchIngressToken())
A list with ingress tokens that defines how the dataset will be defined. The names must correspond to the arguments of the network's forward method. For numeric, categorical, and lazy tensor features, you can useingress_num(),ingress_categ(), andingress_ltnsr()to create them.task_type(
character(1))
The task type, either"classif" or"regr".task_type(
character(1))
The task type.properties(
NULLorcharacter())
The properties of the learner. Defaults to all available properties for the given task type.optimizer(
TorchOptimizer)
The optimizer to use for training. Per default, adam is used.loss(
TorchLoss)
The loss used to train the network. Per default, mse is used for regression and cross_entropy for classification.callbacks(
list()ofTorchCallbacks)
The callbacks. Must have unique ids.packages(
character())
The R packages this object depends on.feature_types(
NULLorcharacter())
The feature types. Defaults to all available feature types.predict_types(
character())
The predict types. Seemlr_reflections$learner_predict_typesfor available values.
Examples
nn_one_layer = nn_module("nn_one_layer",
initialize = function(task, size_hidden) {
self$first = nn_linear(task$n_features, size_hidden)
self$second = nn_linear(size_hidden, output_dim_for(task))
},
# argument x corresponds to the ingress token x
forward = function(x) {
x = self$first(x)
x = nnf_relu(x)
self$second(x)
}
)
learner = lrn("classif.module",
module_generator = nn_one_layer,
ingress_tokens = list(x = ingress_num()),
epochs = 10,
size_hidden = 20,
batch_size = 16
)
task = tsk("iris")
learner$train(task)
learner$network
#> An `nn_module` containing 163 parameters.
#>
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • first: <nn_linear> #100 parameters
#> • second: <nn_linear> #63 parameters