In [1]:
from IPython.display import Image

CNTK 206: Part A - Basic GAN with MNIST data

Prerequisites: We assume that you have successfully downloaded the MNIST data by completing the tutorial titled CNTK_103A_MNIST_DataLoader.ipynb.

Introduction

Generative models have gained a lot of attention in deep learning community which has traditionally leveraged discriminative models for (semi-supervised) and unsupervised learning. In generative modeling, the idea is to collect a huge amount of data in a domain of interest (e.g., pictures, audio, words) and come up with a trained model that generates such real world data sets. This is an active area of research needing mechanisms to scale up training and having large datasets. As stated in the OpenAI blog, such approaches may be used to perform computer aided art generation, or morph images to some word descriptions such as “make my smile wider”. This approach has found use in image denoising, inpainting, super-resolution, structured prediction, exploration in reinforcement learning, and neural network pretraining in cases where labeled data is expensive.

Generating models that can produce realistic content (images, sounds etc.) mimicking real world observations is challenging. Generative Adversarial Network (GAN) is one of the approaches that holds promise. A quote from Yann LeCun summarizes GAN and its variations as the most important idea in the last 10 years. The original idea was proposed by Goodfellow et al at NIPS 2014. In this tutorial, we show how to use the Cognitive Toolkit to create a basic GAN network for generating synthetic MNIST digits.

In [2]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os

import cntk as C
import cntk.tests.test_utils
cntk.tests.test_utils.set_device_from_pytest_env() # (only needed for our build system)
C.cntk_py.set_fixed_random_seed(1) # fix a random seed for CNTK components

%matplotlib inline

There are two run modes: - Fast mode: isFast is set to True. This is the default mode for the notebooks, which means we train for fewer iterations or train / test on limited data. This ensures functional correctness of the notebook though the models produced are far from what a completed training would produce.

  • Slow mode: We recommend the user to set this flag to False once the user has gained familiarity with the notebook content and wants to gain insight from running the notebooks for a longer period with different parameters for training.

Note If the isFlag is set to False the notebook will take a few hours on a GPU enabled machine. You can try fewer iterations by setting the num_minibatches to a smaller number say 20,000 which comes at the expense of quality of the generated images.

In [3]:
isFast = True

Data Reading

The input to the GAN will be a vector of random numbers. At the end of the traning, the GAN “learns” to generate images of hand written digits drawn from the MNIST database. We will be using the same MNIST data generated in tutorial 103A. A more in-depth discussion of the data format and reading methods can be seen in previous tutorials. For our purposes, just know that the following function returns an object that will be used to generate images from the MNIST dataset. Since we are building an unsupervised model, we only need to read in features and ignore the labels.

In [4]:
# Ensure the training data is generated and available for this tutorial
# We search in two locations in the toolkit for the cached MNIST data set.

data_found = False
for data_dir in [os.path.join("..", "Examples", "Image", "DataSets", "MNIST"),
                 os.path.join("data", "MNIST")]:
    train_file = os.path.join(data_dir, "Train-28x28_cntk_text.txt")
    if os.path.isfile(train_file):
        data_found = True
        break

if not data_found:
    raise ValueError("Please generate the data by completing CNTK 103 Part A")

print("Data directory is {0}".format(data_dir))
Data directory is ..\Examples\Image\DataSets\MNIST
In [5]:
def create_reader(path, is_training, input_dim, label_dim):
    deserializer = C.io.CTFDeserializer(
        filename = path,
        streams = C.io.StreamDefs(
            labels_unused = C.io.StreamDef(field = 'labels', shape = label_dim, is_sparse = False),
            features = C.io.StreamDef(field = 'features', shape = input_dim, is_sparse = False
            )
        )
    )
    return C.io.MinibatchSource(
        deserializers = deserializer,
        randomize = is_training,
        max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1
    )

The random noise we will use to train the GAN is provided by the noise_sample function to generate random noise samples from a uniform distribution within the interval [-1, 1].

In [6]:
np.random.seed(123)
def noise_sample(num_samples):
    return np.random.uniform(
        low = -1.0,
        high = 1.0,
        size = [num_samples, g_input_dim]
    ).astype(np.float32)

Model Creation

