In this vignette, we will show how to build neural network
architectures as mlr3pipelines::Graphs
s. We will create a
simple CNN for the tiny-imagenet task, which is a subset of well-known
Imagenet benchmark.
library(mlr3torch)
imagenet = tsk("tiny_imagenet")
imagenet
#>
#> ── <TaskClassif> (110000x2): ImageNet Subset ───────────────────────────────────
#> • Target: class
#> WARN [12:51:47.407] Caching (option 'mlr3torch.cache') is disabled, but dataset requires disk storage. This can lead to unexpected behavior.
#> • Target classes: abacus (0%), academic gown, academic robe, judge's robe (0%),
#> acorn (0%), African elephant, Loxodonta africana (0%), albatross, mollymawk
#> (0%), alp (0%), altar (0%), American alligator, Alligator mississipiensis (0%),
#> American lobster, Northern lobster, Maine lobster, Homarus americanus (0%),
#> apron (0%), Arabian camel, dromedary, Camelus dromedarius (0%), baboon (0%),
#> backpack, back pack, knapsack, packsack, rucksack, haversack (0%), banana (0%),
#> bannister, banister, balustrade, balusters, handrail (0%), barbershop (0%),
#> barn (0%), barrel, cask (0%), basketball (0%), bathtub, bathing tub, bath, tub
#> (0%), beach wagon, station wagon, wagon, estate car, beach waggon, station
#> waggon, waggon (0%), beacon, lighthouse, beacon light, pharos (0%), beaker
#> (0%), bee (0%), beer bottle (0%), bell pepper (0%), bighorn, bighorn sheep,
#> cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis (0%),
#> bikini, two-piece (0%), binoculars, field glasses, opera glasses (0%),
#> birdhouse (0%), bison (0%), black stork, Ciconia nigra (0%), black widow,
#> Latrodectus mactans (0%), boa constrictor, Constrictor constrictor (0%), bow
#> tie, bow-tie, bowtie (0%), brain coral (0%), brass, memorial tablet, plaque
#> (0%), broom (0%), brown bear, bruin, Ursus arctos (0%), bucket, pail (0%),
#> bullet train, bullet (0%), bullfrog, Rana catesbeiana (0%), butcher shop, meat
#> market (0%), candle, taper, wax light (0%), cannon (0%), cardigan (0%), cash
#> machine, cash dispenser, automated teller machine, automatic teller machine,
#> automated teller, automatic teller, ATM (0%), cauliflower (0%), CD player (0%),
#> centipede (0%), chain (0%), chest (0%), Chihuahua (0%), chimpanzee, chimp, Pan
#> troglodytes (0%), Christmas stocking (0%), cliff dwelling (0%), cliff, drop,
#> drop-off (0%), cockroach, roach (0%), comic book (0%), computer keyboard,
#> keypad (0%), confectionery, confectionary, candy store (0%), convertible (0%),
#> coral reef (0%), cougar, puma, catamount, mountain lion, painter, panther,
#> Felis concolor (0%), crane (0%), dam, dike, dyke (0%), desk (0%), dining table,
#> board (0%), dragonfly, darning needle, devil's darning needle, sewing needle,
#> snake feeder, snake doctor, mosquito hawk, skeeter hawk (0%), drumstick (0%),
#> dugong, Dugong dugon (0%), dumbbell (0%), Egyptian cat (0%), espresso (0%),
#> European fire salamander, Salamandra salamandra (0%), flagpole, flagstaff (0%),
#> fly (0%), fountain (0%), freight car (0%), frying pan, frypan, skillet (0%),
#> fur coat (0%), gasmask, respirator, gas helmet (0%), gazelle (0%), German
#> shepherd, German shepherd dog, German police dog, alsatian (0%), go-kart (0%),
#> golden retriever (0%), goldfish, Carassius auratus (0%), gondola (0%), goose
#> (0%), grasshopper, hopper (0%), guacamole (0%), guinea pig, Cavia cobaya (0%),
#> hog, pig, grunter, squealer, Sus scrofa (0%), hourglass (0%), ice cream,
#> icecream (0%), ice lolly, lolly, lollipop, popsicle (0%), iPod (0%), jellyfish
#> (0%), jinrikisha, ricksha, rickshaw (0%), kimono (0%), king penguin,
#> Aptenodytes patagonica (0%), koala, koala bear, kangaroo bear, native bear,
#> Phascolarctos cinereus (0%), Labrador retriever (0%), ladybug, ladybeetle, lady
#> beetle, ladybird, ladybird beetle (0%), lakeside, lakeshore (0%), lampshade,
#> lamp shade (0%), lawn mower, mower (0%), lemon (0%), lesser panda, red panda,
#> panda, bear cat, cat bear, Ailurus fulgens (0%), lifeboat (0%), limousine, limo
#> (0%), lion, king of beasts, Panthera leo (0%), magnetic compass (0%), mantis,
#> mantid (0%), mashed potato (0%), maypole (0%), meat loaf, meatloaf (0%),
#> military uniform (0%), miniskirt, mini (0%), monarch, monarch butterfly,
#> milkweed butterfly, Danaus plexippus (0%), moving van (0%), mushroom (0%), nail
#> (0%), neck brace (0%), obelisk (0%), oboe, hautboy, hautbois (0%), orange (0%),
#> orangutan, orang, orangutang, Pongo pygmaeus (0%), organ, pipe organ (0%), ox
#> (0%), parking meter (0%), pay-phone, pay-station (0%), Persian cat (0%), picket
#> fence, paling (0%), pill bottle (0%), pizza, pizza pie (0%), plate (0%),
#> plunger, plumber's helper (0%), pole (0%), police van, police wagon, paddy
#> wagon, patrol wagon, wagon, black Maria (0%), pomegranate (0%), poncho (0%),
#> pop bottle, soda bottle (0%), potpie (0%), potter's wheel (0%), pretzel (0%),
#> projectile, missile (0%), punching bag, punch bag, punching ball, punchball
#> (0%), reel (0%), refrigerator, icebox (0%), remote control, remote (0%),
#> rocking chair, rocker (0%), rugby ball (0%), sandal (0%), school bus (0%),
#> scoreboard (0%), scorpion (0%), sea cucumber, holothurian (0%), sea slug,
#> nudibranch (0%), seashore, coast, seacoast, sea-coast (0%), sewing machine
#> (0%), slug (0%), snail (0%), snorkel (0%), sock (0%), sombrero (0%), space
#> heater (0%), spider web, spider's web (0%), spiny lobster, langouste, rock
#> lobster, crawfish, crayfish, sea crawfish (0%), sports car, sport car (0%),
#> standard poodle (0%), steel arch bridge (0%), stopwatch, stop watch (0%),
#> sulphur butterfly, sulfur butterfly (0%), sunglasses, dark glasses, shades
#> (0%), suspension bridge (0%), swimming trunks, bathing trunks (0%), syringe
#> (0%), tabby, tabby cat (0%), tailed frog, bell toad, ribbed toad, tailed toad,
#> Ascaphus trui (0%), tarantula (0%), teapot (0%), teddy, teddy bear (0%),
#> thatch, thatched roof (0%), torch (0%), tractor (0%), trilobite (0%), triumphal
#> arch (0%), trolleybus, trolley coach, trackless trolley (0%), turnstile (0%),
#> umbrella (0%), vestment (0%), viaduct (0%), volleyball (0%), walking stick,
#> walkingstick, stick insect (0%), water jug (0%), water tower (0%), wok (0%),
#> wooden spoon (0%), Yorkshire terrier (0%)
#> • Properties: multiclass
#> • Features (1):
#> • lt (1): image
The central ingredients for creating such graphs are
PipeOpTorch
operators.
To mark the entry-point of the neural network, we use a
PipeOpTorchIngress
, for which three different flavors
exist:
-
po("torch_ingress_num")
for numeric data -
po("torch_ingress_categ")
for categorical columns -
po("torch_ingress_ltnsr")
forlazy_tensor
s
Because the imagenet task contains only one feature of type
lazy_tensor
, we go for the last option:
architecture = po("torch_ingress_ltnsr")
We now define a relatively simple convolutional neural network. Note
that in the code below po("nn_relu_1")
is equivalent to
po("nn_relu", id = "nn_linear_1")
. This is needed, because
mlr3pipelines::Graph
s require that each PipeOp
has a unique ID.
What we can further notice is that we don’t have to specify the input
dimension for the convolutional layers, which are inferred from the task
during $train()
ing. This means that our
Learner
can be applied to tasks with different image sizes,
each time building up the correct network structure.
architecture = architecture %>>%
po("nn_conv2d_1", out_channels = 64, kernel_size = 11, stride = 4, padding = 2) %>>%
po("nn_relu_1", inplace = TRUE) %>>%
po("nn_max_pool2d_1", kernel_size = 3, stride = 2) %>>%
po("nn_conv2d_2", out_channels = 192, kernel_size = 5, padding = 2) %>>%
po("nn_relu_2", inplace = TRUE) %>>%
po("nn_max_pool2d_2", kernel_size = 3, stride = 2)
We can now continue with specifying the classification part of the network, which is a dense network that repeats a layer twice:
In order to repeat a segment from a network multiple times, we can
use po("nn_block")
, which we here repeat twice. Then, we
follow with the output head of the network, where we don’t have to
specify the number of classes, as they can also be inferred from the
task
Next, we can combine the convolutional part with the dense head:
Below, we display the network:
architecture$plot(html = TRUE)
To turn this network architecture into an mlr3::Learner
what is left to do is to configure the loss, optimizer, callbacks, and
training arguments, which we do now: We use the standard cross-entropy
loss, SGD as the optimizer and checkpoint our model every 20 epochs.
checkpoint = tempfile()
architecture = architecture %>>%
po("torch_loss", t_loss("cross_entropy")) %>>%
po("torch_optimizer", t_opt("sgd", lr=0.01)) %>>%
po("torch_callbacks",
t_clbk("checkpoint", freq = 20, path = checkpoint)) %>>%
po("torch_model_classif",
batch_size = 32, epochs = 100L, device = "cuda")
cnn = as_learner(architecture)
cnn$id = "cnn"
This created Learner
now exposes all configuration
options of the individual PipeOp
s in its
$param_set
, from which we show only a subset for
readability:
as.data.table(cnn$param_set)[c(32, 34, 42), 1:4]
#> id class lower upper
#> <char> <char> <num> <num>
#> 1: nn_block.n_blocks ParamInt 0 Inf
#> 2: nn_block.nn_dropout.p ParamDbl 0 1
#> 3: torch_loss.reduction ParamFct NA NA
We can still change them, or if we wanted to, even tune them! Below, we increase the number of blocks and latent dimension of the dense part, as well as change the learning rate of the SGD optimizer.
cnn$param_set$set_values(
nn_block.n_blocks = 4L,
nn_block.nn_linear.out_features = 4096 * 2,
torch_optimizer.lr = 0.2
)
Finally, we train the learner on the task:
cnn$train(imagenet)