Zephyrnet Logo

Deep Transfer Learning for Image Classification

Date:

The following tutorial covers how to set up a state of the art deep learning model for image classification. The approach is based on the machine learning frameworks “Tensorflow” and “Keras”, and includes all the code needed to replicate the results in this tutorial.

The prerequisites for setting up the model is access to labelled data, and as an example case I have used images of various traffic signs (which can be downloaded here). The task of the model is thus to predict what kind of traffic sign it sees. To make the example case more realistic, I have reduced the amount of data to max 200 images per class (as limited amount of data is usually the case in practical applications of machine learning).

These images are of course only included as an example to get you started. They can easily be replaced with your own images as long as you follow the same folder structure as the current setup, as explained below.

If this in-depth educational content on computer vision is useful for you, you can subscribe to our AI research mailing list to be alerted when we release new material. 

Replacing images with your own data:

Place your images in subfolders under the main folder “data/” with the name of the image category as subfolder name, as in the example folder structure shown below. First, you need to split your images into training, validation and test data. The images from the “training_data” folder are the actual images used to train the model, whereas images from “validation_data” are used for optimizing training and model hyper-parameters. The test data is then used as the final assessment, to evaluate the accuracy of the model on a completely independent set of images.

Example folder structure with included dataset:

Training data:

  • data/train/category_1 : Images of signs from category 1
  • data/train/category_2 : Images of signs from category 2
  • …………………………………..

Validation data:

  • data/val/category_1 : Images of signs from category 1
  • data/val/category_2 : Images of signs from category 2
  • …………………………………..

Test data:

  • data/test/category_1 : Images of signs from category 1
  • data/test/category_2 : Images of signs from category 2
  • …………………………………..

One way of splitting the images between “train” “validate” and “test” is e.g to use 80% of the images for training the model, and validate/test on 10% each. For a brief introduction to the importance of separating “train”, “validation” and “test” data, you can also have a read here.

Libraries and packages necessary for defining and running the models

These are some useful python libraries/packages that make our life a lot easier, as we do not have to write all the code and functionality from scratch. Building a deep learning model without these libraries/packages would actually be quite a tremendous task!

import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns from numpy.random import seed
seed(1337)
from tensorflow import set_random_seed
set_random_seed(42) from tensorflow.python.keras.applications import vgg16
from tensorflow.python.keras.applications.vgg16 import preprocess_input
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img
from tensorflow.python.keras.callbacks import ModelCheckpoint
from tensorflow.python.keras import layers, models, Model, optimizers from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from plot_conf_matr import plot_confusion_matrix

Define train/test data and the different classes:

Here, we define the location of train/val/test images and the names of all the different categories we want to classify. We then plot the number of images per category in the training set.


train_data_dir = "data/train"
val_data_dir = "data/val"
test_data_dir = "data/test" category_names = sorted(os.listdir('data/train'))
nb_categories = len(category_names)
img_pr_cat = [] for category in category_names: folder = 'data/train' + '/' + category img_pr_cat.append(len(os.listdir(folder))) sns.barplot(y=category_names, x=img_pr_cat).set_title("Number of training images per category:")
 image classification

Overview of the number of training images per class

Let us also plot a few example images from the various sign categories, to visualize the typical image quality:

for subdir, dirs, files in os.walk('data/train'): for file in files: img_file = subdir + '/' + file image = load_img(img_file) plt.figure() plt.title(subdir) plt.imshow(image) break
 image classification

Some example images

As you can see from the example images above, the resolution and quality are not great. However, both image quality and amount of data are often quite limited in practical applications of machine learning. As such, low quality images limited to a maximum of 200 training images per class represents a more realistic example than using thousands of “perfect” high quality images.

Transfer learning

There is no need at this stage to understand all the details of various types of deep learning models, but a summary of some common ways of building models can be found here for those interested.

In this tutorial, we use a pre-trained deep learning model (VGG16) as the basis for our image classifier model, and then retrain the model on our own data, i.e. transfer learning.

img_height, img_width = 224,224
conv_base = vgg16.VGG16(weights='imagenet', include_top=False, pooling='max', input_shape = (img_width, img_height, 3))

You might notice the parameter “pooling= ‘max’ “ above. The reason for that, is that rather than connecting the convolutional base of the VGG16 model to a couple of fully connected layers before the final output layer (which is done in the original VGG16 model), we rather use a max-pooling output(one can also use “average pooling”, as it depends on the use case which approach works best). This approach is an alternative to using fully connected layers to transition from feature maps to an output prediction for the model. In my experience this approach usually works very well, and makes the model less prone to overfitting, as also described in this paper:

