Skip to contents

Create a torch learner from a torch module.

Dictionary

This Learner can be instantiated using the sugar function lrn():

lrn("classif.module", ...)
lrn("regr.module", ...)

Properties

  • Supported task types: 'classif', 'regr'

  • Predict Types:

    • classif: 'response', 'prob'

    • regr: 'response'

  • Feature Types: “logical”, “integer”, “numeric”, “character”, “factor”, “ordered”, “POSIXct”, “Date”, “lazy_tensor”

  • Required Packages: mlr3, mlr3torch, torch

Super classes

mlr3::Learner -> mlr3torch::LearnerTorch -> LearnerTorchModule

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

LearnerTorchModule$new(
  module_generator = NULL,
  param_set = NULL,
  ingress_tokens = NULL,
  task_type,
  properties = NULL,
  optimizer = NULL,
  loss = NULL,
  callbacks = list(),
  packages = character(0),
  feature_types = NULL
)

Arguments

module_generator

(function or nn_module_generator)
A nn_module_generator or function returning an nn_module. Both must take as argument the task for which to construct the network. Other arguments to its initialize method can be provided as parameters.

param_set

(NULL or ParamSet)
If provided, contains the parameters for the module_generator. If NULL, parameters will be inferred from the module_generator.

ingress_tokens

(list of TorchIngressToken())
A list with ingress tokens that defines how the dataset will be defined. The names must correspond to the arguments of the network's forward method. For numeric, categorical, and lazy tensor features, you can use ingress_num(), ingress_categ(), and ingress_ltnsr() to create them.

task_type

(character(1))
The task type, either "classif" or "regr".

task_type

(character(1))
The task type.

properties

(NULL or character())
The properties of the learner. Defaults to all available properties for the given task type.

optimizer

(TorchOptimizer)
The optimizer to use for training. Per default, adam is used.

loss

(TorchLoss)
The loss used to train the network. Per default, mse is used for regression and cross_entropy for classification.

callbacks

(list() of TorchCallbacks)
The callbacks. Must have unique ids.

packages

(character())
The R packages this object depends on.

feature_types

(NULL or character())
The feature types. Defaults to all available feature types.


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerTorchModule$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

nn_one_layer = nn_module("nn_one_layer",
  initialize = function(task, size_hidden) {
    self$first = nn_linear(task$n_features, size_hidden)
    self$second = nn_linear(size_hidden, length(task$class_names))
  },
  # argument x corresponds to the ingress token x
  forward = function(x) {
    x = self$first(x)
    x = nnf_relu(x)
    self$second(x)
  }
)
learner = lrn("classif.module",
  module_generator = nn_one_layer,
  ingress_tokens = list(x = ingress_num()),
  epochs = 10,
  size_hidden = 20,
  batch_size = 16
)
task = tsk("iris")
learner$train(task)
learner$network
#> An `nn_module` containing 163 parameters.
#> 
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • first: <nn_linear> #100 parameters
#> • second: <nn_linear> #63 parameters