Classic image classification networks from torchvision
.
Parameters
Parameters from LearnerTorchImage
and
pretrained
::logical(1)
Whether to use the pretrained model. The final linear layer will be replaced with a newnn_linear
with the number of classes inferred from theTask
.
Properties
Supported task types:
"classif"
Predict Types:
"response"
and"prob"
Feature Types:
"lazy_tensor"
Required packages:
"mlr3torch"
,"torch"
,"torchvision"
Super classes
mlr3::Learner
-> mlr3torch::LearnerTorch
-> mlr3torch::LearnerTorchImage
-> LearnerTorchVision
Methods
Inherited methods
mlr3::Learner$base_learner()
mlr3::Learner$encapsulate()
mlr3::Learner$help()
mlr3::Learner$predict()
mlr3::Learner$predict_newdata()
mlr3::Learner$reset()
mlr3::Learner$train()
mlr3torch::LearnerTorch$dataset()
mlr3torch::LearnerTorch$format()
mlr3torch::LearnerTorch$marshal()
mlr3torch::LearnerTorch$print()
mlr3torch::LearnerTorch$unmarshal()
Method new()
Creates a new instance of this R6 class.
Usage
LearnerTorchVision$new(
name,
module_generator,
label,
optimizer = NULL,
loss = NULL,
callbacks = list()
)
Arguments
name
(
character(1)
)
The name of the network.module_generator
(
function(pretrained, num_classes)
)
Function that generates the network.label
(
character(1)
)
The label of the network. #' @references Krizhevsky, Alex, Sutskever, Ilya, Hinton, E. G (2017). “Imagenet classification with deep convolutional neural networks.” Communications of the ACM, 60(6), 84–90. Sandler, Mark, Howard, Andrew, Zhu, Menglong, Zhmoginov, Andrey, Chen, Liang-Chieh (2018). “Mobilenetv2: Inverted residuals and linear bottlenecks.” In Proceedings of the IEEE conference on computer vision and pattern recognition, 4510–4520. He, Kaiming, Zhang, Xiangyu, Ren, Shaoqing, Sun, Jian (2016). “Deep residual learning for image recognition.” In Proceedings of the IEEE conference on computer vision and pattern recognition, 770–778. Simonyan, Karen, Zisserman, Andrew (2014). “Very deep convolutional networks for large-scale image recognition.” arXiv preprint arXiv:1409.1556.optimizer
(
TorchOptimizer
)
The optimizer to use for training. Per default, adam is used.loss
(
TorchLoss
)
The loss used to train the network. Per default, mse is used for regression and cross_entropy for classification.callbacks
(
list()
ofTorchCallback
s)
The callbacks. Must have unique ids.