Conventional convolutional neural networks perform convolution in the lower layers of the network. For classification, the feature maps of the last convolutional layer are vectorized and fed into fully connected layers followed by a softmax logistic regression layer. This structure bridges the convolutional structure with traditional neural network classifiers. It treats the convolutional layers as feature extractors, and the resulting feature is classified in a traditional way.

However, the fully connected layers are prone to overfitting, thus hampering the generalization ability of the overall network. In this paper, we propose another strategy called global average pooling to replace the traditional fully connected layers in CNN. Instead of adding fully connected layers on top of the feature maps, we take the average of each feature map, and the resulting vector is fed directly into the softmax layer. One advantage of global average pooling over the fully connected layers is that it is more native to the convolution structure by enforcing correspondences between feature maps and categories. Thus the feature maps can be easily interpreted as categories confidence maps. Another advantage is that there is no parameter to optimize in the global average pooling thus overfitting is avoided at this layer. Furthermore, global average pooling sums out the spatial information, thus it is more robust to spatial translations of the input. We can see global average pooling as a structural regularizer that explicitly enforces feature maps to be confidence maps of concepts (categories).

Having loaded the pre-trained VGG16 model, we can also choose to freeze the “deeper layers” of the model in the code block below, and only re-train the last few layers on our own data. This is a common transfer learning strategy, and is often a good approach when the amount of data available for training is limited.

 image classification

Transfer learning

This option is currently commented out from the code (using the #symbol), and we are thus retraining all layers of the model. The number of layers to train represents a parameter you can experiment with yourselves. How does the number of trainable layers affect model performance?

#for layer in conv_base.layers[:-13]:
# layer.trainable = False

As a check we can also print a list of all layers of the model, and whether they are trainable or not (True/False).

for layer in conv_base.layers: print(layer, layer.trainable)

 image classification

Using the VGG16 model as a basis, we now build a final classification layer on top to predict our defined classes. We then print a model summary, lisiting the number of parameters of the model. If you decide to “freeze” some of the layers, you will notice that the number of “Trainable parameters” below will be lower.

model = models.Sequential()
model.add(conv_base)
model.add(layers.Dense(nb_categories, activation='softmax'))
model.summary()

 image classification

As you can see, the output shape of the final layer of the model corresponds to the number of classes, which in our case is 10.

Generators for reading and processing images

We then need to define some functions that read images from our folders and feeds them to the image classifier model. As a part of this we also add some basic image preprocessing, where the input images are scaled to have pixel values in the range [0,1], (from 0–255 in the original images).


#Number of images to load at each iteration
batch_size = 32 # only rescaling
train_datagen = ImageDataGenerator( rescale=1./255
)
test_datagen = ImageDataGenerator( rescale=1./255
) # these are generators for train/test data that will read pictures #found in the defined subfolders of 'data/' print('Total number of images for "training":')
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size = (img_height, img_width),
batch_size = batch_size, class_mode = "categorical") print('Total number of images for "validation":')
val_generator = test_datagen.flow_from_directory(
val_data_dir,
target_size = (img_height, img_width),
batch_size = batch_size,
class_mode = "categorical",
shuffle=False) print('Total number of images for "testing":')
test_generator = test_datagen.flow_from_directory(
test_data_dir,
target_size = (img_height, img_width),
batch_size = batch_size,
class_mode = "categorical",
shuffle=False)
 image classification

Output from running the above code block

Define model parameters and start training

Here, we define some of the parameters that controls the training process of the model. Important parameters are e.g. training rate, how many epochs to train the model and which optimizer to use. You do not need to understand all these terms to follow the tutorial, but those interested can have a quick read here.

We also define a checkpoint parameter, where we keep track of the validation accuracy after each epoch during training. Using this, we always keep a copy of the model that performs best during the training process.


learning_rate = 5e-5
epochs = 10 checkpoint = ModelCheckpoint("sign_classifier.h5", monitor = 'val_acc', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1)
model.compile(loss="categorical_crossentropy", optimizer=optimizers.Adam(lr=learning_rate, clipnorm = 1.), metrics = ['acc'])

We are now ready to start training the model on our own data, and for each “epoch” we print the training and validation loss and accuracy. The model accuracy, as measured on the training data, is given by “acc”, and the accuracy on the images in the validation set is given by “val_acc”. This is the most important quantity, as it tells us how accurate the model is on images it has not already seen during the training process.

Ideally, the “val_acc” should increase for each epoch as we keep training the model, and eventually reach a steady value when our model is not able to learn any more useful information from our training data.

history = model.fit_generator(train_generator, epochs=epochs, shuffle=True, validation_data=val_generator, callbacks=[checkpoint] )
 image classification

Output during training

From the output shown above, we see that the loss decreases while the accuracy increases during the training process. Each time the validation accuracy reaches a new maximum value, the checkpoint file is saved (output: “saving model to sign_classifier.h5”). After the training has completed, we then load the checkpoint file which had the best validation accuracy during training:


model = models.load_model("sign_classifier.h5")

Evaluating model accuracy

We first visualize the changes in model accuracy and loss during the training process, as this gives us important information to evaluate what we can do to improve accuracy. For a nice introduction to this topic, you can also have a look at this video:

Code for plotting and saving the learning curves:

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss'] epochs = range(1,len(acc)+1) plt.figure()
plt.plot(epochs, acc, 'b', label = 'Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()
plt.savefig('Accuracy.jpg') plt.figure()
plt.plot(epochs, loss, 'b', label = 'Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.savefig('Loss.jpg')

 image classification

Starting with the left figure, showing the training/valication accuracy: The blue line represents the model accuracy as measured on the training images, and we see that this quickly reaches a value of almost 1 (which represents classifying 100% of the training images correctly). However, the validation accuracy is the accuracy measured on the validation set, which is the accuracy we really care about. In this case, the accuracy leveled off at around 97–98%, meaning that we succesfully classified almost all of the images in our validation set to the correct category.

To learn more about the accuracy for the different categories, we can calculate and plot the “confusion matrix”. This represents an illustrative way of evaluating model accuracy, as it compares the “true” vs. “predicted” class for all images in the test set. Note: do not worry if you do not get exactly the same numbers when re-running the code! There are some inherent randomness in model initialization etc. which make the results differ slightly from time to time.

The code to calculate and plot the confusion matrix is included below the figure.

Y_pred = model.predict_generator(test_generator)
y_pred = np.argmax(Y_pred, axis=1) cm = confusion_matrix(test_generator.classes, y_pred)
plot_confusion_matrix(cm, classes = category_names, title='Confusion Matrix', normalize=False, figname = 'Confusion_matrix_concrete.jpg')

 image classification

There is the code from the script “plot_conf.py”, which contains the function for plotting the confusion matrix, “plot_confusion_matrix”.

import numpy as np
import matplotlib.pyplot as plt def plot_confusion_matrix(cm, classes, figname, normalize=False, title=’Confusion matrix’, cmap=plt.cm.Blues): “”” This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. “”” import numpy as np import matplotlib.pyplot as plt import itertools if normalize: cm = cm.astype(‘float’) / cm.sum(axis=1)[:, np.newaxis] print(“Normalized confusion matrix”) else: print(‘Confusion matrix, without normalization’) plt.figure(figsize=(8,8)) plt.imshow(cm, interpolation=’nearest’, cmap=cmap) plt.title(title) #plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=90) plt.yticks(tick_marks, classes) fmt = ‘.2f’ if normalize else ‘d’ thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, format(cm[i, j], fmt), horizontalalignment=”center”, color=”white” if cm[i, j] > thresh else “black”) plt.ylabel(‘True label’) plt.xlabel(‘Predicted label’) plt.tight_layout() plt.savefig(figname)

As seen from the confusion matrix above, the main category the model misclassified was“Intersection”, where it mistakes the category with that of “Yield” in 10 of the images. As a final metric, we can also calculate the total accuracy evaluated on the test set.

accuracy = accuracy_score(test_generator.classes, y_pred)
print("Accuracy in test set: %0.1f%% " % (accuracy * 100))

This gives as output an accuracy of 98%, which is not bad! But, can we do better? We have a limited amount of data, so how about trying to improve that using image augmentation?

Model with image augmentation

In our case, the model already performs very well, with an accuracy of 97–98%. However, one strategy when dealing with limited amount of training data is that of “image augmentation”. That is, we make a collection of copies of the existing images, but with some minor changes. Those changes could be transformations like e.g. slight rotations, zooming, flipping images horizontally, ++. Further examples of image augmentation are also covered here.

In the following, we define the same model as before, but here we also incorporate image augmentation as a way of artificially increasing the amount of training data.

 image classification

Code to build a new model, using the same convolutional base and model structure as before:

conv_base = vgg16.VGG16(weights='imagenet', include_top=False, pooling='max', input_shape = (img_width, img_height, 3)) #for layer in conv_base.layers[:-13]:
# layer.trainable = False model = models.Sequential()
model.add(conv_base)
model.add(layers.Dense(nb_categories, activation='softmax'))

Augmentations

The only thing we need to change in our code, is the definition of the training data generator shown below. We can here add some data augmentation strategies, such as e.g. random rotations in the range [-10,10] degrees, a random zoom and width/height shift in the range +-10%, and changes in brightness in the range +-10%.

As examples of augmented images, we can save them to a specified folder “augm_images” as defined in the function “train_generator” below. This option is currently commented out (to avoid saving thousands of images), but you can change that if you want to visualize the augmentations you incorporate. This is often a good idea, just to make sure that the augmented images still make sense for the use-case you are working on.

train_datagen = ImageDataGenerator( rescale=1./255, rotation_range=10, zoom_range=0.1, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=False, brightness_range = (0.9,1.1), fill_mode='nearest' ) # this is a generator that will read pictures found in
# subfolers of 'data/train', and indefinitely generate
# batches of augmented image data train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size = (img_height, img_width),
batch_size = batch_size, #save_to_dir='augm_images', save_prefix='aug', save_format='jpg',
class_mode = "categorical")

Train new model using augmented data

We are now ready to train the same model using additional augmented data, which should hopefully increase model accuracy.

learning_rate = 5e-5
epochs = 20
checkpoint = ModelCheckpoint("sign_classifier_augm.h5", monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1)
model.compile(loss="categorical_crossentropy", optimizer=optimizers.Adam(lr=learning_rate, clipnorm=1.), metrics = ['acc']) history = model.fit_generator(train_generator, epochs=epochs, shuffle=True, validation_data=test_generator, callbacks=[checkpoint] )

After the training has completed, we again load the checkpoint file which had the best validation accuracy during training:

model = models.load_model("sign_classifier_augm.h5")

Plot the learning curves:

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss'] epochs = range(1,len(acc)+1) plt.figure()
plt.plot(epochs, acc, 'b', label = 'Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()
#plt.savefig('Accuracy_Augmented.jpg') plt.figure()
plt.plot(epochs, loss, 'b', label = 'Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
#plt.savefig('Loss_Augmented.jpg')

 image classification

Calculate and plot confusion matrix:

Y_pred = model.predict_generator(test_generator)
y_pred = np.argmax(Y_pred, axis=1) cm_aug = confusion_matrix(test_generator.classes, y_pred)
plot_confusion_matrix(cm_aug, classes = category_names, title='Confusion Matrix', normalize=False, figname = 'Confusion_matrix_Augm.jpg')

 image classification

Calculate the final accuracy, as evaluated on the test set:


accuracy = accuracy_score(test_generator.classes, y_pred)
print("Accuracy in test set: %0.1f%% " % (accuracy * 100))

This gives an output of 99.3%, which is an improvement compared to our initial model without augmented images!

Evaluation of model accuracy

As seen from the above results for model accuracy, data augmentation indeed increased the accuracy of our model. In the current example, we obtained a final accuracy of approximately 99%. In addition, by inspecting the confusion matrix above, we can check which of the sign categories the model classifies incorrectly. Here, we notice that the model still misclassified “Intersection” as “Yield” in a few cases, but significantly better than the model without image augmentation.

Note: Do not worry if you do not get exactly the same numbers when re-running the code! There is some inherent randomness in the model initialization etc. which could make the results differ slightly from time to time.

Plot a few images from the test set, and compare model prediction with ground truth

As a final visualization of model accuracy, we can plot a subset of the test images along with the corresponding model prediction.

Define a folder for “test_subset”, where I have included 50 of the images from the test set:

test_subset_data_dir = "data/test_subset" test_subset_generator = test_datagen.flow_from_directory(
test_subset_data_dir,
batch_size = batch_size,
target_size = (img_height, img_width),
class_mode = "categorical",
shuffle=False)

Make predictions for the images contained in this folder, and visualize the images along with the predicted and actual class. Do you agree with the classifications?

Y_pred = model.predict_generator(test_subset_generator)
y_pred = np.argmax(Y_pred, axis=1) img_nr = 0
for subdir, dirs, files in os.walk('data/test_subset'): for file in files: img_file = subdir + '/' + file image = load_img(img_file,target_size=(img_height,img_width)) pred_emotion = category_names[y_pred[img_nr]] real_emotion = category_names[test_subset_generator.classes[img_nr]] plt.figure() plt.title('Predicted: ' + pred_emotion + 'n' + 'Actual: ' + real_emotion) plt.imshow(image) img_nr = img_nr +1
 image classification

Example output from classification model

Summary

If you managed to run through the entire tutorial using the included dataset, you have hopefully gotten a feeling for how deep learning and image recognition can be used to solve a real-world problem of traffic sign classification. Best of luck exploring the model further with other images, either from your company or from resources such as e.g. kaggle, or simply google image search!

If you want some more detailed information regarding this image classification tutorial (and machine learning in general), I also cover the material in the workshop presentation below: “ From hype to real-world applications” (tutorial walkthrough starts approx. 35 minutes into the video). Good luck!

This article was originally published on Towards Data Science and re-published to TOPBOTS with permission from the author.

Enjoy this article? Sign up for more computer vision updates.

We’ll let you know when we release more technical education.

Source: https://www.topbots.com/deep-transfer-learning-image-classification/?utm_source=rss&utm_medium=rss&utm_campaign=deep-transfer-learning-image-classification

spot_img

Latest Intelligence

spot_img

Chat with us

Hi there! How can I help you?