Transfer Learning with Pre-trained Models

Keras Basics

2 min read

Published Nov 17 2025


11
0
0
0

KerasNeural NetworksPythonTensorFlow

Transfer learning is one of the most powerful techniques in modern deep learning.


Instead of training a model from scratch on your dataset, you:

  1. Take a model pre-trained on a huge dataset (ImageNet)
  2. Freeze its convolutional base
  3. Add your own classifier layer(s)
  4. Train for a short time
  5. (Optional) fine-tune deeper layers

Benefits:

  • Much higher accuracy
  • Works with small datasets
  • Faster training
  • Lower compute cost

In this section we apply transfer learning to Cats vs Dogs using TensorFlow Datasets






Load the Cats vs Dogs Dataset

TensorFlow Datasets (TFDS) provides cats_vs_dogs fully prepared for use.


Install TFDS if needed:

pip install tensorflow-datasets

Load dataset:

import tensorflow_datasets as tfds
import tensorflow as tf

(ds_train, ds_val), ds_info = tfds.load(
    "cats_vs_dogs",
    split=["train[:80%]", "train[80%:]"],
    shuffle_files=True,
    
    as_supervised=True,
    with_info=True
)

print(ds_info)

Labels:

  • 0 = cat
  • 1 = dog





Build Input Pipeline

Transfer learning models expect consistent image size, so resize all images.

img_size = (160, 160)
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def format_example(image, label):
    image = tf.image.resize(image, img_size)
    # normalize to 0–1
    image = image / 255.0
    return image, label

train_ds = (
    ds_train
    .map(format_example, num_parallel_calls=AUTOTUNE)
    .cache()
    .shuffle(1000)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

val_ds = (
    ds_val
    .map(format_example, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

This pipeline:

  • Resizes images
  • Normalises pixels
  • Batches data
  • Prefetches for speed





Load a Pre-trained Base Model (MobileNetV2)

We use MobileNetV2 (fast, accurate, lightweight):

from tensorflow.keras import applications

base_model = applications.MobileNetV2(
    input_shape=img_size + (3,),
    include_top=False, # remove ImageNet classifier
    weights="imagenet"
)

Freeze it:

base_model.trainable = False

This makes it a fixed feature extractor.






Add Your Own Classification Head

Typical transfer-learning head:

from tensorflow.keras import layers, Sequential

model = Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dropout(0.2),
    # binary classification
    layers.Dense(1, activation="sigmoid")
])

Why this design?

  • GlobalAveragePooling2D → reduces feature map dimensions without dense layers
  • Dropout → reduces overfitting
  • Sigmoid → output probability of "dog"





Compile the Model

Use a small learning rate:

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss="binary_crossentropy",
    metrics=["accuracy"]
)





Train the Frozen Model (Feature Extraction Phase)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5
)

Typical accuracy: 93–96% after just a few epochs, this is the power of transfer learning.






Fine-Tuning for Extra Accuracy

Once the head is trained, unfreeze part of the base model:

base_model.trainable = True

Unfreeze only the newest convolutional layers (recommended):

# number of layers to leave frozen
fine_tune_at = 100

for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

Recompile with a very low learning rate:

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),
    loss="binary_crossentropy",
    metrics=["accuracy"]
)

Train again:

history_fine = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5
)

Fine-tuning often boosts accuracy by 1–3%.






Data Augmentation (Highly Recommended)

Adding augmentation improves generalisation:

data_augmentation = Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
])

Use it at the start of your model:

model = Sequential([
    data_augmentation,
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dropout(0.2),
    layers.Dense(1, activation='sigmoid')
])





Using Other Pre-trained Keras Models

You can replace MobileNetV2 with any Keras Application model:

applications.ResNet50()
applications.EfficientNetB0()
applications.EfficientNetB3()
applications.DenseNet121()
applications.InceptionV3()
applications.Xception()

All follow the same pattern:

  1. include_top=False
  2. weights="imagenet"
  3. Freeze
  4. Add classifier head
  5. Train
  6. Optionally fine-tune





Full Working Transfer Learning Script

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers, Sequential, applications

# Load dataset
(ds_train, ds_val), ds_info = tfds.load(
    "cats_vs_dogs",
    split=["train[:80%]", "train[80%:]"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True
)

# Preprocessing
img_size = (160, 160)
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def format_example(image, label):
    image = tf.image.resize(image, img_size)
    image = image / 255.0
    return image, label

train_ds = (
    ds_train
    .map(format_example, num_parallel_calls=AUTOTUNE)
    .cache()
    .shuffle(1000)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

val_ds = (
    ds_val
    .map(format_example, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

# Load base model
base_model = applications.MobileNetV2(
    input_shape=img_size + (3,),
    include_top=False,
    weights="imagenet"
)
base_model.trainable = False

# Build model
model = Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dropout(0.2),
    layers.Dense(1, activation="sigmoid")
])

# Compile
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss="binary_crossentropy",
    metrics=["accuracy"]
)

# Train
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5
)

Products from our shop

Docker Cheat Sheet - Print at Home Designs

Docker Cheat Sheet - Print at Home Designs

Docker Cheat Sheet Mouse Mat

Docker Cheat Sheet Mouse Mat

Docker Cheat Sheet Travel Mug

Docker Cheat Sheet Travel Mug

Docker Cheat Sheet Mug

Docker Cheat Sheet Mug

Vim Cheat Sheet - Print at Home Designs

Vim Cheat Sheet - Print at Home Designs

Vim Cheat Sheet Mouse Mat

Vim Cheat Sheet Mouse Mat

Vim Cheat Sheet Travel Mug

Vim Cheat Sheet Travel Mug

Vim Cheat Sheet Mug

Vim Cheat Sheet Mug

SimpleSteps.guide branded Travel Mug

SimpleSteps.guide branded Travel Mug

Developer Excuse Javascript - Travel Mug

Developer Excuse Javascript - Travel Mug

Developer Excuse Javascript Embroidered T-Shirt - Dark

Developer Excuse Javascript Embroidered T-Shirt - Dark

Developer Excuse Javascript Embroidered T-Shirt - Light

Developer Excuse Javascript Embroidered T-Shirt - Light

Developer Excuse Javascript Mug - White

Developer Excuse Javascript Mug - White

Developer Excuse Javascript Mug - Black

Developer Excuse Javascript Mug - Black

SimpleSteps.guide branded stainless steel water bottle

SimpleSteps.guide branded stainless steel water bottle

Developer Excuse Javascript Hoodie - Light

Developer Excuse Javascript Hoodie - Light

Developer Excuse Javascript Hoodie - Dark

Developer Excuse Javascript Hoodie - Dark

© 2025 SimpleSteps.guide
AboutFAQPoliciesContact