Skip to Content

GANs on MNIST, Part 3

In this final post about GANs on MNIST, we'll implement our most sophisticated architecture, taking ideas from the excelent pix2pix paper. (Click here for Parts 1, 2, and 2.5.) In particular, we'll implement a U-Net generator and a PatchGAN discriminator. For fun, we'll also test a human (me) against the output: Will I be able to tell the difference between the real and generated MNIST images?

Generator

The idea for the U-Net generator is to make something that looks like a convolutional autoencoder, but with skip layers. Olaf Ronneberger (one of the authors on the original U-net paper) has a nice diagram on his site

Jun-Yan Zhu, one of the pix2pix authors, reimplemented an instance of this architecture in Pytorch here. Cleverly, the codebase defines the architecture recursively (where you pass in the "lower" part of the U to the next stage of the construction). Simplified, the skeleton looks something like this:

# 2 Blocks from Bottom
nn.Conv2d(...)

# 1 Block from Bottom
nn.LeakyReLU(...)
nn.Conv2d(...)
nn.BatchNorm2d(...)

# Innermost Layer
nn.LeakyReLU(...)
nn.Conv2d(...)
nn.ReLU(...)
nn.ConvTranspose2d(...)
nn.BatchNorm2d(...)

# 1 Block from Bottom
nn.ReLU(...)
nn.ConvTranspose2d(...)
nn.BatchNorm2d(..)

# 2 Blocks from Bottom
nn.ReLU(...)
nn.ConvTranspose2d(...)
nn.Tanh()

That looks like a convolutional autoencoder, but how do the skip layers work? This is implemented in the forward pass, which looks like this for each individual layer (this is important):

def forward(self, x):
    if self.outermost:
        return self.model(x)
    else:
        return torch.cat([x, self.model(x)], 1)

Here, self.outermost is a property set at the layer level. So, in our skeleton above, it's going to be true for the "2 Blocks from the Bottom" part but false for the others.

So, if we think abvout this for a minute, here's what happens. We get an input tensor x. The outermost layer, two_from_bottom, has self.outermost= True, so this "2 Blocks from the Bottom" layer calls two_from_bottom.model(x). This creates nn.Conv2d(x) = x_1 and then passes the result to one_from_bottom.model(x). Now one_from_bottom.outermost=False, so we return torch.cat([x_1, one_from_bottom.model(x_1)]). Let's avoid going a layer deeper for now, and just think about what happens next, assuming that we know what one_from_bottom.model(x_1) is. We pass torch.cat([x_1, one_from_bottom.model(x_1)]) back up to the last part of "2 Blocks from the Bottom", namely nn.ReLU(...), nn.ConvTranspose2d(...), nn.Tanh(). In this way, the skip layers are implemented; not only does x_1 travel down the stac, it's passed directly to the end of the same-level layer. Again, this isn't saying anything different from what the diagram above indicates; it's just interesting to see how this gets implemented in practice.

For our generator, we'll implement the same effect less elegantly but more simply.

Discriminator

We're going to use a PatchGAN discriminator. Here's how the pix2pix authors describe the idea:

In order to model high-frequencies, it is sufficient to restrict our attention to the structure in local image patches. Therefore, we design a discriminator architecture – which we term a PatchGAN – that only penalizes structure at the scale of patches. This discriminator tries to classify if each N × N patch in an image is real or fake. We run this discriminator convolutationally across the image, averaging all responses to provide the ultimate output of D.

Again, we can see Jun-Yan Zhu's implementation here. Simplifying, the skeleton looks something like this:

nn.Conv2d(in_channels, out_channels=64, ...)
nn.LeakyReLU(...)

# Going to do this block 4x, doubling the
# channel count each time
nn.Conv2d(in_channels=64, out_channels=128)
nn.BatchNorm2d(...)
nn.LeakyReLU(...)

nn.Conv2d(in_channels=128, out_channels=256)
nn.BatchNorm2d(...)
nn.LeakyReLU(...)

nn.Conv2d(in_channels=256, out_channels=512)
nn.BatchNorm2d(...)
nn.LeakyReLU(...)

nn.Conv2d(in_channels=512, out_channels=1024)
nn.BatchNorm2d(...)
nn.LeakyReLU(...)

# Now go down to 1 channel
nn.Conv2d(1024, 1, ...)

We end up with a 1024 dimensional tensor as our output. One might wonder: If we're supposed to be labeling images as 0s (believed fake) or 1s (believed real), how are we going to use this 1024 dimensional output? The answer is that we'll apply our function to this entire vector against a 1024 dimensional vector of zeroes (if the image is fake) and ones (if the image is real). You can see how pix2pix does it here.