Chapter 10.9.3 and 10.9.4 [ISLR2] An Introduction to Statistical Learning - with Applications in R (2nd Edition). Free access to download the book: https://www.statlearning.com/
To see the help file of a function funcname, type
?funcname.
Getting keras up and running on your computer can be a
challenge. The book website <www.statlearning.com> gives
step-by-step instructions on how to achieve this. Guidance can also be
found at <keras.rstudio.com>.
The torch package has become available as an alternative
to the keras package for deep learning. While
torch does not require a python installation,
the current implementation appears to be a fair bit slower than
keras. We include the torch version of the
implementation at the end of this lab.
In this section, we use the keras package, which
interfaces to the tensorflow package which in turn links to
efficient python code. This code is impressively fast, and
the package is well-structured.
# install.packages("keras3") # Uncomment to Run it for the first time
# install.packages("reticulate") # Uncomment to Run it for the first time
library(reticulate)
use_condaenv("r-reticulate", required = TRUE)
library(keras3)
# install_keras(method = "conda", envname = "r-reticulate") # Uncomment to Run it for the first time
In this section we fit a CNN to the CIFAR data, which is
available in the keras package. It is arranged in a similar
fashion as the MNIST data.
cifar100 <- dataset_cifar100()
names(cifar100)
## [1] "train" "test"
x_train <- cifar100$train$x
g_train <- cifar100$train$y
x_test <- cifar100$test$x
g_test <- cifar100$test$y
dim(x_train)
## [1] 50000 32 32 3
range(x_train[1,,,1])
## [1] 13 255
The array of 50,000 training images has four dimensions: each three-color image is represented as a set of three channels, each of which consists of \(32\times 32\) eight-bit pixels. We standardize as we did for the digits, but keep the array structure. We one-hot encode the response factors to produce a 100-column binary matrix.
x_train <- x_train / 255
x_test <- x_test / 255
y_train <- to_categorical(g_train, 100)
dim(y_train)
## [1] 50000 100
Before we start, we look at some of the training images using the
jpeg package.
library(jpeg)
par(mar = c(0, 0, 0, 0), mfrow = c(5, 5))
index <- sample(seq(50000), 25)
for (i in index) plot(as.raster(x_train[i,,, ]))

