Skip to contents

Feature-Tokenizer Transformer for tabular data that can either work on lazy_tensor inputs or on standard tabular features.

Some differences from the paper implementation: no attention compression, no option to have prenormalization in the first layer.

If training is unstable, consider a combination of standardizing features (e.g. using po("scale")), using an adaptive optimizer (e.g. Adam), reducing the learning rate, and using a learning rate scheduler (see CallbackSetLRScheduler for options).

Dictionary

This Learner can be instantiated using the sugar function lrn():

lrn("classif.ft_transformer", ...)
lrn("regr.ft_transformer", ...)

Properties

  • Supported task types: 'classif', 'regr'

  • Predict Types:

    • classif: 'response', 'prob'

    • regr: 'response'

  • Feature Types: “logical”, “integer”, “numeric”, “factor”, “ordered”, “lazy_tensor”

  • Required Packages: mlr3, mlr3torch, torch

Parameters

Parameters from LearnerTorch and PipeOpTorchFTTransformerBlock, as well as:

  • n_blocks :: integer(1)
    The number of transformer blocks.

  • d_token :: integer(1)
    The dimension of the embedding.

  • cardinalities :: integer(1)
    The number of categories for each categorical feature. This only needs to be specified when working with lazy_tensor inputs.

  • init_token :: character(1)
    The initialization method for the embedding weights. Either "uniform" or "normal". "Uniform" by default.

  • ingress_tokens :: named list() or NULL
    A list of TorchIngressTokens. Only required when using lazy tensor features. The names are either "num.input" or "categ.input", and the values are lazy tensor ingress tokens constructed by, e.g. ingress_ltnsr(<num_feat_name>).

References

Gorishniy Y, Rubachev I, Khrulkov V, Babenko A (2021). “Revisiting Deep Learning for Tabular Data.” arXiv, 2106.11959.

Super classes

mlr3::Learner -> mlr3torch::LearnerTorch -> LearnerTorchFTTransformer

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

LearnerTorchFTTransformer$new(
  task_type,
  optimizer = NULL,
  loss = NULL,
  callbacks = list()
)

Arguments

task_type

(character(1))
The task type, either "classif" or "regr".

optimizer

(TorchOptimizer)
The optimizer to use for training. Per default, adam is used.

loss

(TorchLoss)
The loss used to train the network. Per default, mse is used for regression and cross_entropy for classification.

callbacks

(list() of TorchCallbacks)
The callbacks. Must have unique ids.


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerTorchFTTransformer$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# Define the Learner and set parameter values
learner = lrn("classif.ft_transformer")
learner$param_set$set_values(
  epochs = 1, batch_size = 16, device = "cpu",
  n_blocks = 2, d_token = 32, ffn_d_hidden_multiplier = 4/3
)

# Define a Task
task = tsk("iris")

# Create train and test set
ids = partition(task)

# Train the learner on the training ids
learner$train(task, row_ids = ids$train)

# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)

# Score the predictions
predictions$score()
#> classif.ce 
#>       0.26