DEV Community

Cover image for Super Resolution with GAN and Keras (SRGAN)
Manish Dhakal
Manish Dhakal

Posted on

Super Resolution with GAN and Keras (SRGAN)

Prior Knowledge

  • Neural Networks
  • Python
  • Keras (better to have)

Generative Adversarial Networks (GAN)

GAN is the technology in the field of Neural Network innovated by Ian Goodfellow and his friends. SRGAN is the method by which we can increase the resolution of any image.

GAN

It contains basically two parts Generator and Discriminator. Generator produces refined output data from given input noise. Discriminator receives two types of data: one is the real world data and another is the generated output from generator. For discriminator, real data has label ‘1’ and generated data has label ‘0’. We can take the analogy of generator as artist and discriminator as critic. Artists create an art form which is judged by the critic.

ARTIST AND CRITIC

As the generator improves with training, the discriminator performance gets worse because the discriminator can’t easily tell the difference between real and fake. Theoretically, at last discriminator will have 50% accuracy just like flip of a coin.

So our motto is to decrease the accuracy of the people who judge us and focus on our artwork.

Structure of SRGAN

SRGAN MODEL

Alternate Training

The generator and discriminator are trained differently. First discriminator is trained for one or more epochs and generator is also trained for one or more epochs then one cycle is said to be completed. Pretrained VGG19 model is used to extract features from the image while training.
While training the generator the parameters of discriminator are frozen or else the model would be hitting a moving target and never converges.


Code

Import necessary dependencies

import numpy as np
from keras import Model
from keras.layers import Conv2D, PReLU, BatchNormalization, Flatten
from keras.layers import UpSampling2D, LeakyReLU, Dense, Input, add

Enter fullscreen mode Exit fullscreen mode

Some of necessary variables

lr_ip = Input(shape=(25,25,3))
hr_ip = Input(shape=(100,100,3))
train_lr,train_hr = #training images arrays normalized between 0 & 1
test_lr, test_hr = # testing images arrays normalized between 0 & 1
Enter fullscreen mode Exit fullscreen mode



Define Generator

We have to define a function to return the generator model which is used to produce the high resolution image. Residual block is the function in which returns a the addition of input layer and the final layer.

# Residual block
def res_block(ip):

    res_model = Conv2D(64, (3,3), padding = "same")(ip)
    res_model = BatchNormalization(momentum = 0.5)(res_model)
    res_model = PReLU(shared_axes = [1,2])(res_model)

    res_model = Conv2D(64, (3,3), padding = "same")(res_model)
    res_model = BatchNormalization(momentum = 0.5)(res_model)

    return add([ip,res_model])

# Upscale the image 2x
def upscale_block(ip):    
    up_model = Conv2D(256, (3,3), padding="same")(ip)
    up_model = UpSampling2D( size = 2 )(up_model)
    up_model = PReLU(shared_axes=[1,2])(up_model)

    return up_model
num_res_block = 16

# Generator Model
def create_gen(gen_ip):
    layers = Conv2D(64, (9,9), padding="same")(gen_ip)
    layers = PReLU(shared_axes=[1,2])(layers)
    temp = layers
    for i in range(num_res_block):
        layers = res_block(layers)
    layers = Conv2D(64, (3,3), padding="same")(layers)
    layers = BatchNormalization(momentum=0.5)(layers)
    layers = add([layers,temp])
    layers = upscale_block(layers)
    layers = upscale_block(layers)
    op = Conv2D(3, (9,9), padding="same")(layers)
    return Model(inputs=gen_ip, outputs=op)
Enter fullscreen mode Exit fullscreen mode



Define Discriminator

This block of code defines the structure of discriminator model, and all of the layers involved to distinguish real and generated image. As we go deeper, after each 2 layers the number of filter increases by twice.

#Small block inside the discriminator
def discriminator_block(ip, filters, strides=1, bn=True):

    disc_model = Conv2D(filters, (3,3), strides, padding="same")(ip)
    disc_model = LeakyReLU( alpha=0.2 )(disc_model)
    if bn:
        disc_model = BatchNormalization( momentum=0.8 )(disc_model)
    return disc_model

# Discriminator Model
def create_disc(disc_ip):
    df = 64

    d1 = discriminator_block(disc_ip, df, bn=False)
    d2 = discriminator_block(d1, df, strides=2)
    d3 = discriminator_block(d2, df*2)
    d4 = discriminator_block(d3, df*2, strides=2)
    d5 = discriminator_block(d4, df*4)
    d6 = discriminator_block(d5, df*4, strides=2)
    d7 = discriminator_block(d6, df*8)
    d8 = discriminator_block(d7, df*8, strides=2)

    d8_5 = Flatten()(d8)
    d9 = Dense(df*16)(d8_5)
    d10 = LeakyReLU(alpha=0.2)(d9)
    validity = Dense(1, activation='sigmoid')(d10)
    return Model(disc_ip, validity)
Enter fullscreen mode Exit fullscreen mode



VGG19 Model

In this code block, we use the VGG19 model trained with image-net database to extract the features, this model is frozen later so that parameters won’t get updated.

from keras.applications import VGG19
# Build the VGG19 model upto 10th layer 
# Used to extract the features of high res imgaes
def build_vgg():
    vgg = VGG19(weights="imagenet")
    vgg.outputs = [vgg.layers[9].output]
    img = Input(shape=hr_shape)
    img_features = vgg(img)
    return Model(img, img_features)
Enter fullscreen mode Exit fullscreen mode



Combined Model

Now, we attach both generator and discriminator model. The model obtained from this is used only to train the generator model. While training this combined model we have to freeze the discriminator in each epoch.

