Skip to Content

GANs on MNIST Part 1

Inspired by examples like these, I've been getting more interested in building genrative models with Pytorch. A fun first step is to re-implement a solution to a well-solved problem, generating fake digits from MNIST; that's what this post is about. In the end, we'll be rewarded with a GAN whose generator starts off producing 'bad' forgeries (beginning of the gif) and ends up producing pretty 'good' ones (end of the gif, after 100 epochs):

Generator

Our generator is pretty simple: Three linear layers with leaky ReLUs, followed by a final linear layer that we pipe through tanh. Implementing this in Pytorch is a breeze:

IMG_X_DIM = 28
IMG_Y_DIM = 28
IMG_FLAT_DIM = IMG_X_DIM * IMG_Y_DIM

GEN_INPUT_DIM = 64

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(GEN_INPUT_DIM, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.LeakyReLU(0.2 ,inplace=True),

            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2 ,inplace=True),

            nn.Linear(1024, IMG_FLAT_DIM),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x.view(x.size(0), GEN_INPUT_DIM))

Discriminator

Our discriminator is pretty straightforward too: Three linear layers with leaky ReLU and dropout add-ons, followed by a final linear layer through a sigmoid.

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(IMG_FLAT_DIM, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        out = self.model(x.view(x.size(0), IMG_FLAT_DIM))
        out = out.view(out.size(0), -1)
        return out

Training

For training, I wrote a small helper class to manage iteration over each epoch. Lying slightly about the signatures, the class is in the spirit of the below:

class TrainingManager(object):
    def __init__(self, train_loader, generator, gen_optimizer,
                       discriminator, disc_optimizer, loss_func):
        ...

    def _train_epoch(self):
        for images, _ in self.train_loader:

            # Train the discriminator
            ##########################################################################
            ...


            # Train the generator
            ##########################################################################
            ...

    def train(self, epochs=200):
        for epoch in range(epochs):
            disc_loss, gen_loss, d_x, d_gen_z = self._train_epoch(...)        

Commentary

Here's a plot showing the losses of the generator and the discriminator over each (of 100) epochs:

Our generator starts off doing quite poorly, as the graph shows. The very first set of generated images look like this:

We see substantial improvement after 20 epochs:

And more after 60 (total):

And (marginally) more after 100 (total):

One interesting thing to note is that for any given image, the GAN's generator isn't trying to turn it's input (normal random noise) into any particular digit; any will do, as long as it looks like something from MNIST. As a result, you can see the generated images fluctuate between multiple images in the gif. For example, look a the bottom left square move between 3 and 5 (and even 2, early on):

It'd be fun to see what changes if we supply a target digit for each generator input, so that the generator is conditionally trying to create a particular digit. My guess is that, aside from obviously getting more control over our output, we'd be able to lower our loss faster too (because the model would spend less time moving between local minima caused from digit shifting).

Credits

I found this notebook from Github user prcastro to be helpful. The network architecture above comes directly from this source with a few minor changes. You can see my code here.