The cross_entropy
loss function selects the multi-class (nn_cross_entropy_loss
)
or binary (nn_bce_with_logits_loss
) cross entropy
loss based on the number of classes.
Because of this, there is a slight reparameterization of the loss arguments, see Parameters.
Parameters
class_weight
::torch_tensor
The class weights. For multi-class problems, this must be atorch_tensor
of lengthnum_classes
(and is passed as argumentweight
tonn_cross_entropy_loss
). For binary problems, this must be a scalar (and is passed as argumentpos_weight
tonn_bce_with_logits_loss
).
ignore_index
::integer(1)
Index of the class which to ignore and which does not contribute to the gradient. This is only available for multi-class loss.reduction
::character(1)
The reduction to apply. Is either"mean"
or"sum"
and passed as argumentreduction
to either loss function. The default is"mean"
.