Skip to contents

Classic image classification networks from torchvision.

Parameters

Parameters from LearnerTorchImage and

  • pretrained :: logical(1)
    Whether to use the pretrained model. The final linear layer will be replaced with a new nn_linear with the number of classes inferred from the Task.

Properties

  • Supported task types: "classif"

  • Predict Types: "response" and "prob"

  • Feature Types: "lazy_tensor"

  • Required packages: "mlr3torch", "torch", "torchvision"

Super classes

mlr3::Learner -> mlr3torch::LearnerTorch -> mlr3torch::LearnerTorchImage -> LearnerTorchVision

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

LearnerTorchVision$new(
  name,
  module_generator,
  label,
  optimizer = NULL,
  loss = NULL,
  callbacks = list()
)

Arguments

name

(character(1))
The name of the network.

module_generator

(function(pretrained, num_classes))
Function that generates the network.

label

(character(1))
The label of the network. #' @references Krizhevsky, Alex, Sutskever, Ilya, Hinton, E. G (2017). “Imagenet classification with deep convolutional neural networks.” Communications of the ACM, 60(6), 84–90. Sandler, Mark, Howard, Andrew, Zhu, Menglong, Zhmoginov, Andrey, Chen, Liang-Chieh (2018). “Mobilenetv2: Inverted residuals and linear bottlenecks.” In Proceedings of the IEEE conference on computer vision and pattern recognition, 4510–4520. He, Kaiming, Zhang, Xiangyu, Ren, Shaoqing, Sun, Jian (2016). “Deep residual learning for image recognition.” In Proceedings of the IEEE conference on computer vision and pattern recognition, 770–778. Simonyan, Karen, Zisserman, Andrew (2014). “Very deep convolutional networks for large-scale image recognition.” arXiv preprint arXiv:1409.1556.

optimizer

(TorchOptimizer)
The optimizer to use for training. Per default, adam is used.

loss

(TorchLoss)
The loss used to train the network. Per default, mse is used for regression and cross_entropy for classification.

callbacks

(list() of TorchCallbacks)
The callbacks. Must have unique ids.


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerTorchVision$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.