An Introduction to Transfer Learning

Image result for transfer learning

Computer vision and natural language processing programs use Artificial Intelligence to accomplish remarkable things that benefit society in so many ways. For example, by using image classification to detect cancer in patients better than doctors or by making roads safer with self-driving vehicles!

Unfortunately, the development of AI algorithms and models face barriers which limit their ability to be engineered to solve new problems. One of these barriers is the fact that the Machine Learning Algorithms we use to make these AIs take a lot of time and resources to develop to a point where they become useful.

Training a good image detector from scratch can take days even on an optimized computer using GPUs.

However, this begs the question if training our models from scratch takes so long, why do we have to train our models from scratch?

The answer is we don’t, and logically it doesn’t even make sense to do so.

Think about how your brain learns, when you and your brain learn something new your brain doesn’t just start from scratch. It did while you were a baby and didn’t know anything, but now it applies what it already knows to easily help it figure out new concepts.

It takes children so long to learn tasks we consider simple such as counting to ten and speaking. However, as they become older they can learn much more complicated things such as the quadratic function at a much faster rate.

This is because people are able to apply their prior knowledge to their new learning experiences rather than always learning from scratch.

If we apply this same methodology to AIs we get Transfer Learning

This has the capability of working because features learned by machine learning models are often transferable. This concept of being able to transfer learned features is so important because it is the base of transfer learning.

In practice, transfer learning is amazing because it allows us to cut down on the amount of time and resources that go into building a machine learning model.

The diagram below does a really good job of expressing the importance of transfer learning. By using this method, we can create models that converge with higher accuracy and require less time to train!

Image for post
Image for post

To truly understand why and how this works, we need to first understand the basics of how neural networks work.

Deep Neural Networks are made up of many layers of nodes (neurons). Each node in one layer connects to all the nodes in the next. The input layer consists of input nodes which then trigger nodes in the next layer and a chain reaction that travels through the network until the corresponding output nodes get triggered. What determines if a node triggers another node or not is the weights and biases which link each node. These are what get adjusted and changed when the network trains in order to produce an accurate model. Ideally, each layer in the network learns to detect a specific feature, for example, let’s look at the theoretical 4-layer network below.

Image result for 4 layer neural network

Let’s pretend this network is designed to detect dogs in images. In reality, such a network would need to be much more complicated than this, however, the general concept is the same so the following still applies.

The first layer in the network is the input layer which takes in the pixel values of the image. The second and third layers are hidden layers which means we don’t know exactly what they do, only that they connect the input layer to the output layer. Finally, the output layer tells us if a dog was detected in the image or not.

Even though we don’t know exactly what the second layer does, ideally it might learn to be responsible for detecting edges in the image. In the same way, the third layer might ideally learn to detect shapes (which would be combinations of the edges detected in the second layer). And the final layer should learn to detect dogs which can be broken down into a specific combination of shapes, which are combinations of edges, which are determined from pixel values in the image. Is it starting to make sense?

Basically, each layer in the network learns to detect features which are combinations of the features in the previous layers. As an image of a dog is fed through the network it should pick up on its specific features and realize that they combine to form a dog.

Image for post
Image for post
A dog is a combination of shapes and lines

Now imagine that we want to train a new model that can detect cats instead of dogs. Rather than training a new network, we can simply retrain the final layer of the old network to recognize the shape combinations that correspond to cats rather than dogs.

Image for post
Image for post
A cat is also a combination of shapes and lines

This works efficiently because we already trained the network on how to recognize shapes. This means that now all the network needs to do during training is leverage its prior knowledge of shapes and features, and learn to recognize the new combinations of shapes, corresponding to the presence of cats.

To implement transfer learning well, there are a few different things that can be done, but overall it all follows the same methodology.

A common approach is to completely transfer the first layers of a selected network, and then randomly initialize the later ones. This allows you to keep the pattern recognition found in the early layers and then remap those patterns to your output classes in the final layers. This is kind of like using the pre-trained model as a feature extractor for your new model.

Similar to the previous approach, you can set up a training configuration where you start by only adjusting the later layers, but then later you finetune all the layers. This allows you to maintain the model's ability to recognize features, but also finetune that feature extraction to your situation.

Even though transfer learning can have so many benefits when its implemented in the right situation, it’s not always useful.

Transfer learning is only useful if the desired models have similar tasks. This is because the patterns learned by the pre-trained model have to be similar to the patterns you want your new model to have.

For example, a model trained on image net for object detection would be completely useless for creating a model to predict stock prices. Those tasks are completely unrelated! It would, however, be useful if you wanted to detect objects like traffic lights because that would be a very similar task.

Intuitively, it's like trying to apply your knowledge of one subject to a completely unrelated one. Your prior-knowledge of algebra is not going to help you learn French. It can, however, help you learn calculus.

Another important aspect of transfer learning is that you cant use a pre-trained model to make a new model with a different architecture. The shapes would be mismatched. This means if you require a network architecture that doesn’t already have any pre-trained models available, you will have to train your network from scratch.

However, an ingenious way to deal with this would be to first pre-train your network architecture on a large dataset such as ImageNet for object detection. Then, retrain your model on your own data set collected specifically for your problem.

This is especially useful if your dataset is quite small and a similar large dataset is available which can be used in the pre-training step.

Having a small dataset can often make it difficult for networks to learn patterns and produce accurate models. Training on a large dataset first would allow your model to learn the applicable patterns and then apply them to your problem.

This should result in quicker learning time and higher accuracy, just like what was shown in the first graph.

Overall, transfer learning is an incredible way to take advantage of machine learning resources and use them to solve your own problems. Below are some of the other key takeaways!

  • Transfer learning can be used to reduce training time and improve model accuracy
  • Transfer learning is particularly easy to implement with computer vision problems. These problems already have many models trained on huge datasets such as ImageNet which are publicly available and can be easily used
  • With transfer learning, you can only use it if the models are similar
  • Transfer learning doesn’t need to be used in every situation
  • Transfer learning is particularly useful when you have a small dataset for your problem
  • If no compatible pre-trained model exists, it's possible to create your own if you can access a large enough dataset
  • If you want to try out transfer learning for yourself, I recommend you start by looking into TensorFlow which offers amazing support for it. There are even tutorials which walk you through the entire process

Before you go!

  1. Clap this article
  2. Share it with others who might benefit
  3. Connect with me on LinkedIn

Written by

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store