References

Chapter 10.9.3 and 10.9.4 [ISLR2] An Introduction to Statistical Learning - with Applications in R (2nd Edition). Free access to download the book: https://www.statlearning.com/

To see the help file of a function funcname, type ?funcname.

Getting keras up and running on your computer can be a challenge. The book website <www.statlearning.com> gives step-by-step instructions on how to achieve this. Guidance can also be found at <keras.rstudio.com>.

The torch package has become available as an alternative to the keras package for deep learning. While torch does not require a python installation, the current implementation appears to be a fair bit slower than keras. We include the torch version of the implementation at the end of this lab.

Convolutional Neural Networks with Keras

In this section, we use the keras package, which interfaces to the tensorflow package which in turn links to efficient python code. This code is impressively fast, and the package is well-structured.

# install.packages("keras3") # Uncomment to Run it for the first time
# install.packages("reticulate") # Uncomment to Run it for the first time
library(reticulate)
use_condaenv("r-reticulate", required = TRUE)
library(keras3)
# install_keras(method = "conda", envname = "r-reticulate") # Uncomment to Run it for the first time

In this section we fit a CNN to the CIFAR data, which is available in the keras package. It is arranged in a similar fashion as the MNIST data.

cifar100 <- dataset_cifar100()
names(cifar100)
## [1] "train" "test"
x_train <- cifar100$train$x
g_train <- cifar100$train$y
x_test <- cifar100$test$x
g_test <- cifar100$test$y
dim(x_train)
## [1] 50000    32    32     3
range(x_train[1,,,1])
## [1]  13 255

The array of 50,000 training images has four dimensions: each three-color image is represented as a set of three channels, each of which consists of \(32\times 32\) eight-bit pixels. We standardize as we did for the digits, but keep the array structure. We one-hot encode the response factors to produce a 100-column binary matrix.

x_train <- x_train / 255
x_test <- x_test / 255
y_train <- to_categorical(g_train, 100)
dim(y_train)
## [1] 50000   100

Before we start, we look at some of the training images using the jpeg package.

library(jpeg)
par(mar = c(0, 0, 0, 0), mfrow = c(5, 5))
index <- sample(seq(50000), 25)
for (i in index) plot(as.raster(x_train[i,,, ]))

The as.raster() function converts the feature map so that it can be plotted as a color image.

Here we specify a moderately-sized CNN for demonstration purposes.

model <- keras_model_sequential(
  layers = list(
    layer_input(shape = c(32L, 32L, 3L)),
    layer_conv_2d(filters = 32, kernel_size = c(3L, 3L),
                  padding = "same", activation = "relu"),
    layer_max_pooling_2d(pool_size = c(2L, 2L)),
    layer_conv_2d(filters = 64, kernel_size = c(3L, 3L),
                  padding = "same", activation = "relu"),
    layer_max_pooling_2d(pool_size = c(2L, 2L)),
    layer_conv_2d(filters = 128, kernel_size = c(3L, 3L),
                  padding = "same", activation = "relu"),
    layer_max_pooling_2d(pool_size = c(2L, 2L)),
    layer_conv_2d(filters = 256, kernel_size = c(3L, 3L),
                  padding = "same", activation = "relu"),
    layer_max_pooling_2d(pool_size = c(2L, 2L)),
    layer_flatten(),
    layer_dropout(rate = 0.5),
    layer_dense(units = 512, activation = "relu"),
    layer_dense(units = 100, activation = "softmax")
  )
)

model |> summary()
## Model: "sequential"
## ┌───────────────────────────────────┬──────────────────────────┬───────────────
## │ Layer (type)                      │ Output Shape             │       Param # 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ conv2d (Conv2D)                   │ (None, 32, 32, 32)       │           896 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ max_pooling2d (MaxPooling2D)      │ (None, 16, 16, 32)       │             0 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ conv2d_1 (Conv2D)                 │ (None, 16, 16, 64)       │        18,496 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ max_pooling2d_1 (MaxPooling2D)    │ (None, 8, 8, 64)         │             0 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ conv2d_2 (Conv2D)                 │ (None, 8, 8, 128)        │        73,856 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ max_pooling2d_2 (MaxPooling2D)    │ (None, 4, 4, 128)        │             0 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ conv2d_3 (Conv2D)                 │ (None, 4, 4, 256)        │       295,168 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ max_pooling2d_3 (MaxPooling2D)    │ (None, 2, 2, 256)        │             0 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ flatten (Flatten)                 │ (None, 1024)             │             0 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ dropout (Dropout)                 │ (None, 1024)             │             0 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ dense (Dense)                     │ (None, 512)              │       524,800 
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ dense_1 (Dense)                   │ (None, 100)              │        51,300 
## └───────────────────────────────────┴──────────────────────────┴───────────────
##  Total params: 964,516 (3.68 MB)
##  Trainable params: 964,516 (3.68 MB)
##  Non-trainable params: 0 (0.00 B)

