Is Transfer Learning Magic or Pure Genius?
The practice of leveraging pre-trained model weights to save time and increase computational efficiency.
At a high level, Transfer Learning is the practice of utilizing the learnings from an existing model for a new model. This is common practice with large vision and language models pre-trained on massive datasets. These models often have millions or billions of learned parameters and require tremendous computational power to train.
In 2018, Google open-sourced BERT (Bidirectional Encoder Representations from Transformers), a revolutionary language model that was pre-trained using the entirety of the English Wikipedia and Google’s BooksCorpus — for a total of ~3.3 billion words. BERT was trained, by Google, over 4 days on 64 TPUs (Tensor Processing Units). V4-TPUs on Google Cloud have an on-demand cost of $3.22/hour, which gives us a ballpark number of $19,800 to train BERT from scratch — a significant amount of money for most individuals and organizations.
We are in luck, though! Using transfer learning principles, BERT can adapt to the ever-growing body of searchable content and queries and fine-tune to a user’s specifications. The best part is that transfer learning requires significantly less training data and is typically limited to updating the weights on a smaller chunk of the model’s layers; this leads to reduced training time and compute costs.
Let’s Review Some High-Level Concepts
- Learning from Scratch is your traditional deep learning scheme where you initialize the weights of your network as zeros, randomly, or with some predefined value. The model then uses backpropagation to update the weights of your model based on some objective function and optimizer. This type of training is generally more computationally expensive, requiring accelerators like GPUs and HPUs, and can require hours, days, or weeks to complete.
Example: You acquire a large dataset of cats and dogs, design a neural network, and train it from scratch to classify those two animals based on their species.
- Domain Adaptive Transfer Learning is when we start with a model pre-trained on an original dataset, introduce a completely different dataset, and use it to retrain the model. This adapts your model to a new problem and transfers the learnings from a previous dataset.
Example: You load the weights from a pre-trained model like VGGNet, which has been trained on ImageNet, freeze the shallow layers of the model, and train the model for use in an entirely different domain. Example of transfer learning a VGG-16 model to detect defective products in manufacturing production lines.
- Fine-tuning is a sub-category of transfer learning that doesn’t switch domains but updates an existing model with new data. For this to make sense, our original dataset would have to be significantly larger than the new data so that we can justify preserving some of the weights through transfer learning techniques. If the datasets are similar, it might be worth training a new model from scratch.
Example: You want to retrofit the original BERT model to classify tweets as humorous or offensive. After collecting thousands of sample tweets, you can use them for tuning the original model to classify data from the new source.
In practice, transfer learning usually requires that you freeze some base layers of a pre-trained model and then adjust the output layers (also called head/classifier layers) to suit your needs. Sometimes, very small learning rates can be used instead of or in tandem with layer freezing. Lower learning rates help preserve the pre-trained weights but allow minor adjustments from the new dataset. Applying these techniques is straightforward but varies from framework to framework.
Check out these resources to get started with some Keras and PyTorch examples.
Conclusion and Discussion
Transfer Learning isn’t magic; it’s pure genius. It allows us to leverage the “skill” of a previously trained model on other datasets. However, it’s important to remember that it is not a silver bullet and requires thoughtful planning
Key Benefits of transfer learning:
- Reduced training times
- Reduced training dataset requirement
- Flexibility to leverage general hardware rather than costly accelerated hardware
- No requirement for a significant redesign of architecture for new applications
Key Challenges of transfer learning:
- Selecting pre-trained models that have the appropriate level of relevancy and complexity for the target application.
- Determining how much data is enough data
- Defining how much you should perturb the pre-trained weights.
Disclaimer: The definitions of transfer learning, domain adaptation, and fine-tuning are not clearly defined across the professional and academic community. At times people will use “fine-tuning” when referring to both fine-tuning and domain adaptive transfer learning. It is ultimately up to the individual to communicate the machine learning techniques they intend to apply.