A GAN network is composed of two sub-networks, one called the Generator (\(G\)) and the other Discriminator (\(D\)). - The Generator takes random noise vector (\(z\)) as input and strives to output synthetic (fake) image (\(x^*\)) that is indistinguishable from the real image (\(x\)) from the MNIST dataset. - The Discriminator strives to differentiate between the real image (\(x\)) and the fake (\(x^*\)) image.

In [7]:
# Figure 1
Image(url="https://www.cntk.ai/jup/GAN_basic_flow.png")
Out[7]:

In each training iteration, the Generator produces more realistic fake images (in other words minimizes the difference between the real and generated counterpart) and also the Discriminator maximizes the probability of assigning the correct label (real vs. fake) to both real examples (from training set) and the generated fake ones. The two conflicting objectives between the sub-networks (\(G\) and \(D\)) leads to the GAN network (when trained) converge to an equilibrium, where the Generator produces realistic looking fake MNIST images and the Discriminator can at best randomly guess whether images are real or fake. The resulting Generator model once trained produces realistic MNIST image with the input being a random number.

Model config

First, we establish some of the architectural and training hyper-parameters for our model.

  • The generator network is a fully-connected network with a single hidden layer. The input will be a 100-dimensional random vector and the output will be a 784 dimensional vector, corresponding to a flattened version of a 28 x 28 fake image. The discriminator is also a single layer dense network. It takes as input the 784 dimensional output of the generator or a real MNIST image and outputs a single scalar - the estimated probability that the input image is a real MNIST image.

Model components

We build a computational graph for our model, one each for the generator and the discriminator. First, we establish some of the architectural parameters of our model.

  • The generator takes a 100-dimensional random vector (for starters) as input (\(z\)) and the outputs a 784 dimensional vector, corresponding to a flattened version of a 28 x 28 fake (synthetic) image (\(x^*\)). In this tutorial we simply model the generator with two dense layers. We use a tanh activation on the last layer to make sure that the output of the generator function is confined to the interval [-1, 1]. This is necessary because we also scale the MNIST images to this interval, and the outputs of the generator must be able to emulate the actual images as closely as possible.
  • The discriminator takes as input (\(x^*\)) the 784 dimensional output of the generator or a real MNIST image and outputs the estimated probability that the input image is a real MNIST image. We also model this with two dense layers with a sigmoid activation in the last layer ensuring that the discriminator produces a valid probability.
In [8]:
# architectural parameters
g_input_dim = 100
g_hidden_dim = 128
g_output_dim = d_input_dim = 784
d_hidden_dim = 128
d_output_dim = 1
In [9]:
def generator(z):
    with C.layers.default_options(init = C.xavier()):
        h1 = C.layers.Dense(g_hidden_dim, activation = C.relu)(z)
        return C.layers.Dense(g_output_dim, activation = C.tanh)(h1)
In [10]:
def discriminator(x):
    with C.layers.default_options(init = C.xavier()):
        h1 = C.layers.Dense(d_hidden_dim, activation = C.relu)(x)
        return C.layers.Dense(d_output_dim, activation = C.sigmoid)(h1)

We use a minibatch size of 1024 and a fixed learning rate of 0.00005 for training. In the fast mode (isFast = True) we verify only functional correctness with 300 iterations.

Note: In the slow mode, the results look a lot better but it requires patient waiting (few hours) depending on your hardware. In general, the more number of minibatches one trains, the better is the fidelity of the generated images.

In [11]:
# training config
minibatch_size = 1024
num_minibatches = 300 if isFast else 40000
lr = 0.00005

Build the graph

The rest of the computational graph is mostly responsible for coordinating the training algorithms and parameter updates, which is particularly tricky with GANs for couple reasons.

  • First, the discriminator must be used on both the real MNIST images and fake images generated by the generator function. One way to represent this in the computational graph is to create a clone of the output of the discriminator function, but with substituted inputs. Setting method=share in the clone function ensures that both paths through the discriminator model use the same set of parameters.
  • Second, we need to update the parameters for the generator and discriminator model separately using the gradients from different loss functions. We can get the parameters for a Function in the graph with the parameters attribute. However, when updating the model parameters, update only the parameters of the respective models while keeping the other parameters unchanged. In other words, when updating the generator we will update only the parameters of the \(G\) function while keeping the parameters of the \(D\) function fixed and vice versa.