The as.raster() function converts the feature map so
that it can be plotted as a color image.
Here we specify a moderately-sized CNN for demonstration purposes.
model <- keras_model_sequential(
layers = list(
layer_input(shape = c(32L, 32L, 3L)),
layer_conv_2d(filters = 32, kernel_size = c(3L, 3L),
padding = "same", activation = "relu"),
layer_max_pooling_2d(pool_size = c(2L, 2L)),
layer_conv_2d(filters = 64, kernel_size = c(3L, 3L),
padding = "same", activation = "relu"),
layer_max_pooling_2d(pool_size = c(2L, 2L)),
layer_conv_2d(filters = 128, kernel_size = c(3L, 3L),
padding = "same", activation = "relu"),
layer_max_pooling_2d(pool_size = c(2L, 2L)),
layer_conv_2d(filters = 256, kernel_size = c(3L, 3L),
padding = "same", activation = "relu"),
layer_max_pooling_2d(pool_size = c(2L, 2L)),
layer_flatten(),
layer_dropout(rate = 0.5),
layer_dense(units = 512, activation = "relu"),
layer_dense(units = 100, activation = "softmax")
)
)
model |> summary()
## Model: "sequential"
## ┌───────────────────────────────────┬──────────────────────────┬───────────────
## │ Layer (type) │ Output Shape │ Param #
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ conv2d (Conv2D) │ (None, 32, 32, 32) │ 896
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ max_pooling2d (MaxPooling2D) │ (None, 16, 16, 32) │ 0
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ conv2d_1 (Conv2D) │ (None, 16, 16, 64) │ 18,496
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ max_pooling2d_1 (MaxPooling2D) │ (None, 8, 8, 64) │ 0
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ conv2d_2 (Conv2D) │ (None, 8, 8, 128) │ 73,856
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ max_pooling2d_2 (MaxPooling2D) │ (None, 4, 4, 128) │ 0
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ conv2d_3 (Conv2D) │ (None, 4, 4, 256) │ 295,168
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ max_pooling2d_3 (MaxPooling2D) │ (None, 2, 2, 256) │ 0
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ flatten (Flatten) │ (None, 1024) │ 0
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ dropout (Dropout) │ (None, 1024) │ 0
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ dense (Dense) │ (None, 512) │ 524,800
## ├───────────────────────────────────┼──────────────────────────┼───────────────
## │ dense_1 (Dense) │ (None, 100) │ 51,300
## └───────────────────────────────────┴──────────────────────────┴───────────────
## Total params: 964,516 (3.68 MB)
## Trainable params: 964,516 (3.68 MB)
## Non-trainable params: 0 (0.00 B)
Notice that we used the padding = "same" argument to
layer_conv_2D(), which ensures that the output channels
have the same dimension as the input channels. There are 32 channels in
the first hidden layer, in contrast to the three channels in the input
layer. We use a \(3\times 3\)
convolution filter for each channel in all the layers. Each convolution
is followed by a max-pooling layer over \(2\times2\) blocks. By studying the summary,
we can see that the channels halve in both dimensions after each of
these max-pooling operations. After the last of these we have a layer
with 256 channels of dimension \(2\times
2\). These are then flattened to a dense layer of size 1,024: in
other words, each of the \(2\times 2\)
matrices is turned into a 4-vector, and put side-by-side in one layer.
This is followed by a dropout regularization layer, then another dense
layer of size 512, which finally reaches the softmax output layer.
Finally, we specify the fitting algorithm, and fit the model.
model |>
compile(
optimizer = "rmsprop",
loss = "categorical_crossentropy",
metrics = list("categorical_accuracy")
)
history <- model |>
fit(
x_train, y_train,
epochs = 10,
batch_size = 128,
validation_split = 0.2,
verbose = 2
)
## Epoch 1/10
## 313/313 - 39s - 124ms/step - categorical_accuracy: 0.0538 - loss: 4.2052 - val_categorical_accuracy: 0.0880 - val_loss: 4.0220
## Epoch 2/10
## 313/313 - 40s - 129ms/step - categorical_accuracy: 0.1355 - loss: 3.6659 - val_categorical_accuracy: 0.1619 - val_loss: 3.5247
## Epoch 3/10
## 313/313 - 35s - 110ms/step - categorical_accuracy: 0.1962 - loss: 3.3239 - val_categorical_accuracy: 0.2211 - val_loss: 3.1846
## Epoch 4/10
## 313/313 - 36s - 115ms/step - categorical_accuracy: 0.2456 - loss: 3.0643 - val_categorical_accuracy: 0.2219 - val_loss: 3.2441
## Epoch 5/10
## 313/313 - 36s - 116ms/step - categorical_accuracy: 0.2799 - loss: 2.8673 - val_categorical_accuracy: 0.2753 - val_loss: 2.9086
## Epoch 6/10
## 313/313 - 36s - 116ms/step - categorical_accuracy: 0.3167 - loss: 2.6836 - val_categorical_accuracy: 0.2692 - val_loss: 2.9713
## Epoch 7/10
## 313/313 - 41s - 132ms/step - categorical_accuracy: 0.3491 - loss: 2.5243 - val_categorical_accuracy: 0.3494 - val_loss: 2.5679
## Epoch 8/10
## 313/313 - 40s - 126ms/step - categorical_accuracy: 0.3798 - loss: 2.3827 - val_categorical_accuracy: 0.3587 - val_loss: 2.5171
## Epoch 9/10
## 313/313 - 35s - 113ms/step - categorical_accuracy: 0.4094 - loss: 2.2470 - val_categorical_accuracy: 0.3662 - val_loss: 2.4625
## Epoch 10/10
## 313/313 - 36s - 114ms/step - categorical_accuracy: 0.4333 - loss: 2.1339 - val_categorical_accuracy: 0.3973 - val_loss: 2.3600
pred <- model |> predict(x_test)
## 313/313 - 5s - 15ms/step
pred_class <- max.col(pred, ties.method = "first")
true_class <- as.numeric(g_test)+1
mean(pred_class == true_class)
## [1] 0.4027
This model takes 10 minutes to run and achieves 45% accuracy on the test data. Although this is not terrible for 100-class data (a random classifier gets 1% accuracy), searching the web we see results around 75%. Typically it takes a lot of architecture carpentry, fiddling with regularization, and time to achieve such results.
We now show how to use a CNN pretrained on the imagenet
database to classify natural images. We copied six jpeg images from a
digital photo album into the directory book_images. (These
images are available from the data section of
<www.statlearning.com>, the ISL book website. Download
book_images.zip; when clicked it creates the
book_images directory.) We first read in the images,
and convert them into the array format expected by the
keras software to match the specifications in
imagenet. Make sure that your working directory in
R is set to the folder in which the images are stored.
img_dir <- "book_images"
image_names <- list.files(img_dir)
num_images <- length(image_names)
x <- array(dim = c(num_images, 224, 224, 3))
for (i in 1:num_images) {
img_path <- paste(img_dir, image_names[i], sep = "/")
img <- image_load(img_path, target_size = c(224, 224))
x[i,,, ] <- image_to_array(img)
}
x <- imagenet_preprocess_input(x)
We then load the trained network. The model has 50 layers, with a fair bit of complexity.
model <- application_resnet50(weights = "imagenet")
summary(model)
## Model: "resnet50"
## ┌────────────────────┬──────────────────┬────────────┬─────────────────┬───────
## │ Layer (type) │ Output Shape │ Param # │ Connected to │ Trai…
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ input_layer_1 │ (None, 224, 224, │ 0 │ - │ -
## │ (InputLayer) │ 3) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv1_pad │ (None, 230, 230, │ 0 │ input_layer_1[… │ -
## │ (ZeroPadding2D) │ 3) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv1_conv │ (None, 112, 112, │ 9,472 │ conv1_pad[0][0] │ Y
## │ (Conv2D) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv1_bn │ (None, 112, 112, │ 256 │ conv1_conv[0][… │ Y
## │ (BatchNormalizati… │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv1_relu │ (None, 112, 112, │ 0 │ conv1_bn[0][0] │ -
## │ (Activation) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ pool1_pad │ (None, 114, 114, │ 0 │ conv1_relu[0][… │ -
## │ (ZeroPadding2D) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ pool1_pool │ (None, 56, 56, │ 0 │ pool1_pad[0][0] │ -
## │ (MaxPooling2D) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_1_co… │ (None, 56, 56, │ 4,160 │ pool1_pool[0][… │ Y
## │ (Conv2D) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_1_bn │ (None, 56, 56, │ 256 │ conv2_block1_1… │ Y
## │ (BatchNormalizati… │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_1_re… │ (None, 56, 56, │ 0 │ conv2_block1_1… │ -
## │ (Activation) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_2_co… │ (None, 56, 56, │ 36,928 │ conv2_block1_1… │ Y
## │ (Conv2D) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_2_bn │ (None, 56, 56, │ 256 │ conv2_block1_2… │ Y
## │ (BatchNormalizati… │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_2_re… │ (None, 56, 56, │ 0 │ conv2_block1_2… │ -
## │ (Activation) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_0_co… │ (None, 56, 56, │ 16,640 │ pool1_pool[0][… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_3_co… │ (None, 56, 56, │ 16,640 │ conv2_block1_2… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_0_bn │ (None, 56, 56, │ 1,024 │ conv2_block1_0… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_3_bn │ (None, 56, 56, │ 1,024 │ conv2_block1_3… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_add │ (None, 56, 56, │ 0 │ conv2_block1_0… │ -
## │ (Add) │ 256) │ │ conv2_block1_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block1_out │ (None, 56, 56, │ 0 │ conv2_block1_a… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_1_co… │ (None, 56, 56, │ 16,448 │ conv2_block1_o… │ Y
## │ (Conv2D) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_1_bn │ (None, 56, 56, │ 256 │ conv2_block2_1… │ Y
## │ (BatchNormalizati… │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_1_re… │ (None, 56, 56, │ 0 │ conv2_block2_1… │ -
## │ (Activation) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_2_co… │ (None, 56, 56, │ 36,928 │ conv2_block2_1… │ Y
## │ (Conv2D) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_2_bn │ (None, 56, 56, │ 256 │ conv2_block2_2… │ Y
## │ (BatchNormalizati… │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_2_re… │ (None, 56, 56, │ 0 │ conv2_block2_2… │ -
## │ (Activation) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_3_co… │ (None, 56, 56, │ 16,640 │ conv2_block2_2… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_3_bn │ (None, 56, 56, │ 1,024 │ conv2_block2_3… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_add │ (None, 56, 56, │ 0 │ conv2_block1_o… │ -
## │ (Add) │ 256) │ │ conv2_block2_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block2_out │ (None, 56, 56, │ 0 │ conv2_block2_a… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_1_co… │ (None, 56, 56, │ 16,448 │ conv2_block2_o… │ Y
## │ (Conv2D) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_1_bn │ (None, 56, 56, │ 256 │ conv2_block3_1… │ Y
## │ (BatchNormalizati… │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_1_re… │ (None, 56, 56, │ 0 │ conv2_block3_1… │ -
## │ (Activation) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_2_co… │ (None, 56, 56, │ 36,928 │ conv2_block3_1… │ Y
## │ (Conv2D) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_2_bn │ (None, 56, 56, │ 256 │ conv2_block3_2… │ Y
## │ (BatchNormalizati… │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_2_re… │ (None, 56, 56, │ 0 │ conv2_block3_2… │ -
## │ (Activation) │ 64) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_3_co… │ (None, 56, 56, │ 16,640 │ conv2_block3_2… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_3_bn │ (None, 56, 56, │ 1,024 │ conv2_block3_3… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_add │ (None, 56, 56, │ 0 │ conv2_block2_o… │ -
## │ (Add) │ 256) │ │ conv2_block3_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv2_block3_out │ (None, 56, 56, │ 0 │ conv2_block3_a… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_1_co… │ (None, 28, 28, │ 32,896 │ conv2_block3_o… │ Y
## │ (Conv2D) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_1_bn │ (None, 28, 28, │ 512 │ conv3_block1_1… │ Y
## │ (BatchNormalizati… │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_1_re… │ (None, 28, 28, │ 0 │ conv3_block1_1… │ -
## │ (Activation) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_2_co… │ (None, 28, 28, │ 147,584 │ conv3_block1_1… │ Y
## │ (Conv2D) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_2_bn │ (None, 28, 28, │ 512 │ conv3_block1_2… │ Y
## │ (BatchNormalizati… │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_2_re… │ (None, 28, 28, │ 0 │ conv3_block1_2… │ -
## │ (Activation) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_0_co… │ (None, 28, 28, │ 131,584 │ conv2_block3_o… │ Y
## │ (Conv2D) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_3_co… │ (None, 28, 28, │ 66,048 │ conv3_block1_2… │ Y
## │ (Conv2D) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_0_bn │ (None, 28, 28, │ 2,048 │ conv3_block1_0… │ Y
## │ (BatchNormalizati… │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_3_bn │ (None, 28, 28, │ 2,048 │ conv3_block1_3… │ Y
## │ (BatchNormalizati… │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_add │ (None, 28, 28, │ 0 │ conv3_block1_0… │ -
## │ (Add) │ 512) │ │ conv3_block1_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block1_out │ (None, 28, 28, │ 0 │ conv3_block1_a… │ -
## │ (Activation) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_1_co… │ (None, 28, 28, │ 65,664 │ conv3_block1_o… │ Y
## │ (Conv2D) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_1_bn │ (None, 28, 28, │ 512 │ conv3_block2_1… │ Y
## │ (BatchNormalizati… │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_1_re… │ (None, 28, 28, │ 0 │ conv3_block2_1… │ -
## │ (Activation) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_2_co… │ (None, 28, 28, │ 147,584 │ conv3_block2_1… │ Y
## │ (Conv2D) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_2_bn │ (None, 28, 28, │ 512 │ conv3_block2_2… │ Y
## │ (BatchNormalizati… │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_2_re… │ (None, 28, 28, │ 0 │ conv3_block2_2… │ -
## │ (Activation) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_3_co… │ (None, 28, 28, │ 66,048 │ conv3_block2_2… │ Y
## │ (Conv2D) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_3_bn │ (None, 28, 28, │ 2,048 │ conv3_block2_3… │ Y
## │ (BatchNormalizati… │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_add │ (None, 28, 28, │ 0 │ conv3_block1_o… │ -
## │ (Add) │ 512) │ │ conv3_block2_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block2_out │ (None, 28, 28, │ 0 │ conv3_block2_a… │ -
## │ (Activation) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_1_co… │ (None, 28, 28, │ 65,664 │ conv3_block2_o… │ Y
## │ (Conv2D) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_1_bn │ (None, 28, 28, │ 512 │ conv3_block3_1… │ Y
## │ (BatchNormalizati… │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_1_re… │ (None, 28, 28, │ 0 │ conv3_block3_1… │ -
## │ (Activation) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_2_co… │ (None, 28, 28, │ 147,584 │ conv3_block3_1… │ Y
## │ (Conv2D) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_2_bn │ (None, 28, 28, │ 512 │ conv3_block3_2… │ Y
## │ (BatchNormalizati… │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_2_re… │ (None, 28, 28, │ 0 │ conv3_block3_2… │ -
## │ (Activation) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_3_co… │ (None, 28, 28, │ 66,048 │ conv3_block3_2… │ Y
## │ (Conv2D) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_3_bn │ (None, 28, 28, │ 2,048 │ conv3_block3_3… │ Y
## │ (BatchNormalizati… │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_add │ (None, 28, 28, │ 0 │ conv3_block2_o… │ -
## │ (Add) │ 512) │ │ conv3_block3_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block3_out │ (None, 28, 28, │ 0 │ conv3_block3_a… │ -
## │ (Activation) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_1_co… │ (None, 28, 28, │ 65,664 │ conv3_block3_o… │ Y
## │ (Conv2D) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_1_bn │ (None, 28, 28, │ 512 │ conv3_block4_1… │ Y
## │ (BatchNormalizati… │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_1_re… │ (None, 28, 28, │ 0 │ conv3_block4_1… │ -
## │ (Activation) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_2_co… │ (None, 28, 28, │ 147,584 │ conv3_block4_1… │ Y
## │ (Conv2D) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_2_bn │ (None, 28, 28, │ 512 │ conv3_block4_2… │ Y
## │ (BatchNormalizati… │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_2_re… │ (None, 28, 28, │ 0 │ conv3_block4_2… │ -
## │ (Activation) │ 128) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_3_co… │ (None, 28, 28, │ 66,048 │ conv3_block4_2… │ Y
## │ (Conv2D) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_3_bn │ (None, 28, 28, │ 2,048 │ conv3_block4_3… │ Y
## │ (BatchNormalizati… │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_add │ (None, 28, 28, │ 0 │ conv3_block3_o… │ -
## │ (Add) │ 512) │ │ conv3_block4_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv3_block4_out │ (None, 28, 28, │ 0 │ conv3_block4_a… │ -
## │ (Activation) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_1_co… │ (None, 14, 14, │ 131,328 │ conv3_block4_o… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_1_bn │ (None, 14, 14, │ 1,024 │ conv4_block1_1… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_1_re… │ (None, 14, 14, │ 0 │ conv4_block1_1… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_2_co… │ (None, 14, 14, │ 590,080 │ conv4_block1_1… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_2_bn │ (None, 14, 14, │ 1,024 │ conv4_block1_2… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_2_re… │ (None, 14, 14, │ 0 │ conv4_block1_2… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_0_co… │ (None, 14, 14, │ 525,312 │ conv3_block4_o… │ Y
## │ (Conv2D) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_3_co… │ (None, 14, 14, │ 263,168 │ conv4_block1_2… │ Y
## │ (Conv2D) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_0_bn │ (None, 14, 14, │ 4,096 │ conv4_block1_0… │ Y
## │ (BatchNormalizati… │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_3_bn │ (None, 14, 14, │ 4,096 │ conv4_block1_3… │ Y
## │ (BatchNormalizati… │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_add │ (None, 14, 14, │ 0 │ conv4_block1_0… │ -
## │ (Add) │ 1024) │ │ conv4_block1_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block1_out │ (None, 14, 14, │ 0 │ conv4_block1_a… │ -
## │ (Activation) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_1_co… │ (None, 14, 14, │ 262,400 │ conv4_block1_o… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_1_bn │ (None, 14, 14, │ 1,024 │ conv4_block2_1… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_1_re… │ (None, 14, 14, │ 0 │ conv4_block2_1… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_2_co… │ (None, 14, 14, │ 590,080 │ conv4_block2_1… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_2_bn │ (None, 14, 14, │ 1,024 │ conv4_block2_2… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_2_re… │ (None, 14, 14, │ 0 │ conv4_block2_2… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_3_co… │ (None, 14, 14, │ 263,168 │ conv4_block2_2… │ Y
## │ (Conv2D) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_3_bn │ (None, 14, 14, │ 4,096 │ conv4_block2_3… │ Y
## │ (BatchNormalizati… │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_add │ (None, 14, 14, │ 0 │ conv4_block1_o… │ -
## │ (Add) │ 1024) │ │ conv4_block2_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block2_out │ (None, 14, 14, │ 0 │ conv4_block2_a… │ -
## │ (Activation) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_1_co… │ (None, 14, 14, │ 262,400 │ conv4_block2_o… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_1_bn │ (None, 14, 14, │ 1,024 │ conv4_block3_1… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_1_re… │ (None, 14, 14, │ 0 │ conv4_block3_1… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_2_co… │ (None, 14, 14, │ 590,080 │ conv4_block3_1… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_2_bn │ (None, 14, 14, │ 1,024 │ conv4_block3_2… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_2_re… │ (None, 14, 14, │ 0 │ conv4_block3_2… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_3_co… │ (None, 14, 14, │ 263,168 │ conv4_block3_2… │ Y
## │ (Conv2D) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_3_bn │ (None, 14, 14, │ 4,096 │ conv4_block3_3… │ Y
## │ (BatchNormalizati… │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_add │ (None, 14, 14, │ 0 │ conv4_block2_o… │ -
## │ (Add) │ 1024) │ │ conv4_block3_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block3_out │ (None, 14, 14, │ 0 │ conv4_block3_a… │ -
## │ (Activation) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_1_co… │ (None, 14, 14, │ 262,400 │ conv4_block3_o… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_1_bn │ (None, 14, 14, │ 1,024 │ conv4_block4_1… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_1_re… │ (None, 14, 14, │ 0 │ conv4_block4_1… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_2_co… │ (None, 14, 14, │ 590,080 │ conv4_block4_1… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_2_bn │ (None, 14, 14, │ 1,024 │ conv4_block4_2… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_2_re… │ (None, 14, 14, │ 0 │ conv4_block4_2… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_3_co… │ (None, 14, 14, │ 263,168 │ conv4_block4_2… │ Y
## │ (Conv2D) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_3_bn │ (None, 14, 14, │ 4,096 │ conv4_block4_3… │ Y
## │ (BatchNormalizati… │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_add │ (None, 14, 14, │ 0 │ conv4_block3_o… │ -
## │ (Add) │ 1024) │ │ conv4_block4_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block4_out │ (None, 14, 14, │ 0 │ conv4_block4_a… │ -
## │ (Activation) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_1_co… │ (None, 14, 14, │ 262,400 │ conv4_block4_o… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_1_bn │ (None, 14, 14, │ 1,024 │ conv4_block5_1… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_1_re… │ (None, 14, 14, │ 0 │ conv4_block5_1… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_2_co… │ (None, 14, 14, │ 590,080 │ conv4_block5_1… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_2_bn │ (None, 14, 14, │ 1,024 │ conv4_block5_2… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_2_re… │ (None, 14, 14, │ 0 │ conv4_block5_2… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_3_co… │ (None, 14, 14, │ 263,168 │ conv4_block5_2… │ Y
## │ (Conv2D) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_3_bn │ (None, 14, 14, │ 4,096 │ conv4_block5_3… │ Y
## │ (BatchNormalizati… │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_add │ (None, 14, 14, │ 0 │ conv4_block4_o… │ -
## │ (Add) │ 1024) │ │ conv4_block5_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block5_out │ (None, 14, 14, │ 0 │ conv4_block5_a… │ -
## │ (Activation) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_1_co… │ (None, 14, 14, │ 262,400 │ conv4_block5_o… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_1_bn │ (None, 14, 14, │ 1,024 │ conv4_block6_1… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_1_re… │ (None, 14, 14, │ 0 │ conv4_block6_1… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_2_co… │ (None, 14, 14, │ 590,080 │ conv4_block6_1… │ Y
## │ (Conv2D) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_2_bn │ (None, 14, 14, │ 1,024 │ conv4_block6_2… │ Y
## │ (BatchNormalizati… │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_2_re… │ (None, 14, 14, │ 0 │ conv4_block6_2… │ -
## │ (Activation) │ 256) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_3_co… │ (None, 14, 14, │ 263,168 │ conv4_block6_2… │ Y
## │ (Conv2D) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_3_bn │ (None, 14, 14, │ 4,096 │ conv4_block6_3… │ Y
## │ (BatchNormalizati… │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_add │ (None, 14, 14, │ 0 │ conv4_block5_o… │ -
## │ (Add) │ 1024) │ │ conv4_block6_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv4_block6_out │ (None, 14, 14, │ 0 │ conv4_block6_a… │ -
## │ (Activation) │ 1024) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_1_co… │ (None, 7, 7, │ 524,800 │ conv4_block6_o… │ Y
## │ (Conv2D) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_1_bn │ (None, 7, 7, │ 2,048 │ conv5_block1_1… │ Y
## │ (BatchNormalizati… │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_1_re… │ (None, 7, 7, │ 0 │ conv5_block1_1… │ -
## │ (Activation) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_2_co… │ (None, 7, 7, │ 2,359,808 │ conv5_block1_1… │ Y
## │ (Conv2D) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_2_bn │ (None, 7, 7, │ 2,048 │ conv5_block1_2… │ Y
## │ (BatchNormalizati… │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_2_re… │ (None, 7, 7, │ 0 │ conv5_block1_2… │ -
## │ (Activation) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_0_co… │ (None, 7, 7, │ 2,099,200 │ conv4_block6_o… │ Y
## │ (Conv2D) │ 2048) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_3_co… │ (None, 7, 7, │ 1,050,624 │ conv5_block1_2… │ Y
## │ (Conv2D) │ 2048) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_0_bn │ (None, 7, 7, │ 8,192 │ conv5_block1_0… │ Y
## │ (BatchNormalizati… │ 2048) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_3_bn │ (None, 7, 7, │ 8,192 │ conv5_block1_3… │ Y
## │ (BatchNormalizati… │ 2048) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_add │ (None, 7, 7, │ 0 │ conv5_block1_0… │ -
## │ (Add) │ 2048) │ │ conv5_block1_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block1_out │ (None, 7, 7, │ 0 │ conv5_block1_a… │ -
## │ (Activation) │ 2048) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_1_co… │ (None, 7, 7, │ 1,049,088 │ conv5_block1_o… │ Y
## │ (Conv2D) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_1_bn │ (None, 7, 7, │ 2,048 │ conv5_block2_1… │ Y
## │ (BatchNormalizati… │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_1_re… │ (None, 7, 7, │ 0 │ conv5_block2_1… │ -
## │ (Activation) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_2_co… │ (None, 7, 7, │ 2,359,808 │ conv5_block2_1… │ Y
## │ (Conv2D) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_2_bn │ (None, 7, 7, │ 2,048 │ conv5_block2_2… │ Y
## │ (BatchNormalizati… │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_2_re… │ (None, 7, 7, │ 0 │ conv5_block2_2… │ -
## │ (Activation) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_3_co… │ (None, 7, 7, │ 1,050,624 │ conv5_block2_2… │ Y
## │ (Conv2D) │ 2048) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_3_bn │ (None, 7, 7, │ 8,192 │ conv5_block2_3… │ Y
## │ (BatchNormalizati… │ 2048) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_add │ (None, 7, 7, │ 0 │ conv5_block1_o… │ -
## │ (Add) │ 2048) │ │ conv5_block2_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block2_out │ (None, 7, 7, │ 0 │ conv5_block2_a… │ -
## │ (Activation) │ 2048) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_1_co… │ (None, 7, 7, │ 1,049,088 │ conv5_block2_o… │ Y
## │ (Conv2D) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_1_bn │ (None, 7, 7, │ 2,048 │ conv5_block3_1… │ Y
## │ (BatchNormalizati… │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_1_re… │ (None, 7, 7, │ 0 │ conv5_block3_1… │ -
## │ (Activation) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_2_co… │ (None, 7, 7, │ 2,359,808 │ conv5_block3_1… │ Y
## │ (Conv2D) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_2_bn │ (None, 7, 7, │ 2,048 │ conv5_block3_2… │ Y
## │ (BatchNormalizati… │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_2_re… │ (None, 7, 7, │ 0 │ conv5_block3_2… │ -
## │ (Activation) │ 512) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_3_co… │ (None, 7, 7, │ 1,050,624 │ conv5_block3_2… │ Y
## │ (Conv2D) │ 2048) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_3_bn │ (None, 7, 7, │ 8,192 │ conv5_block3_3… │ Y
## │ (BatchNormalizati… │ 2048) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_add │ (None, 7, 7, │ 0 │ conv5_block2_o… │ -
## │ (Add) │ 2048) │ │ conv5_block3_3… │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ conv5_block3_out │ (None, 7, 7, │ 0 │ conv5_block3_a… │ -
## │ (Activation) │ 2048) │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ avg_pool │ (None, 2048) │ 0 │ conv5_block3_o… │ -
## │ (GlobalAveragePoo… │ │ │ │
## ├────────────────────┼──────────────────┼────────────┼─────────────────┼───────
## │ predictions │ (None, 1000) │ 2,049,000 │ avg_pool[0][0] │ Y
## │ (Dense) │ │ │ │
## └────────────────────┴──────────────────┴────────────┴─────────────────┴───────
## Total params: 25,636,712 (97.80 MB)
## Trainable params: 25,583,592 (97.59 MB)
## Non-trainable params: 53,120 (207.50 KB)
Finally, we classify our six images, and return the top three class choices in terms of predicted probability for each.
pred6 <- model %>% predict(x) %>%
imagenet_decode_predictions(top = 3)
## 1/1 - 3s - 3s/step
names(pred6) <- image_names
print(pred6)
## $Cape_Weaver.jpg
## class_name class_description score
## 1 n01843065 jacamar 0.49795407
## 2 n01818515 macaw 0.22193323
## 3 n02494079 squirrel_monkey 0.04287848
##
## $Flamingo.jpg
## class_name class_description score
## 1 n02007558 flamingo 0.92634964
## 2 n02006656 spoonbill 0.07169954
## 3 n02002556 white_stork 0.00122821
##
## $Hawk_cropped.jpg
## class_name class_description score
## 1 n01608432 kite 0.72270894
## 2 n01622779 great_grey_owl 0.08182577
## 3 n01532829 house_finch 0.04218888
##
## $Hawk_Fountain.jpg
## class_name class_description score
## 1 n03388043 fountain 0.2788652
## 2 n03532672 hook 0.1785546
## 3 n03804744 nail 0.1080733
##
## $Lhasa_Apso.jpg
## class_name class_description score
## 1 n02097474 Tibetan_terrier 0.50929731
## 2 n02098413 Lhasa 0.42209828
## 3 n02098105 soft-coated_wheaten_terrier 0.01695858
##
## $Sleeping_Cat.jpg
## class_name class_description score
## 1 n02105641 Old_English_sheepdog 0.83266014
## 2 n02086240 Shih-Tzu 0.04513866
## 3 n03223299 doormat 0.03299771
In this section we fit a CNN to the CIFAR data, which is
available in the torchvision package. It is arranged in a
similar fashion as the MNIST data.
# install.packages("torch")
# install.packages("luz")
# install.packages("torchvision")
# install.packages("torchdatasets")
# install.packages("zeallot")
library(torch)
##
## Attaching package: 'torch'
## The following object is masked from 'package:keras3':
##
## as_iterator
## The following object is masked from 'package:reticulate':
##
## as_iterator
library(luz) # high-level interface for torch
##
## Attaching package: 'luz'
## The following object is masked from 'package:keras3':
##
## evaluate
library(torchvision) # for datasets and image transformation
library(torchdatasets) # for datasets we are going to use
library(zeallot)
torch_manual_seed(6805)
transform <- function(x) {
transform_to_tensor(x)
}
train_ds <- cifar100_dataset(
root = "./",
train = TRUE,
download = TRUE,
transform = transform
)
## Dataset <cifar100_dataset> (~160 MB) will be downloaded and processed if not
## already available.
## Dataset <cifar100_dataset> loaded with 50000 images across 100 classes.
test_ds <- cifar100_dataset(
root = "./",
train = FALSE,
transform = transform
)
## Dataset <cifar100_dataset> loaded with 10000 images across 100 classes.
str(train_ds[1])
## List of 2
## $ x:Float [1:3, 1:32, 1:32]
## $ y: int 20
length(train_ds)
## [1] 50000
The CIFAR dataset consists of 50,000 training images, each
represented by a 3d tensor: each three-color image is represented as a
set of three channels, each of which consists of \(32\times 32\) eight-bit pixels. We
standardize as we did for the digits, but keep the array structure. This
is accomplished with the transform argument.
Before we start, we look at some of the training images.
par(mar = c(0, 0, 0, 0), mfrow = c(5, 5))
index <- sample(seq(50000), 25)
for (i in index) plot(as.raster(as.array(train_ds[i][[1]]$permute(c(2,3,1)))))

The as.raster() function converts the feature map so
that it can be plotted as a color image.
Here we specify a moderately-sized CNN for demonstration purposes.
conv_block <- nn_module(
initialize = function(in_channels, out_channels) {
self$conv <- nn_conv2d(
in_channels = in_channels,
out_channels = out_channels,
kernel_size = c(3,3),
padding = "same"
)
self$relu <- nn_relu()
self$pool <- nn_max_pool2d(kernel_size = c(2,2))
},
forward = function(x) {
x %>%
self$conv() %>%
self$relu() %>%
self$pool()
}
)
model <- nn_module(
initialize = function() {
self$conv <- nn_sequential(
conv_block(3, 32),
conv_block(32, 64),
conv_block(64, 128),
conv_block(128, 256)
)
self$output <- nn_sequential(
nn_dropout(0.5),
nn_linear(2*2*256, 512),
nn_relu(),
nn_linear(512, 100)
)
},
forward = function(x) {
x %>%
self$conv() %>%
torch_flatten(start_dim = 2) %>%
self$output()
}
)
model()
## An `nn_module` containing 964,516 parameters.
##
## ── Modules ─────────────────────────────────────────────────────────────────────
## • conv: <nn_sequential> #388,416 parameters
## • output: <nn_sequential> #576,100 parameters
Notice that we used the padding = "same" argument to
nn_conv2d(), which ensures that the output channels have
the same dimension as the input channels. There are 32 channels in the
first hidden layer, in contrast to the three channels in the input
layer. We use a \(3\times 3\)
convolution filter for each channel in all the layers. Each convolution
is followed by a max-pooling layer over \(2\times2\) blocks. By studying the summary,
we can see that the channels halve in both dimensions after each of
these max-pooling operations. After the last of these we have a layer
with 256 channels of dimension \(2\times
2\). These are then flattened to a dense layer of size 1,024: in
other words, each of the \(2\times 2\)
matrices is turned into a 4-vector, and put side-by-side in one layer.
This is followed by a dropout regularization layer, then another dense
layer of size 512, and finally, the output layer.
Finally, we specify the fitting algorithm, and fit the model.
fitted <- model %>%
setup(
loss = nn_cross_entropy_loss(),
optimizer = optim_rmsprop,
metrics = list(luz_metric_accuracy())
) %>%
set_opt_hparams(lr = 0.001) %>%
fit(
train_ds,
epochs = 10, #30,
valid_data = 0.2,
dataloader_options = list(batch_size = 128)
)
print(fitted)
## A `luz_module_fitted`
## ── Time ────────────────────────────────────────────────────────────────────────
## • Total time: 13m 15.4s
## • Avg time per training epoch: 1m 9.4s
##
## ── Results ─────────────────────────────────────────────────────────────────────
## Metrics observed in the last epoch.
##
## ℹ Training:
## loss: 2.3598
## acc: 0.3802
##
## ── Model ───────────────────────────────────────────────────────────────────────
## An `nn_module` containing 964,516 parameters.
##
## ── Modules ─────────────────────────────────────────────────────────────────────
## • conv: <nn_sequential> #388,416 parameters
## • output: <nn_sequential> #576,100 parameters
evaluate(fitted, test_ds)
## A `luz_module_evaluation`
## ── Results ─────────────────────────────────────────────────────────────────────
## loss: 2.4916
## acc: 0.3631
This model takes 10 minutes to run and achieves 36% accuracy on the test data. Although this is not terrible for 100-class data (a random classifier gets 1% accuracy), searching the web we see results around 75%. Typically it takes a lot of architecture carpentry, fiddling with regularization, and time to achieve such results.
We now show how to use a CNN pretrained on the imagenet
database to classify natural images. We copied six jpeg images from a
digital photo album into the directory book_images. (These
images are available from the data section of www.statlearning.com, the ISLR book
website. Download book_images.zip; when clicked it creates
the book_images directory.) We first read in the images,
and convert them into the array format expected by the
torch software to match the specifications in
imagenet. Make sure that your working directory in
R is set to the folder in which the images are stored.
img_dir <- "book_images"
image_names <- list.files(img_dir)
num_images <- length(image_names)
x <- torch_empty(num_images, 3, 224, 224)
for (i in 1:num_images) {
img_path <- file.path(img_dir, image_names[i])
img <- img_path %>%
base_loader() %>%
transform_to_tensor() %>%
transform_resize(c(224, 224)) %>%
# normalize with imagenet mean and stds.
transform_normalize(
mean = c(0.485, 0.456, 0.406),
std = c(0.229, 0.224, 0.225)
)
x[i,,, ] <- img
}
We then load the trained network. The model has 18 layers, with a fair bit of complexity.
model <- torchvision::model_resnet18(pretrained = TRUE)
model$eval() # put the model in evaluation mode
Finally, we classify our six images, and return the top three class choices in terms of predicted probability for each.
preds <- model(x)
mapping <- jsonlite::read_json("https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json") %>%
sapply(function(x) x[[2]])
top3 <- torch_topk(preds, dim = 2, k = 3)
top3_prob <- top3[[1]] %>%
nnf_softmax(dim = 2) %>%
torch_unbind() %>%
lapply(as.numeric)
top3_class <- top3[[2]] %>%
torch_unbind() %>%
lapply(function(x) mapping[as.integer(x)])
result <- purrr::map2(top3_prob, top3_class, function(pr, cl) {
names(pr) <- cl
pr
})
names(result) <- image_names
print(result)
## $Cape_Weaver.jpg
## hummingbird lorikeet bee_eater
## 0.3633287 0.3577293 0.2789420
##
## $Flamingo.jpg
## flamingo spoonbill white_stork
## 0.978211999 0.017045649 0.004742352
##
## $Hawk_cropped.jpg
## kite jay magpie
## 0.6157812 0.2311861 0.1530326
##
## $Hawk_Fountain.jpg
## eel agama common_newt
## 0.5391128 0.2527185 0.2081687
##
## $Lhasa_Apso.jpg
## Lhasa Tibetan_terrier Shih-Tzu
## 0.79760426 0.12013003 0.08226573
##
## $Sleeping_Cat.jpg
## Saint_Bernard guinea_pig Bernese_mountain_dog
## 0.3946672 0.3426990 0.2626339