Creates a torch dataset from an mlr3 Task
.
The resulting dataset's $.get_batch()
method returns a list with elements x
, y
and index
:
x
is a list with tensors, whose content is defined by the parameterfeature_ingress_tokens
.y
is the target variable and its content is defined by the parametertarget_batchgetter
..index
is the index of the batch in the task's data.
The data is returned on the device specified by the parameter device
.
Arguments
- task
- feature_ingress_tokens
(named
list()
ofTorchIngressToken
)
Each ingress token defines one item in the$x
value of a batch with corresponding names.- target_batchgetter
(
function(data, device)
)
A function taking in argumentsdata
, which is adata.table
containing only the target variable, anddevice
. It must return the target as a torch tensor on the selected device.- device
(
character()
)
The device, e.g."cuda"
or"cpu"
.
Examples
task = tsk("iris")
sepal_ingress = TorchIngressToken(
features = c("Sepal.Length", "Sepal.Width"),
batchgetter = batchgetter_num,
shape = c(NA, 2)
)
petal_ingress = TorchIngressToken(
features = c("Petal.Length", "Petal.Width"),
batchgetter = batchgetter_num,
shape = c(NA, 2)
)
ingress_tokens = list(sepal = sepal_ingress, petal = petal_ingress)
target_batchgetter = function(data, device) {
torch_tensor(data = data[[1L]], dtype = torch_float32(), device)$unsqueeze(2)
}
dataset = task_dataset(task, ingress_tokens, target_batchgetter, "cpu")
batch = dataset$.getbatch(1:10)
batch
#> $x
#> $x$sepal
#> torch_tensor
#> 5.1000 3.5000
#> 4.9000 3.0000
#> 4.7000 3.2000
#> 4.6000 3.1000
#> 5.0000 3.6000
#> 5.4000 3.9000
#> 4.6000 3.4000
#> 5.0000 3.4000
#> 4.4000 2.9000
#> 4.9000 3.1000
#> [ CPUFloatType{10,2} ]
#>
#> $x$petal
#> torch_tensor
#> 1.4000 0.2000
#> 1.4000 0.2000
#> 1.3000 0.2000
#> 1.5000 0.2000
#> 1.4000 0.2000
#> 1.7000 0.4000
#> 1.4000 0.3000
#> 1.5000 0.2000
#> 1.4000 0.2000
#> 1.5000 0.1000
#> [ CPUFloatType{10,2} ]
#>
#>
#> $.index
#> torch_tensor
#> 1
#> 2
#> 3
#> 4
#> 5
#> 6
#> 7
#> 8
#> 9
#> 10
#> [ CPULongType{10} ]
#>
#> $y
#> torch_tensor
#> 1
#> 1
#> 1
#> 1
#> 1
#> 1
#> 1
#> 1
#> 1
#> 1
#> [ CPUFloatType{10,1} ]
#>