Skip to contents

Create a torch learner from an instantiated nn_module(). For classification, the output of the network must be the scores (before the softmax).

Parameters

See LearnerTorch

Super classes

mlr3::Learner -> mlr3torch::LearnerTorch -> LearnerTorchModel

Active bindings

network_stored

(nn_module or NULL)
The network that will be trained. After calling $train(), this is NULL.

ingress_tokens

(named list() with TorchIngressToken or NULL)
The ingress tokens. Must be non-NULL when calling $train().

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

LearnerTorchModel$new(
  network = NULL,
  ingress_tokens = NULL,
  task_type,
  properties = NULL,
  optimizer = NULL,
  loss = NULL,
  callbacks = list(),
  packages = character(0),
  feature_types = NULL
)

Arguments

network

(nn_module)
An instantiated nn_module. Is not cloned during construction. For classification, outputs must be the scores (before the softmax).

ingress_tokens

(list of TorchIngressToken())
A list with ingress tokens that defines how the dataloader will be defined.

task_type

(character(1))
The task type.

properties

(NULL or character())
The properties of the learner. Defaults to all available properties for the given task type.

optimizer

(TorchOptimizer)
The torch optimizer.

loss

(TorchLoss)
The loss to use for training.

callbacks

(list() of TorchCallbacks)
The callbacks used during training. Must have unique ids. They are executed in the order in which they are provided

packages

(character())
The R packages this object depends on.

feature_types

(NULL or character())
The feature types. Defaults to all available feature types.


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerTorchModel$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# We show the learner using a classification task

# The iris task has 4 features and 3 classes
network = nn_linear(4, 3)
task = tsk("iris")

# This defines the dataloader.
# It loads all 4 features, which are also numeric.
# The shape is (NA, 4) because the batch dimension is generally NA
ingress_tokens = list(
  input = TorchIngressToken(task$feature_names, batchgetter_num, c(NA, 4))
)

# Creating the learner and setting required parameters
learner = lrn("classif.torch_model",
  network = network,
  ingress_tokens = ingress_tokens,
  batch_size = 16,
  epochs = 1,
  device = "cpu"
)

# A simple train-predict
ids = partition(task)
learner$train(task, ids$train)
learner$predict(task, ids$test)
#> <PredictionClassif> for 50 observations:
#>  row_ids     truth   response
#>        3    setosa versicolor
#>        9    setosa versicolor
#>       10    setosa versicolor
#>      ---       ---        ---
#>      147 virginica     setosa
#>      149 virginica     setosa
#>      150 virginica     setosa