Transfer learning with Keras – Part 1
Transfer learning is a hot topic at this moment. Due to the fact that architectures like VGG16/19, InceptionV3 and similar are built by default in frameworks as Keras, applying Transfer Learning (TL) techniques is becoming “easy” for the first steps and gain some intuition about a problem.
I’ll pick the VGG16 as starting point including the weights from ImageNet for pre-trained. Before diving into the details we need to pass through some theory to understand where the magic happens.
Firstly lets dive into the ConvNet (CNN / convolutional neural network) architecture VGG16 proposed by Karen Simonyan & Andrew Zisserman on their paper “Very Deep Convolutional Networks for Large-Scale Image Recognition“.
I’ll be referencing to the image above for a better understanding since highlight the key point that we need. VGG16 is composed by 16 layer of convolution mixed with max pooling with a stride of 2. This means that will pick the max value from a 2×2 grid. This lead us from a 224×224 layer to a 112×112. Conceptually, we are forwarding the relevant value for the network. Is worth to mention that each convolution is applying a Rectifier Linear Unit (ReLU) as activation function. Check Keras implementation below:
Don’t panic if doesn’t sounds familiar to you. At the end of the day, VGG16, until the flatten layer (loc. 28), what is doing is trying to get mos relevant features/parts of the image into a single 7×7 pixels with a dimension (channels) up to 512. Remember that the network input was 224×224 with 3 channels as dimension (RGB). Why is this relevant? This first set of layers do not classify nothing, you can think about as an eye. The ones that are perming the classification are the last fully-connected layers with an output of 1000. Via the Softmax as activation we would be able to get the highest label probability by mapping it to the labels of ImageNet. Probably, if you played with this at any point in time, you are getting the intuition for the next steps.
One more thing before diving into transfer learning see the param using the out of the box VGG16 model from Keras:
Remember that include_top parameter at the moment of creating the model. We will get there 😉
There is a reason why include_top is one of the parameters. include_top: whether to include the 3 fully-connected layers at the top of the network. This means that we are going to get as output from the layer block5_pool(26) at my gist. Here is where we found the intuition about where the real classification is done. Indeed, those 3 fully connected layers where the softmax over the last one, plus the 1000 labels from the image data set, will bring us the probably of x been an object.
One of the good things about Transfer Learning, or the idea behind, is to pre load ImageNet data set weights into the model. If we remove the top layers, and freeze the rest of the model, we can reuse the ability of rest of the network on focusing on most relevant features on the image by freezing those layers and adding our own network to handle the new feature or simple, a different classifications (there are tons of examples about cats :D).
Therefore, we can follow three main approaches:
- Retrain all the network. In same cases this works if the nature of the images vary from the original ImageNet dataset (xD 14,197,122 images).
- Freeze the first block of the VGG16 network. In same cases, if you have a large data set, freezing the layer until the 56x56x256 block. The rest of the network will be trained. With a small data set you can end under fitting the network. The large CNN requires a huge amount of data for been trained properly.
- Freeze until the top. The most used approach for quick tests and start gaining some intuition around the answers. With a decent dataset (~5000 images per label) is it possible to achieve +95% results with 2 to 3 blocks of fully connected layers.
In this case we are freezing the layers from the VGG16 network and passing the output to a two fully connected layers plus the output one (predictions). For this case we are user an stochastic gradient descent with an standard learning rate.
For the case, I will try to use a short data set (~500 images per label). In the next post I’m going to use augmentation and Keras ImageGenerator to extend the data set x10.