Create a torch learner from a torch module.
See also
Other Learner:
mlr_learners.mlp
,
mlr_learners.tab_resnet
,
mlr_learners.torch_featureless
,
mlr_learners_torch
,
mlr_learners_torch_image
,
mlr_learners_torch_model
Other Learner:
mlr_learners.mlp
,
mlr_learners.tab_resnet
,
mlr_learners.torch_featureless
,
mlr_learners_torch
,
mlr_learners_torch_image
,
mlr_learners_torch_model
Super classes
mlr3::Learner
-> mlr3torch::LearnerTorch
-> LearnerTorchModule
Methods
Inherited methods
mlr3::Learner$base_learner()
mlr3::Learner$configure()
mlr3::Learner$encapsulate()
mlr3::Learner$help()
mlr3::Learner$predict()
mlr3::Learner$predict_newdata()
mlr3::Learner$reset()
mlr3::Learner$selected_features()
mlr3::Learner$train()
mlr3torch::LearnerTorch$dataset()
mlr3torch::LearnerTorch$format()
mlr3torch::LearnerTorch$marshal()
mlr3torch::LearnerTorch$print()
mlr3torch::LearnerTorch$unmarshal()
Method new()
Creates a new instance of this R6 class.
Usage
LearnerTorchModule$new(
module_generator = NULL,
param_set = NULL,
ingress_tokens = NULL,
task_type,
properties = NULL,
optimizer = NULL,
loss = NULL,
callbacks = list(),
packages = character(0),
feature_types = NULL
)
Arguments
module_generator
(
function
ornn_module_generator
)
Ann_module_generator
orfunction
returning annn_module
. Both must take as argument thetask
for which to construct the network. Other arguments to its initialize method can be provided as parameters.param_set
(
NULL
orParamSet
)
If provided, contains the parameters for the module_generator. IfNULL
, parameters will be inferred from the module_generator.ingress_tokens
(
list
ofTorchIngressToken()
)
A list with ingress tokens that defines how the dataset will be defined. The names must correspond to the arguments of the network's forward method. For numeric, categorical, and lazy tensor features, you can useingress_num()
,ingress_categ()
, andingress_ltnsr()
to create them.task_type
(
character(1)
)
The task type, either"classif
" or"regr"
.task_type
(
character(1)
)
The task type.properties
(
NULL
orcharacter()
)
The properties of the learner. Defaults to all available properties for the given task type.optimizer
(
TorchOptimizer
)
The optimizer to use for training. Per default, adam is used.loss
(
TorchLoss
)
The loss used to train the network. Per default, mse is used for regression and cross_entropy for classification.callbacks
(
list()
ofTorchCallback
s)
The callbacks. Must have unique ids.packages
(
character()
)
The R packages this object depends on.feature_types
(
NULL
orcharacter()
)
The feature types. Defaults to all available feature types.
Examples
nn_one_layer = nn_module("nn_one_layer",
initialize = function(task, size_hidden) {
self$first = nn_linear(task$n_features, size_hidden)
self$second = nn_linear(size_hidden, length(task$class_names))
},
# argument x corresponds to the ingress token x
forward = function(x) {
x = self$first(x)
x = nnf_relu(x)
self$second(x)
}
)
learner = lrn("classif.module",
module_generator = nn_one_layer,
ingress_tokens = list(x = ingress_num()),
epochs = 10,
size_hidden = 20,
batch_size = 16
)
task = tsk("iris")
learner$train(task)
learner$network
#> An `nn_module` containing 163 parameters.
#>
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • first: <nn_linear> #100 parameters
#> • second: <nn_linear> #63 parameters