Notice that we used the padding = "same" argument to layer_conv_2D(), which ensures that the output channels have the same dimension as the input channels. There are 32 channels in the first hidden layer, in contrast to the three channels in the input layer. We use a \(3\times 3\) convolution filter for each channel in all the layers. Each convolution is followed by a max-pooling layer over \(2\times2\) blocks. By studying the summary, we can see that the channels halve in both dimensions after each of these max-pooling operations. After the last of these we have a layer with 256 channels of dimension \(2\times 2\). These are then flattened to a dense layer of size 1,024: in other words, each of the \(2\times 2\) matrices is turned into a 4-vector, and put side-by-side in one layer. This is followed by a dropout regularization layer, then another dense layer of size 512, which finally reaches the softmax output layer.

Finally, we specify the fitting algorithm, and fit the model.

model |>
  compile(
    optimizer = "rmsprop",
    loss = "categorical_crossentropy",
    metrics = list("categorical_accuracy")
  )

history <- model |>
  fit(
    x_train, y_train,
    epochs = 10,
    batch_size = 128,
    validation_split = 0.2,
    verbose = 2
  )
## Epoch 1/10
## 313/313 - 39s - 124ms/step - categorical_accuracy: 0.0538 - loss: 4.2052 - val_categorical_accuracy: 0.0880 - val_loss: 4.0220
## Epoch 2/10
## 313/313 - 40s - 129ms/step - categorical_accuracy: 0.1355 - loss: 3.6659 - val_categorical_accuracy: 0.1619 - val_loss: 3.5247
## Epoch 3/10
## 313/313 - 35s - 110ms/step - categorical_accuracy: 0.1962 - loss: 3.3239 - val_categorical_accuracy: 0.2211 - val_loss: 3.1846
## Epoch 4/10
## 313/313 - 36s - 115ms/step - categorical_accuracy: 0.2456 - loss: 3.0643 - val_categorical_accuracy: 0.2219 - val_loss: 3.2441
## Epoch 5/10
## 313/313 - 36s - 116ms/step - categorical_accuracy: 0.2799 - loss: 2.8673 - val_categorical_accuracy: 0.2753 - val_loss: 2.9086
## Epoch 6/10
## 313/313 - 36s - 116ms/step - categorical_accuracy: 0.3167 - loss: 2.6836 - val_categorical_accuracy: 0.2692 - val_loss: 2.9713
## Epoch 7/10
## 313/313 - 41s - 132ms/step - categorical_accuracy: 0.3491 - loss: 2.5243 - val_categorical_accuracy: 0.3494 - val_loss: 2.5679
## Epoch 8/10
## 313/313 - 40s - 126ms/step - categorical_accuracy: 0.3798 - loss: 2.3827 - val_categorical_accuracy: 0.3587 - val_loss: 2.5171
## Epoch 9/10
## 313/313 - 35s - 113ms/step - categorical_accuracy: 0.4094 - loss: 2.2470 - val_categorical_accuracy: 0.3662 - val_loss: 2.4625
## Epoch 10/10
## 313/313 - 36s - 114ms/step - categorical_accuracy: 0.4333 - loss: 2.1339 - val_categorical_accuracy: 0.3973 - val_loss: 2.3600
pred <- model |> predict(x_test)
## 313/313 - 5s - 15ms/step
pred_class <- max.col(pred, ties.method = "first")
true_class <- as.numeric(g_test)+1
mean(pred_class == true_class)
## [1] 0.4027

This model takes 10 minutes to run and achieves 45% accuracy on the test data. Although this is not terrible for 100-class data (a random classifier gets 1% accuracy), searching the web we see results around 75%. Typically it takes a lot of architecture carpentry, fiddling with regularization, and time to achieve such results.

Using Pretrained CNN Models

We now show how to use a CNN pretrained on the imagenet database to classify natural images. We copied six jpeg images from a digital photo album into the directory book_images. (These images are available from the data section of <www.statlearning.com>, the ISL book website. Download book_images.zip; when clicked it creates the book_images directory.) We first read in the images, and convert them into the array format expected by the keras software to match the specifications in imagenet. Make sure that your working directory in R is set to the folder in which the images are stored.

img_dir <- "book_images"
image_names <- list.files(img_dir)
num_images <- length(image_names)
x <- array(dim = c(num_images, 224, 224, 3))
for (i in 1:num_images) {
  img_path <- paste(img_dir, image_names[i], sep = "/")
  img <- image_load(img_path, target_size = c(224, 224))
  x[i,,, ] <- image_to_array(img)
}
x <- imagenet_preprocess_input(x)

We then load the trained network. The model has 50 layers, with a fair bit of complexity.

