Skip to contents

This wraps a torch::nn_loss and annotates it with metadata, most importantly a ParamSet. The loss function is created for the given parameter values by calling the $generate() method.

This class is usually used to configure the loss function of a torch learner, e.g. when construcing a learner or in a ModelDescriptor.

For a list of available losses, see mlr3torch_losses. Items from this dictionary can be retrieved using t_loss().

Parameters

Defined by the constructor argument param_set. If no parameter set is provided during construction, the parameter set is constructed by creating a parameter for each argument of the wrapped loss function, where the parametes are then of type ParamUty.

Super class

mlr3torch::TorchDescriptor -> TorchLoss

Public fields

task_types

(character())
The task types this loss supports.

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

TorchLoss$new(
  torch_loss,
  task_types = NULL,
  param_set = NULL,
  id = NULL,
  label = NULL,
  packages = NULL,
  man = NULL
)

Arguments

torch_loss

(nn_loss)
The loss module.

task_types

(character())
The task types supported by this loss.

param_set

(ParamSet or NULL)
The parameter set. If NULL (default) it is inferred from torch_loss.

id

(character(1))
The id for of the new object.

label

(character(1))
Label for the new instance.

packages

(character())
The R packages this object depends on.

man

(character(1))
String in the format [pkg]::[topic] pointing to a manual page for this object. The referenced help package can be opened via method $help().


Method print()

Prints the object

Usage

TorchLoss$print(...)

Arguments

...

any


Method clone()

The objects of this class are cloneable with this method.

Usage

TorchLoss$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# Create a new torch loss
torch_loss = TorchLoss$new(torch_loss = nn_mse_loss, task_types = "regr")
torch_loss
#> <TorchLoss:nn_mse_loss> nn_mse_loss
#> * Generator: nn_mse_loss
#> * Parameters: list()
#> * Packages: torch,mlr3torch
#> * Task Types: regr
# the parameters are inferred
torch_loss$param_set
#> <ParamSet(1)>
#>           id    class lower upper nlevels        default  value
#>       <char>   <char> <num> <num>   <num>         <list> <list>
#> 1: reduction ParamUty    NA    NA     Inf <NoDefault[0]> [NULL]

# Retrieve a loss from the dictionary:
torch_loss = t_loss("mse", reduction = "mean")
# is the same as
torch_loss
#> <TorchLoss:mse> Mean Squared Error
#> * Generator: nn_mse_loss
#> * Parameters: reduction=mean
#> * Packages: torch,mlr3torch
#> * Task Types: regr
torch_loss$param_set
#> <ParamSet(1)>
#>           id    class lower upper nlevels default  value
#>       <char>   <char> <num> <num>   <num>  <list> <list>
#> 1: reduction ParamFct    NA    NA       2    mean   mean
torch_loss$label
#> [1] "Mean Squared Error"
torch_loss$task_types
#> [1] "regr"
torch_loss$id
#> [1] "mse"

# Create the loss function
loss_fn = torch_loss$generate()
loss_fn
#> An `nn_module` containing 0 parameters.
# Is the same as
nn_mse_loss(reduction = "mean")
#> An `nn_module` containing 0 parameters.

# open the help page of the wrapped loss function
# torch_loss$help()

# Use in a learner
learner = lrn("regr.mlp", loss = t_loss("mse"))
# The parameters of the loss are added to the learner's parameter set
learner$param_set
#> <ParamSetCollection(32)>
#>                   id    class lower upper nlevels        default        value
#>               <char>   <char> <num> <num>   <num>         <list>       <list>
#>  1:           epochs ParamInt 0e+00   Inf     Inf <NoDefault[0]>       [NULL]
#>  2:           device ParamFct    NA    NA      12 <NoDefault[0]>         auto
#>  3:      num_threads ParamInt 1e+00   Inf     Inf <NoDefault[0]>            1
#>  4:             seed ParamInt  -Inf   Inf     Inf <NoDefault[0]>       random
#>  5:        eval_freq ParamInt 1e+00   Inf     Inf <NoDefault[0]>            1
#>  6:   measures_train ParamUty    NA    NA     Inf <NoDefault[0]>    <list[0]>
#>  7:   measures_valid ParamUty    NA    NA     Inf <NoDefault[0]>    <list[0]>
#>  8:         patience ParamInt 0e+00   Inf     Inf <NoDefault[0]>            0
#>  9:        min_delta ParamDbl 0e+00   Inf     Inf <NoDefault[0]>            0
#> 10:       batch_size ParamInt 1e+00   Inf     Inf <NoDefault[0]>       [NULL]
#> 11:          shuffle ParamLgl    NA    NA       2          FALSE       [NULL]
#> 12:          sampler ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
#> 13:    batch_sampler ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
#> 14:      num_workers ParamInt 0e+00   Inf     Inf              0       [NULL]
#> 15:       collate_fn ParamUty    NA    NA     Inf         [NULL]       [NULL]
#> 16:       pin_memory ParamLgl    NA    NA       2          FALSE       [NULL]
#> 17:        drop_last ParamLgl    NA    NA       2          FALSE       [NULL]
#> 18:          timeout ParamDbl  -Inf   Inf     Inf             -1       [NULL]
#> 19:   worker_init_fn ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
#> 20:   worker_globals ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
#> 21:  worker_packages ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
#> 22:          neurons ParamUty    NA    NA     Inf <NoDefault[0]>             
#> 23:                p ParamDbl 0e+00 1e+00     Inf <NoDefault[0]>          0.5
#> 24:       activation ParamUty    NA    NA     Inf <NoDefault[0]> <nn_relu[1]>
#> 25:  activation_args ParamUty    NA    NA     Inf <NoDefault[0]>    <list[0]>
#> 26:            shape ParamUty    NA    NA     Inf <NoDefault[0]>       [NULL]
#> 27:           opt.lr ParamDbl 0e+00   Inf     Inf          0.001       [NULL]
#> 28:        opt.betas ParamUty    NA    NA     Inf    0.900,0.999       [NULL]
#> 29:          opt.eps ParamDbl 1e-16 1e-04     Inf          1e-08       [NULL]
#> 30: opt.weight_decay ParamDbl 0e+00 1e+00     Inf              0       [NULL]
#> 31:      opt.amsgrad ParamLgl    NA    NA       2          FALSE       [NULL]
#> 32:   loss.reduction ParamFct    NA    NA       2           mean       [NULL]
#>                   id    class lower upper nlevels        default        value