Skip to contents

In this vignette, we will show how to build neural network architectures as mlr3pipelines::Graphss. We will create a simple CNN for the tiny-imagenet task, which is a subset of well-known Imagenet benchmark.

library(mlr3torch)
imagenet = tsk("tiny_imagenet")
imagenet
#> 
#> ── <TaskClassif> (110000x2): ImageNet Subset ───────────────────────────────────
#> • Target: class
#> WARN  [12:51:47.407] Caching (option 'mlr3torch.cache') is disabled, but dataset requires disk storage. This can lead to unexpected behavior.
#> • Target classes: abacus (0%), academic gown, academic robe, judge's robe (0%),
#> acorn (0%), African elephant, Loxodonta africana (0%), albatross, mollymawk
#> (0%), alp (0%), altar (0%), American alligator, Alligator mississipiensis (0%),
#> American lobster, Northern lobster, Maine lobster, Homarus americanus (0%),
#> apron (0%), Arabian camel, dromedary, Camelus dromedarius (0%), baboon (0%),
#> backpack, back pack, knapsack, packsack, rucksack, haversack (0%), banana (0%),
#> bannister, banister, balustrade, balusters, handrail (0%), barbershop (0%),
#> barn (0%), barrel, cask (0%), basketball (0%), bathtub, bathing tub, bath, tub
#> (0%), beach wagon, station wagon, wagon, estate car, beach waggon, station
#> waggon, waggon (0%), beacon, lighthouse, beacon light, pharos (0%), beaker
#> (0%), bee (0%), beer bottle (0%), bell pepper (0%), bighorn, bighorn sheep,
#> cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis (0%),
#> bikini, two-piece (0%), binoculars, field glasses, opera glasses (0%),
#> birdhouse (0%), bison (0%), black stork, Ciconia nigra (0%), black widow,
#> Latrodectus mactans (0%), boa constrictor, Constrictor constrictor (0%), bow
#> tie, bow-tie, bowtie (0%), brain coral (0%), brass, memorial tablet, plaque
#> (0%), broom (0%), brown bear, bruin, Ursus arctos (0%), bucket, pail (0%),
#> bullet train, bullet (0%), bullfrog, Rana catesbeiana (0%), butcher shop, meat
#> market (0%), candle, taper, wax light (0%), cannon (0%), cardigan (0%), cash
#> machine, cash dispenser, automated teller machine, automatic teller machine,
#> automated teller, automatic teller, ATM (0%), cauliflower (0%), CD player (0%),
#> centipede (0%), chain (0%), chest (0%), Chihuahua (0%), chimpanzee, chimp, Pan
#> troglodytes (0%), Christmas stocking (0%), cliff dwelling (0%), cliff, drop,
#> drop-off (0%), cockroach, roach (0%), comic book (0%), computer keyboard,
#> keypad (0%), confectionery, confectionary, candy store (0%), convertible (0%),
#> coral reef (0%), cougar, puma, catamount, mountain lion, painter, panther,
#> Felis concolor (0%), crane (0%), dam, dike, dyke (0%), desk (0%), dining table,
#> board (0%), dragonfly, darning needle, devil's darning needle, sewing needle,
#> snake feeder, snake doctor, mosquito hawk, skeeter hawk (0%), drumstick (0%),
#> dugong, Dugong dugon (0%), dumbbell (0%), Egyptian cat (0%), espresso (0%),
#> European fire salamander, Salamandra salamandra (0%), flagpole, flagstaff (0%),
#> fly (0%), fountain (0%), freight car (0%), frying pan, frypan, skillet (0%),
#> fur coat (0%), gasmask, respirator, gas helmet (0%), gazelle (0%), German
#> shepherd, German shepherd dog, German police dog, alsatian (0%), go-kart (0%),
#> golden retriever (0%), goldfish, Carassius auratus (0%), gondola (0%), goose
#> (0%), grasshopper, hopper (0%), guacamole (0%), guinea pig, Cavia cobaya (0%),
#> hog, pig, grunter, squealer, Sus scrofa (0%), hourglass (0%), ice cream,
#> icecream (0%), ice lolly, lolly, lollipop, popsicle (0%), iPod (0%), jellyfish
#> (0%), jinrikisha, ricksha, rickshaw (0%), kimono (0%), king penguin,
#> Aptenodytes patagonica (0%), koala, koala bear, kangaroo bear, native bear,
#> Phascolarctos cinereus (0%), Labrador retriever (0%), ladybug, ladybeetle, lady
#> beetle, ladybird, ladybird beetle (0%), lakeside, lakeshore (0%), lampshade,
#> lamp shade (0%), lawn mower, mower (0%), lemon (0%), lesser panda, red panda,
#> panda, bear cat, cat bear, Ailurus fulgens (0%), lifeboat (0%), limousine, limo
#> (0%), lion, king of beasts, Panthera leo (0%), magnetic compass (0%), mantis,
#> mantid (0%), mashed potato (0%), maypole (0%), meat loaf, meatloaf (0%),
#> military uniform (0%), miniskirt, mini (0%), monarch, monarch butterfly,
#> milkweed butterfly, Danaus plexippus (0%), moving van (0%), mushroom (0%), nail
#> (0%), neck brace (0%), obelisk (0%), oboe, hautboy, hautbois (0%), orange (0%),
#> orangutan, orang, orangutang, Pongo pygmaeus (0%), organ, pipe organ (0%), ox
#> (0%), parking meter (0%), pay-phone, pay-station (0%), Persian cat (0%), picket
#> fence, paling (0%), pill bottle (0%), pizza, pizza pie (0%), plate (0%),
#> plunger, plumber's helper (0%), pole (0%), police van, police wagon, paddy
#> wagon, patrol wagon, wagon, black Maria (0%), pomegranate (0%), poncho (0%),
#> pop bottle, soda bottle (0%), potpie (0%), potter's wheel (0%), pretzel (0%),
#> projectile, missile (0%), punching bag, punch bag, punching ball, punchball
#> (0%), reel (0%), refrigerator, icebox (0%), remote control, remote (0%),
#> rocking chair, rocker (0%), rugby ball (0%), sandal (0%), school bus (0%),
#> scoreboard (0%), scorpion (0%), sea cucumber, holothurian (0%), sea slug,
#> nudibranch (0%), seashore, coast, seacoast, sea-coast (0%), sewing machine
#> (0%), slug (0%), snail (0%), snorkel (0%), sock (0%), sombrero (0%), space
#> heater (0%), spider web, spider's web (0%), spiny lobster, langouste, rock
#> lobster, crawfish, crayfish, sea crawfish (0%), sports car, sport car (0%),
#> standard poodle (0%), steel arch bridge (0%), stopwatch, stop watch (0%),
#> sulphur butterfly, sulfur butterfly (0%), sunglasses, dark glasses, shades
#> (0%), suspension bridge (0%), swimming trunks, bathing trunks (0%), syringe
#> (0%), tabby, tabby cat (0%), tailed frog, bell toad, ribbed toad, tailed toad,
#> Ascaphus trui (0%), tarantula (0%), teapot (0%), teddy, teddy bear (0%),
#> thatch, thatched roof (0%), torch (0%), tractor (0%), trilobite (0%), triumphal
#> arch (0%), trolleybus, trolley coach, trackless trolley (0%), turnstile (0%),
#> umbrella (0%), vestment (0%), viaduct (0%), volleyball (0%), walking stick,
#> walkingstick, stick insect (0%), water jug (0%), water tower (0%), wok (0%),
#> wooden spoon (0%), Yorkshire terrier (0%)
#> • Properties: multiclass
#> • Features (1):
#>   • lt (1): image