model <- application_resnet50(weights = "imagenet")
summary(model)
## Model: "resnet50"
## ┌────────────────────┬──────────────────┬────────────┬─────────────────┬───────
## │ Layer (type)       │ Output Shape     │    Param # │ Connected to    │ Trai… 
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ input_layer_1      │ (None, 224, 224, │          0 │ -               │   -   
## │ (InputLayer)       │ 3)               │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv1_pad          │ (None, 230, 230, │          0 │ input_layer_1[… │   -   
## │ (ZeroPadding2D)    │ 3)               │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv1_conv         │ (None, 112, 112, │      9,472 │ conv1_pad[0][0] │   Y   
## │ (Conv2D)           │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv1_bn           │ (None, 112, 112, │        256 │ conv1_conv[0][… │   Y   
## │ (BatchNormalizati… │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv1_relu         │ (None, 112, 112, │          0 │ conv1_bn[0][0]  │   -   
## │ (Activation)       │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ pool1_pad          │ (None, 114, 114, │          0 │ conv1_relu[0][… │   -   
## │ (ZeroPadding2D)    │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ pool1_pool         │ (None, 56, 56,   │          0 │ pool1_pad[0][0] │   -   
## │ (MaxPooling2D)     │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_1_co… │ (None, 56, 56,   │      4,160 │ pool1_pool[0][… │   Y   
## │ (Conv2D)           │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_1_bn  │ (None, 56, 56,   │        256 │ conv2_block1_1… │   Y   
## │ (BatchNormalizati… │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_1_re… │ (None, 56, 56,   │          0 │ conv2_block1_1… │   -   
## │ (Activation)       │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_2_co… │ (None, 56, 56,   │     36,928 │ conv2_block1_1… │   Y   
## │ (Conv2D)           │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_2_bn  │ (None, 56, 56,   │        256 │ conv2_block1_2… │   Y   
## │ (BatchNormalizati… │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_2_re… │ (None, 56, 56,   │          0 │ conv2_block1_2… │   -   
## │ (Activation)       │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_0_co… │ (None, 56, 56,   │     16,640 │ pool1_pool[0][… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_3_co… │ (None, 56, 56,   │     16,640 │ conv2_block1_2… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_0_bn  │ (None, 56, 56,   │      1,024 │ conv2_block1_0… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_3_bn  │ (None, 56, 56,   │      1,024 │ conv2_block1_3… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_add   │ (None, 56, 56,   │          0 │ conv2_block1_0… │   -   
## │ (Add)              │ 256)             │            │ conv2_block1_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_out   │ (None, 56, 56,   │          0 │ conv2_block1_a… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_1_co… │ (None, 56, 56,   │     16,448 │ conv2_block1_o… │   Y   
## │ (Conv2D)           │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_1_bn  │ (None, 56, 56,   │        256 │ conv2_block2_1… │   Y   
## │ (BatchNormalizati… │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_1_re… │ (None, 56, 56,   │          0 │ conv2_block2_1… │   -   
## │ (Activation)       │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_2_co… │ (None, 56, 56,   │     36,928 │ conv2_block2_1… │   Y   
## │ (Conv2D)           │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_2_bn  │ (None, 56, 56,   │        256 │ conv2_block2_2… │   Y   
## │ (BatchNormalizati… │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_2_re… │ (None, 56, 56,   │          0 │ conv2_block2_2… │   -   
## │ (Activation)       │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_3_co… │ (None, 56, 56,   │     16,640 │ conv2_block2_2… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_3_bn  │ (None, 56, 56,   │      1,024 │ conv2_block2_3… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_add   │ (None, 56, 56,   │          0 │ conv2_block1_o… │   -   
## │ (Add)              │ 256)             │            │ conv2_block2_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_out   │ (None, 56, 56,   │          0 │ conv2_block2_a… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_1_co… │ (None, 56, 56,   │     16,448 │ conv2_block2_o… │   Y   
## │ (Conv2D)           │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_1_bn  │ (None, 56, 56,   │        256 │ conv2_block3_1… │   Y   
## │ (BatchNormalizati… │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_1_re… │ (None, 56, 56,   │          0 │ conv2_block3_1… │   -   
## │ (Activation)       │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_2_co… │ (None, 56, 56,   │     36,928 │ conv2_block3_1… │   Y   
## │ (Conv2D)           │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_2_bn  │ (None, 56, 56,   │        256 │ conv2_block3_2… │   Y   
## │ (BatchNormalizati… │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_2_re… │ (None, 56, 56,   │          0 │ conv2_block3_2… │   -   
## │ (Activation)       │ 64)              │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_3_co… │ (None, 56, 56,   │     16,640 │ conv2_block3_2… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_3_bn  │ (None, 56, 56,   │      1,024 │ conv2_block3_3… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_add   │ (None, 56, 56,   │          0 │ conv2_block2_o… │   -   
## │ (Add)              │ 256)             │            │ conv2_block3_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_out   │ (None, 56, 56,   │          0 │ conv2_block3_a… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_1_co… │ (None, 28, 28,   │     32,896 │ conv2_block3_o… │   Y   
## │ (Conv2D)           │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_1_bn  │ (None, 28, 28,   │        512 │ conv3_block1_1… │   Y   
## │ (BatchNormalizati… │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_1_re… │ (None, 28, 28,   │          0 │ conv3_block1_1… │   -   
## │ (Activation)       │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_2_co… │ (None, 28, 28,   │    147,584 │ conv3_block1_1… │   Y   
## │ (Conv2D)           │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_2_bn  │ (None, 28, 28,   │        512 │ conv3_block1_2… │   Y   
## │ (BatchNormalizati… │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_2_re… │ (None, 28, 28,   │          0 │ conv3_block1_2… │   -   
## │ (Activation)       │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_0_co… │ (None, 28, 28,   │    131,584 │ conv2_block3_o… │   Y   
## │ (Conv2D)           │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_3_co… │ (None, 28, 28,   │     66,048 │ conv3_block1_2… │   Y   
## │ (Conv2D)           │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_0_bn  │ (None, 28, 28,   │      2,048 │ conv3_block1_0… │   Y   
## │ (BatchNormalizati… │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_3_bn  │ (None, 28, 28,   │      2,048 │ conv3_block1_3… │   Y   
## │ (BatchNormalizati… │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_add   │ (None, 28, 28,   │          0 │ conv3_block1_0… │   -   
## │ (Add)              │ 512)             │            │ conv3_block1_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_out   │ (None, 28, 28,   │          0 │ conv3_block1_a… │   -   
## │ (Activation)       │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_1_co… │ (None, 28, 28,   │     65,664 │ conv3_block1_o… │   Y   
## │ (Conv2D)           │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_1_bn  │ (None, 28, 28,   │        512 │ conv3_block2_1… │   Y   
## │ (BatchNormalizati… │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_1_re… │ (None, 28, 28,   │          0 │ conv3_block2_1… │   -   
## │ (Activation)       │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_2_co… │ (None, 28, 28,   │    147,584 │ conv3_block2_1… │   Y   
## │ (Conv2D)           │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_2_bn  │ (None, 28, 28,   │        512 │ conv3_block2_2… │   Y   
## │ (BatchNormalizati… │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_2_re… │ (None, 28, 28,   │          0 │ conv3_block2_2… │   -   
## │ (Activation)       │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_3_co… │ (None, 28, 28,   │     66,048 │ conv3_block2_2… │   Y   
## │ (Conv2D)           │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_3_bn  │ (None, 28, 28,   │      2,048 │ conv3_block2_3… │   Y   
## │ (BatchNormalizati… │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_add   │ (None, 28, 28,   │          0 │ conv3_block1_o… │   -   
## │ (Add)              │ 512)             │            │ conv3_block2_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_out   │ (None, 28, 28,   │          0 │ conv3_block2_a… │   -   
## │ (Activation)       │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_1_co… │ (None, 28, 28,   │     65,664 │ conv3_block2_o… │   Y   
## │ (Conv2D)           │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_1_bn  │ (None, 28, 28,   │        512 │ conv3_block3_1… │   Y   
## │ (BatchNormalizati… │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_1_re… │ (None, 28, 28,   │          0 │ conv3_block3_1… │   -   
## │ (Activation)       │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_2_co… │ (None, 28, 28,   │    147,584 │ conv3_block3_1… │   Y   
## │ (Conv2D)           │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_2_bn  │ (None, 28, 28,   │        512 │ conv3_block3_2… │   Y   
## │ (BatchNormalizati… │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_2_re… │ (None, 28, 28,   │          0 │ conv3_block3_2… │   -   
## │ (Activation)       │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_3_co… │ (None, 28, 28,   │     66,048 │ conv3_block3_2… │   Y   
## │ (Conv2D)           │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_3_bn  │ (None, 28, 28,   │      2,048 │ conv3_block3_3… │   Y   
## │ (BatchNormalizati… │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_add   │ (None, 28, 28,   │          0 │ conv3_block2_o… │   -   
## │ (Add)              │ 512)             │            │ conv3_block3_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_out   │ (None, 28, 28,   │          0 │ conv3_block3_a… │   -   
## │ (Activation)       │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_1_co… │ (None, 28, 28,   │     65,664 │ conv3_block3_o… │   Y   
## │ (Conv2D)           │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_1_bn  │ (None, 28, 28,   │        512 │ conv3_block4_1… │   Y   
## │ (BatchNormalizati… │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_1_re… │ (None, 28, 28,   │          0 │ conv3_block4_1… │   -   
## │ (Activation)       │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_2_co… │ (None, 28, 28,   │    147,584 │ conv3_block4_1… │   Y   
## │ (Conv2D)           │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_2_bn  │ (None, 28, 28,   │        512 │ conv3_block4_2… │   Y   
## │ (BatchNormalizati… │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_2_re… │ (None, 28, 28,   │          0 │ conv3_block4_2… │   -   
## │ (Activation)       │ 128)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_3_co… │ (None, 28, 28,   │     66,048 │ conv3_block4_2… │   Y   
## │ (Conv2D)           │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_3_bn  │ (None, 28, 28,   │      2,048 │ conv3_block4_3… │   Y   
## │ (BatchNormalizati… │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_add   │ (None, 28, 28,   │          0 │ conv3_block3_o… │   -   
## │ (Add)              │ 512)             │            │ conv3_block4_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_out   │ (None, 28, 28,   │          0 │ conv3_block4_a… │   -   
## │ (Activation)       │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_1_co… │ (None, 14, 14,   │    131,328 │ conv3_block4_o… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_1_bn  │ (None, 14, 14,   │      1,024 │ conv4_block1_1… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_1_re… │ (None, 14, 14,   │          0 │ conv4_block1_1… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_2_co… │ (None, 14, 14,   │    590,080 │ conv4_block1_1… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_2_bn  │ (None, 14, 14,   │      1,024 │ conv4_block1_2… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_2_re… │ (None, 14, 14,   │          0 │ conv4_block1_2… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_0_co… │ (None, 14, 14,   │    525,312 │ conv3_block4_o… │   Y   
## │ (Conv2D)           │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_3_co… │ (None, 14, 14,   │    263,168 │ conv4_block1_2… │   Y   
## │ (Conv2D)           │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_0_bn  │ (None, 14, 14,   │      4,096 │ conv4_block1_0… │   Y   
## │ (BatchNormalizati… │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_3_bn  │ (None, 14, 14,   │      4,096 │ conv4_block1_3… │   Y   
## │ (BatchNormalizati… │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_add   │ (None, 14, 14,   │          0 │ conv4_block1_0… │   -   
## │ (Add)              │ 1024)            │            │ conv4_block1_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_out   │ (None, 14, 14,   │          0 │ conv4_block1_a… │   -   
## │ (Activation)       │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_1_co… │ (None, 14, 14,   │    262,400 │ conv4_block1_o… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_1_bn  │ (None, 14, 14,   │      1,024 │ conv4_block2_1… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_1_re… │ (None, 14, 14,   │          0 │ conv4_block2_1… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_2_co… │ (None, 14, 14,   │    590,080 │ conv4_block2_1… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_2_bn  │ (None, 14, 14,   │      1,024 │ conv4_block2_2… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_2_re… │ (None, 14, 14,   │          0 │ conv4_block2_2… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_3_co… │ (None, 14, 14,   │    263,168 │ conv4_block2_2… │   Y   
## │ (Conv2D)           │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_3_bn  │ (None, 14, 14,   │      4,096 │ conv4_block2_3… │   Y   
## │ (BatchNormalizati… │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_add   │ (None, 14, 14,   │          0 │ conv4_block1_o… │   -   
## │ (Add)              │ 1024)            │            │ conv4_block2_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_out   │ (None, 14, 14,   │          0 │ conv4_block2_a… │   -   
## │ (Activation)       │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_1_co… │ (None, 14, 14,   │    262,400 │ conv4_block2_o… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_1_bn  │ (None, 14, 14,   │      1,024 │ conv4_block3_1… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_1_re… │ (None, 14, 14,   │          0 │ conv4_block3_1… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_2_co… │ (None, 14, 14,   │    590,080 │ conv4_block3_1… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_2_bn  │ (None, 14, 14,   │      1,024 │ conv4_block3_2… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_2_re… │ (None, 14, 14,   │          0 │ conv4_block3_2… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_3_co… │ (None, 14, 14,   │    263,168 │ conv4_block3_2… │   Y   
## │ (Conv2D)           │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_3_bn  │ (None, 14, 14,   │      4,096 │ conv4_block3_3… │   Y   
## │ (BatchNormalizati… │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_add   │ (None, 14, 14,   │          0 │ conv4_block2_o… │   -   
## │ (Add)              │ 1024)            │            │ conv4_block3_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_out   │ (None, 14, 14,   │          0 │ conv4_block3_a… │   -   
## │ (Activation)       │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_1_co… │ (None, 14, 14,   │    262,400 │ conv4_block3_o… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_1_bn  │ (None, 14, 14,   │      1,024 │ conv4_block4_1… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_1_re… │ (None, 14, 14,   │          0 │ conv4_block4_1… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_2_co… │ (None, 14, 14,   │    590,080 │ conv4_block4_1… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_2_bn  │ (None, 14, 14,   │      1,024 │ conv4_block4_2… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_2_re… │ (None, 14, 14,   │          0 │ conv4_block4_2… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_3_co… │ (None, 14, 14,   │    263,168 │ conv4_block4_2… │   Y   
## │ (Conv2D)           │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_3_bn  │ (None, 14, 14,   │      4,096 │ conv4_block4_3… │   Y   
## │ (BatchNormalizati… │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_add   │ (None, 14, 14,   │          0 │ conv4_block3_o… │   -   
## │ (Add)              │ 1024)            │            │ conv4_block4_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_out   │ (None, 14, 14,   │          0 │ conv4_block4_a… │   -   
## │ (Activation)       │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_1_co… │ (None, 14, 14,   │    262,400 │ conv4_block4_o… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_1_bn  │ (None, 14, 14,   │      1,024 │ conv4_block5_1… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_1_re… │ (None, 14, 14,   │          0 │ conv4_block5_1… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_2_co… │ (None, 14, 14,   │    590,080 │ conv4_block5_1… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_2_bn  │ (None, 14, 14,   │      1,024 │ conv4_block5_2… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_2_re… │ (None, 14, 14,   │          0 │ conv4_block5_2… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_3_co… │ (None, 14, 14,   │    263,168 │ conv4_block5_2… │   Y   
## │ (Conv2D)           │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_3_bn  │ (None, 14, 14,   │      4,096 │ conv4_block5_3… │   Y   
## │ (BatchNormalizati… │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_add   │ (None, 14, 14,   │          0 │ conv4_block4_o… │   -   
## │ (Add)              │ 1024)            │            │ conv4_block5_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_out   │ (None, 14, 14,   │          0 │ conv4_block5_a… │   -   
## │ (Activation)       │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_1_co… │ (None, 14, 14,   │    262,400 │ conv4_block5_o… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_1_bn  │ (None, 14, 14,   │      1,024 │ conv4_block6_1… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_1_re… │ (None, 14, 14,   │          0 │ conv4_block6_1… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_2_co… │ (None, 14, 14,   │    590,080 │ conv4_block6_1… │   Y   
## │ (Conv2D)           │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_2_bn  │ (None, 14, 14,   │      1,024 │ conv4_block6_2… │   Y   
## │ (BatchNormalizati… │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_2_re… │ (None, 14, 14,   │          0 │ conv4_block6_2… │   -   
## │ (Activation)       │ 256)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_3_co… │ (None, 14, 14,   │    263,168 │ conv4_block6_2… │   Y   
## │ (Conv2D)           │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_3_bn  │ (None, 14, 14,   │      4,096 │ conv4_block6_3… │   Y   
## │ (BatchNormalizati… │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_add   │ (None, 14, 14,   │          0 │ conv4_block5_o… │   -   
## │ (Add)              │ 1024)            │            │ conv4_block6_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_out   │ (None, 14, 14,   │          0 │ conv4_block6_a… │   -   
## │ (Activation)       │ 1024)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_1_co… │ (None, 7, 7,     │    524,800 │ conv4_block6_o… │   Y   
## │ (Conv2D)           │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_1_bn  │ (None, 7, 7,     │      2,048 │ conv5_block1_1… │   Y   
## │ (BatchNormalizati… │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_1_re… │ (None, 7, 7,     │          0 │ conv5_block1_1… │   -   
## │ (Activation)       │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_2_co… │ (None, 7, 7,     │  2,359,808 │ conv5_block1_1… │   Y   
## │ (Conv2D)           │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_2_bn  │ (None, 7, 7,     │      2,048 │ conv5_block1_2… │   Y   
## │ (BatchNormalizati… │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_2_re… │ (None, 7, 7,     │          0 │ conv5_block1_2… │   -   
## │ (Activation)       │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_0_co… │ (None, 7, 7,     │  2,099,200 │ conv4_block6_o… │   Y   
## │ (Conv2D)           │ 2048)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_3_co… │ (None, 7, 7,     │  1,050,624 │ conv5_block1_2… │   Y   
## │ (Conv2D)           │ 2048)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_0_bn  │ (None, 7, 7,     │      8,192 │ conv5_block1_0… │   Y   
## │ (BatchNormalizati… │ 2048)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_3_bn  │ (None, 7, 7,     │      8,192 │ conv5_block1_3… │   Y   
## │ (BatchNormalizati… │ 2048)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_add   │ (None, 7, 7,     │          0 │ conv5_block1_0… │   -   
## │ (Add)              │ 2048)            │            │ conv5_block1_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_out   │ (None, 7, 7,     │          0 │ conv5_block1_a… │   -   
## │ (Activation)       │ 2048)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_1_co… │ (None, 7, 7,     │  1,049,088 │ conv5_block1_o… │   Y   
## │ (Conv2D)           │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_1_bn  │ (None, 7, 7,     │      2,048 │ conv5_block2_1… │   Y   
## │ (BatchNormalizati… │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_1_re… │ (None, 7, 7,     │          0 │ conv5_block2_1… │   -   
## │ (Activation)       │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_2_co… │ (None, 7, 7,     │  2,359,808 │ conv5_block2_1… │   Y   
## │ (Conv2D)           │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_2_bn  │ (None, 7, 7,     │      2,048 │ conv5_block2_2… │   Y   
## │ (BatchNormalizati… │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_2_re… │ (None, 7, 7,     │          0 │ conv5_block2_2… │   -   
## │ (Activation)       │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_3_co… │ (None, 7, 7,     │  1,050,624 │ conv5_block2_2… │   Y   
## │ (Conv2D)           │ 2048)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_3_bn  │ (None, 7, 7,     │      8,192 │ conv5_block2_3… │   Y   
## │ (BatchNormalizati… │ 2048)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_add   │ (None, 7, 7,     │          0 │ conv5_block1_o… │   -   
## │ (Add)              │ 2048)            │            │ conv5_block2_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_out   │ (None, 7, 7,     │          0 │ conv5_block2_a… │   -   
## │ (Activation)       │ 2048)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_1_co… │ (None, 7, 7,     │  1,049,088 │ conv5_block2_o… │   Y   
## │ (Conv2D)           │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_1_bn  │ (None, 7, 7,     │      2,048 │ conv5_block3_1… │   Y   
## │ (BatchNormalizati… │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_1_re… │ (None, 7, 7,     │          0 │ conv5_block3_1… │   -   
## │ (Activation)       │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_2_co… │ (None, 7, 7,     │  2,359,808 │ conv5_block3_1… │   Y   
## │ (Conv2D)           │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_2_bn  │ (None, 7, 7,     │      2,048 │ conv5_block3_2… │   Y   
## │ (BatchNormalizati… │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_2_re… │ (None, 7, 7,     │          0 │ conv5_block3_2… │   -   
## │ (Activation)       │ 512)             │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_3_co… │ (None, 7, 7,     │  1,050,624 │ conv5_block3_2… │   Y   
## │ (Conv2D)           │ 2048)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_3_bn  │ (None, 7, 7,     │      8,192 │ conv5_block3_3… │   Y   
## │ (BatchNormalizati… │ 2048)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_add   │ (None, 7, 7,     │          0 │ conv5_block2_o… │   -   
## │ (Add)              │ 2048)            │            │ conv5_block3_3… │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_out   │ (None, 7, 7,     │          0 │ conv5_block3_a… │   -   
## │ (Activation)       │ 2048)            │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ avg_pool           │ (None, 2048)     │          0 │ conv5_block3_o… │   -   
## │ (GlobalAveragePoo… │                  │            │                 │       
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ predictions        │ (None, 1000)     │  2,049,000 │ avg_pool[0][0]  │   Y   
## │ (Dense)            │                  │            │                 │       
## └────────────────────┴──────────────────┴────────────┴─────────────────┴───────
##  Total params: 25,636,712 (97.80 MB)
##  Trainable params: 25,583,592 (97.59 MB)
##  Non-trainable params: 53,120 (207.50 KB)

Finally, we classify our six images, and return the top three class choices in terms of predicted probability for each.

pred6 <- model %>% predict(x) %>%
  imagenet_decode_predictions(top = 3)
## 1/1 - 3s - 3s/step
names(pred6) <- image_names
print(pred6)
## $Cape_Weaver.jpg
##   class_name class_description      score
## 1  n01843065           jacamar 0.49795407
## 2  n01818515             macaw 0.22193323
## 3  n02494079   squirrel_monkey 0.04287848
## 
## $Flamingo.jpg
##   class_name class_description      score
## 1  n02007558          flamingo 0.92634964
## 2  n02006656         spoonbill 0.07169954
## 3  n02002556       white_stork 0.00122821
## 
## $Hawk_cropped.jpg
##   class_name class_description      score
## 1  n01608432              kite 0.72270894
## 2  n01622779    great_grey_owl 0.08182577
## 3  n01532829       house_finch 0.04218888
## 
## $Hawk_Fountain.jpg
##   class_name class_description     score
## 1  n03388043          fountain 0.2788652
## 2  n03532672              hook 0.1785546
## 3  n03804744              nail 0.1080733
## 
## $Lhasa_Apso.jpg
##   class_name           class_description      score
## 1  n02097474             Tibetan_terrier 0.50929731
## 2  n02098413                       Lhasa 0.42209828
## 3  n02098105 soft-coated_wheaten_terrier 0.01695858
## 
## $Sleeping_Cat.jpg
##   class_name    class_description      score
## 1  n02105641 Old_English_sheepdog 0.83266014
## 2  n02086240             Shih-Tzu 0.04513866
## 3  n03223299              doormat 0.03299771

Convolutional Neural Networks with Torch

In this section we fit a CNN to the CIFAR data, which is available in the torchvision package. It is arranged in a similar fashion as the MNIST data.

# install.packages("torch")
# install.packages("luz")
# install.packages("torchvision")
# install.packages("torchdatasets")
# install.packages("zeallot")
library(torch)
## 
## Attaching package: 'torch'
## The following object is masked from 'package:keras3':
## 
##     as_iterator
## The following object is masked from 'package:reticulate':
## 
##     as_iterator
library(luz) # high-level interface for torch
## 
## Attaching package: 'luz'
## The following object is masked from 'package:keras3':
## 
##     evaluate
library(torchvision) # for datasets and image transformation
library(torchdatasets) # for datasets we are going to use
library(zeallot)
torch_manual_seed(6805)
transform <- function(x) {
  transform_to_tensor(x)
}

train_ds <- cifar100_dataset(
  root = "./", 
  train = TRUE, 
  download = TRUE, 
  transform = transform
)
## Dataset <cifar100_dataset> (~160 MB) will be downloaded and processed if not
## already available.
## Dataset <cifar100_dataset> loaded with 50000 images across 100 classes.
test_ds <- cifar100_dataset(
  root = "./", 
  train = FALSE, 
  transform = transform
)
## Dataset <cifar100_dataset> loaded with 10000 images across 100 classes.
str(train_ds[1])
## List of 2
##  $ x:Float [1:3, 1:32, 1:32]
##  $ y: int 20
length(train_ds)
## [1] 50000

The CIFAR dataset consists of 50,000 training images, each represented by a 3d tensor: each three-color image is represented as a set of three channels, each of which consists of \(32\times 32\) eight-bit pixels. We standardize as we did for the digits, but keep the array structure. This is accomplished with the transform argument.

Before we start, we look at some of the training images.

par(mar = c(0, 0, 0, 0), mfrow = c(5, 5))
index <- sample(seq(50000), 25)
for (i in index) plot(as.raster(as.array(train_ds[i][[1]]$permute(c(2,3,1)))))

The as.raster() function converts the feature map so that it can be plotted as a color image.

Here we specify a moderately-sized CNN for demonstration purposes.

conv_block <- nn_module(
  initialize = function(in_channels, out_channels) {
    self$conv <- nn_conv2d(
      in_channels = in_channels, 
      out_channels = out_channels, 
      kernel_size = c(3,3), 
      padding = "same"
    )
    self$relu <- nn_relu()
    self$pool <- nn_max_pool2d(kernel_size = c(2,2))
  },
  forward = function(x) {
    x %>% 
      self$conv() %>% 
      self$relu() %>% 
      self$pool()
  }
)

model <- nn_module(
  initialize = function() {
    self$conv <- nn_sequential(
      conv_block(3, 32),
      conv_block(32, 64),
      conv_block(64, 128),
      conv_block(128, 256)
    )
    self$output <- nn_sequential(
      nn_dropout(0.5),
      nn_linear(2*2*256, 512),
      nn_relu(),
      nn_linear(512, 100)
    )
  },
  forward = function(x) {
    x %>% 
      self$conv() %>% 
      torch_flatten(start_dim = 2) %>% 
      self$output()
  }
)
model()
## An `nn_module` containing 964,516 parameters.
## 
## ── Modules ─────────────────────────────────────────────────────────────────────
## • conv: <nn_sequential> #388,416 parameters
## • output: <nn_sequential> #576,100 parameters

Notice that we used the padding = "same" argument to nn_conv2d(), which ensures that the output channels have the same dimension as the input channels. There are 32 channels in the first hidden layer, in contrast to the three channels in the input layer. We use a \(3\times 3\) convolution filter for each channel in all the layers. Each convolution is followed by a max-pooling layer over \(2\times2\) blocks. By studying the summary, we can see that the channels halve in both dimensions after each of these max-pooling operations. After the last of these we have a layer with 256 channels of dimension \(2\times 2\). These are then flattened to a dense layer of size 1,024: in other words, each of the \(2\times 2\) matrices is turned into a 4-vector, and put side-by-side in one layer. This is followed by a dropout regularization layer, then another dense layer of size 512, and finally, the output layer.

Finally, we specify the fitting algorithm, and fit the model.

fitted <- model %>% 
  setup(
    loss = nn_cross_entropy_loss(),
    optimizer = optim_rmsprop, 
    metrics = list(luz_metric_accuracy())
  ) %>% 
  set_opt_hparams(lr = 0.001) %>% 
  fit(
    train_ds,
    epochs = 10, #30,
    valid_data = 0.2,
    dataloader_options = list(batch_size = 128)
  )

print(fitted)
## A `luz_module_fitted`
## ── Time ────────────────────────────────────────────────────────────────────────
## • Total time: 13m 15.4s
## • Avg time per training epoch: 1m 9.4s
## 
## ── Results ─────────────────────────────────────────────────────────────────────
## Metrics observed in the last epoch.
## 
## ℹ Training:
## loss: 2.3598
## acc: 0.3802
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## An `nn_module` containing 964,516 parameters.
## 
## ── Modules ─────────────────────────────────────────────────────────────────────
## • conv: <nn_sequential> #388,416 parameters
## • output: <nn_sequential> #576,100 parameters
evaluate(fitted, test_ds)
## A `luz_module_evaluation`
## ── Results ─────────────────────────────────────────────────────────────────────
## loss: 2.4916
## acc: 0.3631

This model takes 10 minutes to run and achieves 36% accuracy on the test data. Although this is not terrible for 100-class data (a random classifier gets 1% accuracy), searching the web we see results around 75%. Typically it takes a lot of architecture carpentry, fiddling with regularization, and time to achieve such results.

Using Pretrained CNN Models

We now show how to use a CNN pretrained on the imagenet database to classify natural images. We copied six jpeg images from a digital photo album into the directory book_images. (These images are available from the data section of www.statlearning.com, the ISLR book website. Download book_images.zip; when clicked it creates the book_images directory.) We first read in the images, and convert them into the array format expected by the torch software to match the specifications in imagenet. Make sure that your working directory in R is set to the folder in which the images are stored.

img_dir <- "book_images"
image_names <- list.files(img_dir)
num_images <- length(image_names)
x <- torch_empty(num_images, 3, 224, 224)
for (i in 1:num_images) {
   img_path <- file.path(img_dir, image_names[i])
   img <- img_path %>% 
     base_loader() %>% 
     transform_to_tensor() %>% 
     transform_resize(c(224, 224)) %>% 
     # normalize with imagenet mean and stds.
     transform_normalize(
       mean = c(0.485, 0.456, 0.406),
       std = c(0.229, 0.224, 0.225)
     )
   x[i,,, ] <- img
}

We then load the trained network. The model has 18 layers, with a fair bit of complexity.

model <- torchvision::model_resnet18(pretrained = TRUE)
model$eval() # put the model in evaluation mode

Finally, we classify our six images, and return the top three class choices in terms of predicted probability for each.

preds <- model(x)

mapping <- jsonlite::read_json("https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json") %>% 
  sapply(function(x) x[[2]])

top3 <- torch_topk(preds, dim = 2, k = 3)

top3_prob <- top3[[1]] %>% 
  nnf_softmax(dim = 2) %>% 
  torch_unbind() %>% 
  lapply(as.numeric)

top3_class <- top3[[2]] %>% 
  torch_unbind() %>% 
  lapply(function(x) mapping[as.integer(x)])

result <- purrr::map2(top3_prob, top3_class, function(pr, cl) {
  names(pr) <- cl
  pr
})
names(result) <- image_names
print(result)
## $Cape_Weaver.jpg
## hummingbird    lorikeet   bee_eater 
##   0.3633287   0.3577293   0.2789420 
## 
## $Flamingo.jpg
##    flamingo   spoonbill white_stork 
## 0.978211999 0.017045649 0.004742352 
## 
## $Hawk_cropped.jpg
##      kite       jay    magpie 
## 0.6157812 0.2311861 0.1530326 
## 
## $Hawk_Fountain.jpg
##         eel       agama common_newt 
##   0.5391128   0.2527185   0.2081687 
## 
## $Lhasa_Apso.jpg
##           Lhasa Tibetan_terrier        Shih-Tzu 
##      0.79760426      0.12013003      0.08226573 
## 
## $Sleeping_Cat.jpg
##        Saint_Bernard           guinea_pig Bernese_mountain_dog 
##            0.3946672            0.3426990            0.2626339