# Attach the generator and discriminator
def create_comb(gen_model, disc_model, vgg, lr_ip, hr_ip):
    gen_img = gen_model(lr_ip)
    gen_features = vgg(gen_img)
    disc_model.trainable = False
    validity = disc_model(gen_img)
    return Model([lr_ip, hr_ip],[validity,gen_features])
Enter fullscreen mode Exit fullscreen mode



Declare models

Then, we declare generator, discriminator and vgg models. Those model will be used as arguments for the combined model.
Any changes of the smaller models inside the combined model also affects the model outside. For example: weight updates, freezing the model, etc.

generator = create_gen(lr_ip)
discriminator = create_disc(hr_ip)
discriminator.compile(loss="binary_crossentropy", optimizer="adam",      
  metrics=['accuracy'])
vgg = build_vgg()
vgg.trainable = False
gan_model = create_comb(generator, discriminator, vgg, lr_ip, hr_ip)
gan_model.compile(loss=["binary_crossentropy","mse"], loss_weights=
  [1e-3, 1], optimizer="adam")
Enter fullscreen mode Exit fullscreen mode



Sample the training data into small batches

As the training set is too large, we need to sample the images into small batches to avoid Resource Exhausted Error. The resource such as RAM will not be enough to train all the images at once.

batch_size = 20
train_lr_batches = []
train_hr_batches = []
for it in range(int(train_hr.shape[0] / batch_size)):
    start_idx = it * batch_size
    end_idx = start_idx + batch_size
    train_hr_batches.append(train_hr[start_idx:end_idx])
    train_lr_batches.append(train_lr[start_idx:end_idx])
train_lr_batches = np.array(train_lr_batches)
train_hr_batches = np.array(train_hr_batches)
Enter fullscreen mode Exit fullscreen mode



Training the model

This block is the core of whole program. Here we train the discriminator and generator in the alternating method as mentioned above. As of now, the discriminator is frozen, do not forget to unfreeze before and freeze after training the discriminator, which is given in the code below.

epochs = 100
for e in range(epochs):
    gen_label = np.zeros((batch_size, 1))
    real_label = np.ones((batch_size,1))
    g_losses = []
    d_losses = []
    for b in range(len(train_hr_batches)):
        lr_imgs = train_lr_batches[b]
        hr_imgs = train_hr_batches[b]
        gen_imgs = generator.predict_on_batch(lr_imgs)
        #Dont forget to make the discriminator trainable
        discriminator.trainable = True

        #Train the discriminator
        d_loss_gen = discriminator.train_on_batch(gen_imgs,
          gen_label)
        d_loss_real = discriminator.train_on_batch(hr_imgs,
          real_label)
        discriminator.trainable = False
        d_loss = 0.5 * np.add(d_loss_gen, d_loss_real)
        image_features = vgg.predict(hr_imgs)

        #Train the generator
        g_loss, _, _ = gan_model.train_on_batch([lr_imgs, hr_imgs], 
          [real_label, image_features])

        d_losses.append(d_loss)
        g_losses.append(g_loss)
    g_losses = np.array(g_losses)
    d_losses = np.array(d_losses)

    g_loss = np.sum(g_losses, axis=0) / len(g_losses)
    d_loss = np.sum(d_losses, axis=0) / len(d_losses)
    print("epoch:", e+1 ,"g_loss:", g_loss, "d_loss:", d_loss)
Enter fullscreen mode Exit fullscreen mode



Evaluate the model

Hereby, we calculate the performance of the generator with test dataset. The loss may be a little larger than with training dataset, but do not worry as long as long as the difference is small.

label = np.ones((len(test_lr),1))
test_features = vgg.predict(test_hr)
eval,_,_ = gan_model.evaluate([test_lr, test_hr], [label,test_features])
Enter fullscreen mode Exit fullscreen mode



Predict the output

We can generate high resolution images with generator model.

test_prediction = generator.predict_on_batch(test_lr)
Enter fullscreen mode Exit fullscreen mode

The output is quite amazing…
SRGAN Output

You can find my implementation which was trained on google colab in my github profile.


Tips

  • Always remember which model to make trainable or not.
  • While training the generator use the label value as one.
  • It is better to use images larger than 25x25 as they have more details for generated images.
  • Do not forget to normalize the numpy dataset between 0 and 1.

References

Jason Brownlee. 2019. Generative Adversarial Networks with Python
https://arxiv.org/pdf/1609.04802. Paper on SRGAN

Top comments (3)

Collapse
 
moritzmessner profile image
Moritz Messner

Hi Manish, thanks for the article!

When I tried to run your code, I ran into a problem with VGG. I was able to solve it by modifying the build_vgg function.

If anyone has the same problem, here is my solution for it.

from tensorflow.keras.applications import VGG19

def build_vgg():
vgg = VGG19(weights="imagenet",input_shape =(img_width,img_height,3),include_top=False)
outputs = [vgg.layers[9].output]
return Model(vgg.input, outputs)

Collapse
 
mehnaz1985 profile image
mehnaz1985

Nice explanation. As I am a novice in this field, I would like to know whether I could apply training on my custom dataset or not? Is there any specification to create the dataset? suppose I have train and test CSV file or train or test images in a separate folder. Please help me in this regard.

Collapse
 
whiterosefsociety profile image
whiterose-fsociety

def upscale_block(ip):

up_model = Conv2D(256, (3,3), padding="same")(ip)
up_model = UpSampling2D( size = 2 )(up_model)
up_model = PReLU(shared_axes=[1,2])(up_model)

I wanted to ask about this ?

I thought it would have been this way
up_model = UpSampling2D( size = 2 )(ip)
up_model = Conv2D(256, (3,3), padding="same")( up_model)

Since the pooling layer does not perform any learning.
I assumed this (the one I wrote) is the same as Conv2DTranspose