The central ingredients for creating such graphs are PipeOpTorch operators.

To mark the entry-point of the neural network, we use a PipeOpTorchIngress, for which three different flavors exist:

  • po("torch_ingress_num") for numeric data
  • po("torch_ingress_categ") for categorical columns
  • po("torch_ingress_ltnsr") for lazy_tensors

Because the imagenet task contains only one feature of type lazy_tensor, we go for the last option:

architecture = po("torch_ingress_ltnsr")

We now define a relatively simple convolutional neural network. Note that in the code below po("nn_relu_1") is equivalent to po("nn_relu", id = "nn_linear_1"). This is needed, because mlr3pipelines::Graphs require that each PipeOp has a unique ID.

What we can further notice is that we don’t have to specify the input dimension for the convolutional layers, which are inferred from the task during $train()ing. This means that our Learner can be applied to tasks with different image sizes, each time building up the correct network structure.

architecture = architecture %>>%
  po("nn_conv2d_1", out_channels = 64, kernel_size = 11, stride = 4, padding = 2) %>>%
  po("nn_relu_1", inplace = TRUE) %>>%
  po("nn_max_pool2d_1", kernel_size = 3, stride = 2) %>>%
  po("nn_conv2d_2", out_channels = 192, kernel_size = 5, padding = 2) %>>%
  po("nn_relu_2", inplace = TRUE) %>>%
  po("nn_max_pool2d_2", kernel_size = 3, stride = 2)

We can now continue with specifying the classification part of the network, which is a dense network that repeats a layer twice:

dense_layer = po("nn_dropout") %>>%
  po("nn_linear", out_features = 4096) %>>%
  po("nn_relu_6")

In order to repeat a segment from a network multiple times, we can use po("nn_block"), which we here repeat twice. Then, we follow with the output head of the network, where we don’t have to specify the number of classes, as they can also be inferred from the task

classifier = po("nn_block", dense_layer, n_blocks = 2L) %>>%
  po("nn_head")

Next, we can combine the convolutional part with the dense head:

architecture = architecture %>>%
  po("nn_flatten") %>>%
  classifier

Below, we display the network:

architecture$plot(html = TRUE)

To turn this network architecture into an mlr3::Learner what is left to do is to configure the loss, optimizer, callbacks, and training arguments, which we do now: We use the standard cross-entropy loss, SGD as the optimizer and checkpoint our model every 20 epochs.

checkpoint = tempfile()
architecture = architecture %>>%
  po("torch_loss", t_loss("cross_entropy")) %>>%
  po("torch_optimizer", t_opt("sgd", lr=0.01)) %>>%
  po("torch_callbacks", 
    t_clbk("checkpoint", freq = 20, path = checkpoint)) %>>%
  po("torch_model_classif",
    batch_size = 32, epochs = 100L, device = "cuda")

cnn = as_learner(architecture)
cnn$id = "cnn"

This created Learner now exposes all configuration options of the individual PipeOps in its $param_set, from which we show only a subset for readability:

as.data.table(cnn$param_set)[c(32, 34, 42), 1:4]
#>                       id    class lower upper
#>                   <char>   <char> <num> <num>
#> 1:     nn_block.n_blocks ParamInt     0   Inf
#> 2: nn_block.nn_dropout.p ParamDbl     0     1
#> 3:  torch_loss.reduction ParamFct    NA    NA

We can still change them, or if we wanted to, even tune them! Below, we increase the number of blocks and latent dimension of the dense part, as well as change the learning rate of the SGD optimizer.

cnn$param_set$set_values(
  nn_block.n_blocks = 4L,
  nn_block.nn_linear.out_features = 4096 * 2,
  torch_optimizer.lr = 0.2
)

Finally, we train the learner on the task:

cnn$train(imagenet)