Training the Model

The code for training the GAN very closely follows the algorithm as presented in the original NIPS 2014 paper. In this implementation, we train \(D\) to maximize the probability of assigning the correct label (fake vs. real) to both training examples and the samples from \(G\). In other words, \(D\) and \(G\) play the following two-player minimax game with the value function \(V(G,D)\):

\[\min_G \max_D V(D,G)= \mathbb{E}_{x}[ log D(x) ] + \mathbb{E}_{z}[ log(1 - D(G(z))) ]\]

At the optimal point of this game the generator will produce realistic looking data while the discriminator will predict that the generated image is indeed fake with a probability of 0.5. The algorithm referred below is implemented in this tutorial.

In [12]:
# Figure 2
Image(url="https://www.cntk.ai/jup/GAN_goodfellow_NIPS2014.png", width = 500)
Out[12]:
In [13]:
def build_graph(noise_shape, image_shape,
                G_progress_printer, D_progress_printer):
    input_dynamic_axes = [C.Axis.default_batch_axis()]
    Z = C.input_variable(noise_shape, dynamic_axes=input_dynamic_axes)
    X_real = C.input_variable(image_shape, dynamic_axes=input_dynamic_axes)
    X_real_scaled = 2*(X_real / 255.0) - 1.0

    # Create the model function for the generator and discriminator models
    X_fake = generator(Z)
    D_real = discriminator(X_real_scaled)
    D_fake = D_real.clone(
        method = 'share',
        substitutions = {X_real_scaled.output: X_fake.output}
    )

    # Create loss functions and configure optimazation algorithms
    G_loss = 1.0 - C.log(D_fake)
    D_loss = -(C.log(D_real) + C.log(1.0 - D_fake))

    G_learner = C.fsadagrad(
        parameters = X_fake.parameters,
        lr = C.learning_parameter_schedule_per_sample(lr),
        momentum = C.momentum_schedule_per_sample(0.9985724484938566)
    )
    D_learner = C.fsadagrad(
        parameters = D_real.parameters,
        lr = C.learning_parameter_schedule_per_sample(lr),
        momentum = C.momentum_schedule_per_sample(0.9985724484938566)
    )

    # Instantiate the trainers
    G_trainer = C.Trainer(
        X_fake,
        (G_loss, None),
        G_learner,
        G_progress_printer
    )
    D_trainer = C.Trainer(
        D_real,
        (D_loss, None),
        D_learner,
        D_progress_printer
    )

    return X_real, X_fake, Z, G_trainer, D_trainer

With the value functions defined we proceed to iteratively train the GAN model. The training of the model can take significantly long depending on the hardware especially if isFast flag is turned off.

In [14]:
def train(reader_train):
    k = 2

    # print out loss for each model for upto 50 times
    print_frequency_mbsize = num_minibatches // 50
    pp_G = C.logging.ProgressPrinter(print_frequency_mbsize)
    pp_D = C.logging.ProgressPrinter(print_frequency_mbsize * k)

    X_real, X_fake, Z, G_trainer, D_trainer = \
        build_graph(g_input_dim, d_input_dim, pp_G, pp_D)

    input_map = {X_real: reader_train.streams.features}
    for train_step in range(num_minibatches):

        # train the discriminator model for k steps
        for gen_train_step in range(k):
            Z_data = noise_sample(minibatch_size)
            X_data = reader_train.next_minibatch(minibatch_size, input_map)
            if X_data[X_real].num_samples == Z_data.shape[0]:
                batch_inputs = {X_real: X_data[X_real].data,
                                Z: Z_data}
                D_trainer.train_minibatch(batch_inputs)

        # train the generator model for a single step
        Z_data = noise_sample(minibatch_size)
        batch_inputs = {Z: Z_data}
        G_trainer.train_minibatch(batch_inputs)

        G_trainer_loss = G_trainer.previous_minibatch_loss_average

    return Z, X_fake, G_trainer_loss
In [15]:
reader_train = create_reader(train_file, True, d_input_dim, label_dim=10)

