DEV Community

Manoj Kumar Patra
Manoj Kumar Patra

Posted on • Edited on

Model Training Patterns - Transfer Learning

Transfer learning, in short, is incorporating a previously trained model with its weights frozen (or non-trainable) and final layer removed, into a new model to solve a similar (more specialized) problem. The final layer is replaced with the output layer of our specialized task before training is continued.

Some common use cases where transfer learning can be useful include:

  1. Image object detection
  2. Image style transfer
  3. Image generation
  4. Text classification
  5. Machine translation

Bottleneck layer

Bottleneck layer is the layer that represents the input in the lowest dimensionality space.

Typically, it's the last layer before a flattening operation.

This is the last layer of the pre-trained model that we want to load and attach the new custom layers to.

Following is how we would use a VGG model with Tensorflow for transfer learning with a specialized image data with each example of the shape 150 X 150 X 3:

vgg_model = tf.keras.application.VGG19(
  # last layer to be loaded is the bottleneck layer
  include_top=False,
  weights='imagenet',
  input_shape=((150, 150, 3))
)

# 0 trainable parameters in the pre-trained model
# Feature extraction
vgg_model.trainable = False

feature_batch = vgg_model(image_batch)

global_avg_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_avg = global_avg_layer(feature_batch)

prediction_layer = tf.keras.layers.Dense(8, activation='softmax')
prediction_batch = prediction_layer(feature_batch_avg)

specialized_model = keras.Sequential([
  vgg_model,
  global_avg_layer,
  prediction_layer
])
Enter fullscreen mode Exit fullscreen mode

The embedding analogy

The bottleneck layer in a pre-trained model can be compared to an embedding.

Consider an encoder-decoder architecture. A bottleneck layer then acts as an embedding to represent the original data in a lower dimensional representation, which then the decoder, i.e., the custom new network after the bottleneck layer, uses to decode this representation back to its original.

Transfer learning with TF Hub

hub_layer = hub.KerasLayer(
  "url_to_model",
  input_shape=[],
  dtype=tf.string,
  # Fine-tuning 
  trainable=True

model = keras.Sequential([
  hub_layer,
  keras.layers.Dense(32, activation='relu'),
  keras.layers.Dense(1, activation='sigmoid')
])
Enter fullscreen mode Exit fullscreen mode

To summarize, transfer learning is doing the following:

  1. Learning new things (combination of pre-trained model and the network after the bottleneck layer
  2. Learning how to learn new things (with pre-trained model)

What happens when we modify the weights of the pre-trained model?

In the first example above, we have set vgg_model.trainable to False. This is feature extraction. Here we only train the custom layers after the bottleneck layer.

In the second example, however, we have set trainable to True. This is considered to be fine-tuning the weights of the pre-trained model. We can either update weights for all layers or some layers of the pre-trained model.

Decide the number of layers to fine-tune

To decide the number of layers that should be fine-tuned, we can use the following approach:

  1. Keep the learning rate low ~ 0.001
  2. Keep the number of iterations small
  3. Unfreeze from the end with each iteration and monitor the model's loss after training
  4. Iteratively unfreeze more layers and continue monitoring the model's loss after training
  5. Stop when the first layer is reached or the model's loss becomes plateau.

This approach is considered to be progressive fine-tuning.

Applying Feature extraction vs. Fine tuning

Feature extraction Fine-tuning
Smaller dataset (100 to 1000 examples) Larger dataset (~ 1 million examples)
Prediction task is different than that of the pre-trained model Prediction task is same as or similar to that of the pre-trained model
Preferred when budget is low Preferred when budget is high as training time and computation cost will be higher with updating weights of pre-trained model

Top comments (0)