Style Transfer with VGG-16

Python project, TensorFlow.

This article will show how reuse the feature extractor of a model trained for object detection in a new model designed for style transfer. First, the domain of style transfer will be introduced, then, we will go further in the implementation. Finally, useful tools will be developed to visualize our results.

GitHub link: https://github.com/Apiquet/Style_transfer

Table of contents

  1. Style Transfer
  2. Principle
  3. Implementation
    1. The model
    2. Training
  4. Test
  5. Conclusion

1) Style Transfer

Style Transfer refers to the task of applying a specific style to an image. The following example, obtained with the neural network we will implement further, illustrates it:

This can be done with any content image and any style image. The following example shows the original image at the bottom left of the image, then each style image is displayed at the bottom left of each style transfer:

There are several ways to implement a style transfer algorithm, this article will focus on one of them. This method uses a neural network and works as follow:

  1. Load the model and freeze all its layers (because we do not train the model)
  2. Run the model on the initial content image
  3. Calculate a loss that depends on two metrics:
    1. how far the output is from the original content image
    2. how far the output is from the style image
  4. Use gradient descent with respect to the image (to update its pixels to match both content and style images)
  5. Feed the model with the updated image, and go back to step 3

The next section will go deeper in the implementation.

2) Principle

In general, when we use a neural network, we have to find a lot of data and train the model on it. As a result, the neural network becomes more and more accurate as we train. However, in Style Transfer, we will have to do the opposite. We will take a pre-trained network that has already learned to extract features from an input image (to extract good representations from the input image with a lower resolution). During “training”, we will not update its weights, the only thing we will update is the image. For this, we will have to calculate our gradients with respect to the image instead of the model’s weights. We are going to go deeper in the gradient descent’s explanation further.

One important thing to know is that the style of an image depends on low resolution content. While the main content of an image (a dog is present) corresponds to high resolution content. As it can be difficult to understand, I hope the following example will help. In the next illustration, the two images have the same content but not the same style:

We can clearly recognize the presence of a dog and a ball, but the “low resolution” contents are different which leads to different style. The style is defined by the value of each pixel, while the dog is defined by a large number of pixels.

This fundamental difference between style and content will help us define the loss of the network. Indeed, it is not logical to define the style loss from the neural network’s output because we will only get high-resolution representations ! We only need to calculate the content loss with the neural network’s output. The style loss must use the first layers that extract low resolution data.

To continue the series available on my website, I choose to re-use the VGG-16 network used by SSD300 model developed in a previous article. This neural network (SSD) was trained on VOC2012 dataset with 21 classes, so its feature extractor (VGG-16) was trained to extract features that best represent the content of the input image. The following illustration shows the “training” process and should clarify the above paragraphs:

The illustration above explains the process used to update an image at each pass (epoch). The style loss is calculated with the first N layers (based on the difference with the style image). The content loss is calculated with a layer from the end of the model. According to experiments and online resources, Style Transfer seems to work better if we do not take the last layer. The last layer represents a very high representation but the image content is rather a medium/high resolution, so we should take the layer index -2 or -3 (-2 in the illustration).

The following animation shows the result, epoch after epoch:

The content of the image (the dog and the ball) stay in the image but the style change over the epoch. A last thing to know before going to the implementation is the weights that we can add to content and style losses:

This weights (a, b) will allow to adjust the final result (to be closer to the content image or style image).

3) Implementation

3-1) The model

To have a complete description of the VGG-16 implementation, please refer to the “SSD300 Implementation” article. To have a complete description of how to load weights to re-use the feature extractor (VGG-16) of SSD, please refer to the “Image Segmentation” article.

The following code instantiate VGG-16 from the repository of SDD300 and load the pre-trained weights:

from models.SSD300 import SSD300

# get SSD300 model
SSD300_model = SSD300(21, floatType)
input_shape = (300, 300, 3)

# run model on zero tensor to initialize it
confs, locs = SSD300_model(tf.zeros([32, 300, 300, 3], self.floatType))
if ssd_weights_path is not None:
    SSD300_model.load_weights(ssd_weights_path)

# get the VGG-16 part from SSD300
SSD_backbone = SSD300_model.getVGG16()

# get a new VGG-16 model until the stage 5
from models.VGG16 import VGG16
self.VGG16 = VGG16(input_shape=input_shape)
self.VGG16_tilStage5 = self.VGG16.getUntilStage5()

