Skip to contents

This function creates an S3 class of class "TorchIngressToken", which is an internal data structure. It contains the (meta-)information of how a batch is generated from a Task and fed into an entry point of the neural network. It is stored as the ingress field in a ModelDescriptor.

Usage

TorchIngressToken(features, batchgetter, shape)

Arguments

features

(character)
Features on which the batchgetter will operate.

batchgetter

(function)
Function with two arguments: data and device. This function is given the output of Task$data(rows = batch_indices, cols = features) and it should produce a tensor of shape shape_out.

shape

(integer)
Shape that batchgetter will produce. Batch-dimension should be included as NA.

Value

TorchIngressToken object.

Examples

# Define a task for which we want to define an ingress token
task = tsk("iris")

# We create an ingress token for two feature Sepal.Length and Petal.Length:
# We have to specify the features, the batchgetter and the shape
features = c("Sepal.Length", "Petal.Length")
# As a batchgetter we use batchgetter_num

batch_dt = task$data(rows = 1:10, cols =features)
batch_dt
#>     Sepal.Length Petal.Length
#>            <num>        <num>
#>  1:          5.1          1.4
#>  2:          4.9          1.4
#>  3:          4.7          1.3
#>  4:          4.6          1.5
#>  5:          5.0          1.4
#>  6:          5.4          1.7
#>  7:          4.6          1.4
#>  8:          5.0          1.5
#>  9:          4.4          1.4
#> 10:          4.9          1.5
batch_tensor = batchgetter_num(batch_dt, "cpu")
batch_tensor
#> torch_tensor
#>  5.1000  1.4000
#>  4.9000  1.4000
#>  4.7000  1.3000
#>  4.6000  1.5000
#>  5.0000  1.4000
#>  5.4000  1.7000
#>  4.6000  1.4000
#>  5.0000  1.5000
#>  4.4000  1.4000
#>  4.9000  1.5000
#> [ CPUFloatType{10,2} ]

# The shape is unknown in the first dimension (batch dimension)

ingress_token = TorchIngressToken(
  features = features,
  batchgetter = batchgetter_num,
  shape = c(NA, 2)
)
ingress_token
#> Ingress: Task[Sepal.Length,Petal.Length] --> Tensor(NA, 2)