G_input, G_output, G_trainer_loss = train(reader_train)
Learning rate per sample: 5e-05
Learning rate per sample: 5e-05
 Minibatch[   1-  12]: loss = 0.460688 * 12288, metric = 0.00% * 12288;
 Minibatch[   1-   6]: loss = 2.744962 * 6144, metric = 0.00% * 6144;
 Minibatch[  13-  24]: loss = 0.496159 * 12288, metric = 0.00% * 12288;
 Minibatch[   7-  12]: loss = 2.408652 * 6144, metric = 0.00% * 6144;
 Minibatch[  25-  36]: loss = 0.468794 * 12288, metric = 0.00% * 12288;
 Minibatch[  13-  18]: loss = 2.497859 * 6144, metric = 0.00% * 6144;
 Minibatch[  37-  48]: loss = 0.549526 * 12288, metric = 0.00% * 12288;
 Minibatch[  19-  24]: loss = 2.715347 * 6144, metric = 0.00% * 6144;
 Minibatch[  49-  60]: loss = 1.273582 * 12288, metric = 0.00% * 12288;
 Minibatch[  25-  30]: loss = 1.957965 * 6144, metric = 0.00% * 6144;
 Minibatch[  61-  72]: loss = 1.263255 * 12288, metric = 0.00% * 12288;
 Minibatch[  31-  36]: loss = 2.055775 * 6144, metric = 0.00% * 6144;
 Minibatch[  73-  84]: loss = 1.041039 * 12288, metric = 0.00% * 12288;
 Minibatch[  37-  42]: loss = 1.913864 * 6144, metric = 0.00% * 6144;
 Minibatch[  85-  96]: loss = 1.190252 * 12288, metric = 0.00% * 12288;
 Minibatch[  43-  48]: loss = 1.834975 * 6144, metric = 0.00% * 6144;
 Minibatch[  97- 108]: loss = 0.916644 * 12288, metric = 0.00% * 12288;
 Minibatch[  49-  54]: loss = 1.856111 * 6144, metric = 0.00% * 6144;
 Minibatch[ 109- 120]: loss = 0.949138 * 12288, metric = 0.00% * 12288;
 Minibatch[  55-  60]: loss = 1.612386 * 6144, metric = 0.00% * 6144;
 Minibatch[ 121- 132]: loss = 1.159921 * 12288, metric = 0.00% * 12288;
 Minibatch[  61-  66]: loss = 1.720942 * 6144, metric = 0.00% * 6144;
 Minibatch[ 133- 144]: loss = 1.069388 * 12288, metric = 0.00% * 12288;
 Minibatch[  67-  72]: loss = 1.742269 * 6144, metric = 0.00% * 6144;
 Minibatch[ 145- 156]: loss = 1.160007 * 12288, metric = 0.00% * 12288;
 Minibatch[  73-  78]: loss = 1.752263 * 6144, metric = 0.00% * 6144;
 Minibatch[ 157- 168]: loss = 1.109970 * 12288, metric = 0.00% * 12288;
 Minibatch[  79-  84]: loss = 2.054680 * 6144, metric = 0.00% * 6144;
 Minibatch[ 169- 180]: loss = 1.092135 * 12288, metric = 0.00% * 12288;
 Minibatch[  85-  90]: loss = 1.904574 * 6144, metric = 0.00% * 6144;
 Minibatch[ 181- 192]: loss = 1.127416 * 12288, metric = 0.00% * 12288;
 Minibatch[  91-  96]: loss = 1.740344 * 6144, metric = 0.00% * 6144;
 Minibatch[ 193- 204]: loss = 1.096188 * 12288, metric = 0.00% * 12288;
 Minibatch[  97- 102]: loss = 1.672640 * 6144, metric = 0.00% * 6144;
 Minibatch[ 205- 216]: loss = 1.076912 * 12288, metric = 0.00% * 12288;
 Minibatch[ 103- 108]: loss = 1.567401 * 6144, metric = 0.00% * 6144;
 Minibatch[ 217- 228]: loss = 1.157102 * 12288, metric = 0.00% * 12288;
 Minibatch[ 109- 114]: loss = 1.651072 * 6144, metric = 0.00% * 6144;
 Minibatch[ 229- 240]: loss = 1.028144 * 12288, metric = 0.00% * 12288;
 Minibatch[ 115- 120]: loss = 1.683286 * 6144, metric = 0.00% * 6144;
 Minibatch[ 241- 252]: loss = 0.904863 * 12288, metric = 0.00% * 12288;
 Minibatch[ 121- 126]: loss = 1.785782 * 6144, metric = 0.00% * 6144;
 Minibatch[ 253- 264]: loss = 0.852108 * 12288, metric = 0.00% * 12288;
 Minibatch[ 127- 132]: loss = 1.827977 * 6144, metric = 0.00% * 6144;
 Minibatch[ 265- 276]: loss = 0.696002 * 12288, metric = 0.00% * 12288;
 Minibatch[ 133- 138]: loss = 2.021851 * 6144, metric = 0.00% * 6144;
 Minibatch[ 277- 288]: loss = 0.584447 * 12288, metric = 0.00% * 12288;
 Minibatch[ 139- 144]: loss = 2.205963 * 6144, metric = 0.00% * 6144;
 Minibatch[ 289- 300]: loss = 0.560834 * 12288, metric = 0.00% * 12288;
 Minibatch[ 145- 150]: loss = 2.256393 * 6144, metric = 0.00% * 6144;
 Minibatch[ 301- 312]: loss = 0.571917 * 12288, metric = 0.00% * 12288;
 Minibatch[ 151- 156]: loss = 2.246409 * 6144, metric = 0.00% * 6144;
 Minibatch[ 313- 324]: loss = 0.603770 * 12288, metric = 0.00% * 12288;
 Minibatch[ 157- 162]: loss = 2.246180 * 6144, metric = 0.00% * 6144;
 Minibatch[ 325- 336]: loss = 0.646258 * 12288, metric = 0.00% * 12288;
 Minibatch[ 163- 168]: loss = 2.244751 * 6144, metric = 0.00% * 6144;
 Minibatch[ 337- 348]: loss = 0.743782 * 12288, metric = 0.00% * 12288;
 Minibatch[ 169- 174]: loss = 2.081807 * 6144, metric = 0.00% * 6144;
 Minibatch[ 349- 360]: loss = 0.766006 * 12288, metric = 0.00% * 12288;
 Minibatch[ 175- 180]: loss = 2.005488 * 6144, metric = 0.00% * 6144;
 Minibatch[ 361- 372]: loss = 0.751595 * 12288, metric = 0.00% * 12288;
 Minibatch[ 181- 186]: loss = 2.073293 * 6144, metric = 0.00% * 6144;
 Minibatch[ 373- 384]: loss = 0.741732 * 12288, metric = 0.00% * 12288;
 Minibatch[ 187- 192]: loss = 2.031459 * 6144, metric = 0.00% * 6144;
 Minibatch[ 385- 396]: loss = 0.800578 * 12288, metric = 0.00% * 12288;
 Minibatch[ 193- 198]: loss = 2.146479 * 6144, metric = 0.00% * 6144;
 Minibatch[ 397- 408]: loss = 0.703397 * 12288, metric = 0.00% * 12288;
 Minibatch[ 199- 204]: loss = 2.148305 * 6144, metric = 0.00% * 6144;
 Minibatch[ 409- 420]: loss = 0.808739 * 12288, metric = 0.00% * 12288;
 Minibatch[ 205- 210]: loss = 1.946966 * 6144, metric = 0.00% * 6144;
 Minibatch[ 421- 432]: loss = 0.878929 * 12288, metric = 0.00% * 12288;
 Minibatch[ 211- 216]: loss = 1.935109 * 6144, metric = 0.00% * 6144;
 Minibatch[ 433- 444]: loss = 0.838549 * 12288, metric = 0.00% * 12288;
 Minibatch[ 217- 222]: loss = 1.906769 * 6144, metric = 0.00% * 6144;
 Minibatch[ 445- 456]: loss = 0.946877 * 12288, metric = 0.00% * 12288;
 Minibatch[ 223- 228]: loss = 1.882014 * 6144, metric = 0.00% * 6144;
 Minibatch[ 457- 468]: loss = 0.883135 * 12288, metric = 0.00% * 12288;
 Minibatch[ 229- 234]: loss = 1.900080 * 6144, metric = 0.00% * 6144;
 Minibatch[ 469- 480]: loss = 0.911140 * 12288, metric = 0.00% * 12288;
 Minibatch[ 235- 240]: loss = 1.947830 * 6144, metric = 0.00% * 6144;
 Minibatch[ 481- 492]: loss = 0.814674 * 12288, metric = 0.00% * 12288;
 Minibatch[ 241- 246]: loss = 2.027527 * 6144, metric = 0.00% * 6144;
 Minibatch[ 493- 504]: loss = 0.846405 * 12288, metric = 0.00% * 12288;
 Minibatch[ 247- 252]: loss = 1.890244 * 6144, metric = 0.00% * 6144;
 Minibatch[ 505- 516]: loss = 0.807406 * 12288, metric = 0.00% * 12288;
 Minibatch[ 253- 258]: loss = 2.031275 * 6144, metric = 0.00% * 6144;
 Minibatch[ 517- 528]: loss = 0.795118 * 12288, metric = 0.00% * 12288;
 Minibatch[ 259- 264]: loss = 2.050018 * 6144, metric = 0.00% * 6144;
 Minibatch[ 529- 540]: loss = 0.828629 * 12288, metric = 0.00% * 12288;
 Minibatch[ 265- 270]: loss = 1.943776 * 6144, metric = 0.00% * 6144;
 Minibatch[ 541- 552]: loss = 0.851117 * 12288, metric = 0.00% * 12288;
 Minibatch[ 271- 276]: loss = 1.986959 * 6144, metric = 0.00% * 6144;
 Minibatch[ 553- 564]: loss = 0.760867 * 12288, metric = 0.00% * 12288;
 Minibatch[ 277- 282]: loss = 2.125000 * 6144, metric = 0.00% * 6144;
 Minibatch[ 565- 576]: loss = 0.727046 * 12288, metric = 0.00% * 12288;
 Minibatch[ 283- 288]: loss = 2.122426 * 6144, metric = 0.00% * 6144;
 Minibatch[ 577- 588]: loss = 0.753314 * 12288, metric = 0.00% * 12288;
 Minibatch[ 289- 294]: loss = 2.060394 * 6144, metric = 0.00% * 6144;
 Minibatch[ 589- 600]: loss = 0.794642 * 12288, metric = 0.00% * 12288;
 Minibatch[ 295- 300]: loss = 1.989756 * 6144, metric = 0.00% * 6144;