# load weights from SSD to new VGG-16 model
ssd_seq_idx = 0
ssd_layer_idx = 0
for i in range(len(self.VGG16_tilStage5.layers)):
    ssd_layer_idx = i
    if i >= 13:
        ssd_seq_idx = 1
        ssd_layer_idx -= 13
    self.VGG16_tilStage5.get_layer(index=i).set_weights(
    SSD_backbone.get_layer(index=ssd_seq_idx).get_layer(
        index=ssd_layer_idx).get_weights())
    self.VGG16_tilStage5.get_layer(index=i).trainable = False

# del models that we won't use anymore
del SSD_backbone
del SSD300_model

Then, we need to create a new model with 6 outputs:

  • The first 5 layers used for the style loss
  • The layer at index -2 for the content loss
self.style_layers = []

self.inputs = tf.keras.layers.Input(shape=input_shape)
self.x = self.VGG16_tilStage5.get_layer(index=0)(self.inputs)

for i in range(1, 6):
    # get first 5 layers for the style loss
    self.style_layers.append(self.x)
    self.x = self.VGG16_tilStage5.get_layer(index=i)(self.x)

for i in range(7, len(self.VGG16_tilStage5.layers)-2):
    self.x = self.VGG16_tilStage5.get_layer(index=i)(self.x)

# get first layer at index -2 for the content loss
self.content_layers = [self.x]

self.model = tf.keras.Model(
inputs=self.inputs, outputs=self.style_layers+self.content_layers)

Our model is now implemented and has 6 outputs layers: the first 5 of the architecture and the one at index -2. The next section will explain the training process and how to calculate the losses.

3-2) Training

The training loop is very similar to a regular one, the main difference is that we apply the gradient on the image instead of the model’s weights:

# infer the model on the style image to get the style targets (result of the first 5 layers)
style_targets, _ = self.get_features(style_image)

# infer the model on the content image to get the content targets (result of the layer of index -2)
_, content_targets = self.get_features(content_image)

# generate a copy of our content image
# this copy will be update with the gradients over the epochs
generated_image = tf.cast(content_image, dtype=tf.float32)
generated_image = tf.Variable(generated_image)

# training loop
for n in tqdm(range(epochs), position=0, leave=True):
    with tf.GradientTape() as tape:
        # run the model on the current image (image updated at each run)
        # get the style feature (outputs of the first 5 layers) and content feature (outputs of the layer with index -2)
        style_features, content_features = self.get_features(generated_image)
        # calculate the loss
        loss = self.get_loss(style_targets, style_features, content_targets, content_features)

    # get gradients
    gradients = tape.gradient(loss, generated_image)
    # apply gradients wrt the image to update
    optimizer.apply_gradients([(gradients, generated_image)])
    # clip image to have a range of [0, 255]
    generated_image.assign(tf.clip_by_value(generated_image, clip_value_min=0.0, clip_value_max=255.0))

The style and content losses are calculated with the code bellow:

def get_loss(self, style_target, style_feature,
             content_target, content_feature):
    style_loss = tf.add_n([tf.reduce_mean(tf.square(features - targets))
                           for features, targets in zip(style_feature, style_target)])

    content_loss = tf.add_n([0.5 * tf.reduce_sum(tf.square(content_feature - content_target))])

    return 1 * style_loss + 0.001 * content_loss

This correspond to the following equations:

The following images show the difference between having a medium weight for content loss and having a very low one:

With a weight too low for the content loss (picture on the right), the dog may no longer be recognizable. Style transfer consists in obtaining a good compromise between style and content representation in the final image.

4) Test

There are useful methods and functions in the GitHub repository that allow you to infer the model from images and videos. The code is not difficult to implement, it’s only about image management, so I won’t explain it here. However, I will explain how to use it (all the code is available in the GitHub repository).

The notebook style_transfer_examples.ipynb shows how to load and infer the model on images or on videos. For instance, we can run the following code with style_image_path = “imgs/style_1.jpg”, “imgs/style_2.jpg” and “imgs/style_3.jpg”:

This will create imgs/style_transfer_1.jpg result (style_transfer_2.jpg and style_transfer_3.jpg also if we try the 3 style_1/2/3.jpg images). Then, we can concatenate them with the following code:

This will produce the image below:

We can also infer the model on a video:

Conclusion

In this article, we learned the principle of Style Transfer using the feature extractor of a model trained to a different task (object detection). We learned how to use it for Style Transfer, how to calculate our loss and how to implement the training function that updates the image.

After this series of articles, we know how to use a feature extractor (here VGG-16) for multiple tasks such as: object detection, multi-object tracking, image segmentation and style transfer. I hope this will help other people who want to learn more about neural networks.


Here you can find my project:

https://github.com/Apiquet/Style_transfer

Video source: coveer