In [16]:
# Print the generator loss
print("Training loss of the generator is: {0:.2f}".format(G_trainer_loss))
Training loss of the generator is: 2.07

Generating Fake (Synthetic) Images

Now that we have trained the model, we can create fake images simply by feeding random noise into the generator and displaying the outputs. Below are a few images generated from random samples. To get a new set of samples, you can re-run the last cell.

In [17]:
def plot_images(images, subplot_shape):
    plt.style.use('ggplot')
    fig, axes = plt.subplots(*subplot_shape)
    for image, ax in zip(images, axes.flatten()):
        ax.imshow(image.reshape(28, 28), vmin = 0, vmax = 1.0, cmap = 'gray')
        ax.axis('off')
    plt.show()

noise = noise_sample(36)
images = G_output.eval({G_input: noise})
plot_images(images, subplot_shape =[6, 6])
_images/CNTK_206A_Basic_GAN_28_0.png

Larger number of iterations should generate more realistic looking MNIST images. A sampling of such generated images are shown below.

In [18]:
# Figure 3
Image(url="http://www.cntk.ai/jup/GAN_basic_slowmode.jpg")
Out[18]:

Note: It takes a large number of iterations to capture a representation of the real world signal. Even simple dense networks can be quite effective in modelling data albeit MNIST is a relatively simple dataset as well.

Suggested Task

  • Explore the impact of changing the dimension of the input random noise (say from 100 to 10) in terms of computation time, loss and memory footprint for the same number of iterations.
  • Scale the image from 0 to 1. What other changes in the network are needed?
  • Performance is a key aspect to deep neural networks training. Study how the changing the minibatch sizes impact the performance both with regards to quality of the generated images and the time it takes to train a model.
  • Try generating fake images using the CIFAR-10 data set as the training data. How does the network above perform?