In this final post about GANs on MNIST, we'll implement our most sophisticated architecture, taking ideas from the excelent pix2pix paper. (Click here for Parts 1, 2, and 2.5.) In particular, we'll implement a U-Net generator and a PatchGAN discriminator. For fun, we'll also test a human (me) against the output: Will I be able to tell the difference between the real and generated MNIST images?
The idea for the U-Net generator is to make something that looks like a convolutional autoencoder, but with skip layers. Olaf Ronneberger (one of the authors on the original U-net paper) has a nice diagram on his site
Jun-Yan Zhu, one of the pix2pix authors, reimplemented an instance of this architecture in Pytorch here. Cleverly, the codebase defines the architecture recursively (where you pass in the "lower" part of the U to the next stage of the construction). Simplified, the skeleton looks something like this:
# 2 Blocks from Bottom nn.Conv2d(...) # 1 Block from Bottom nn.LeakyReLU(...) nn.Conv2d(...) nn.BatchNorm2d(...) # Innermost Layer nn.LeakyReLU(...) nn.Conv2d(...) nn.ReLU(...) nn.ConvTranspose2d(...) nn.BatchNorm2d(...) # 1 Block from Bottom nn.ReLU(...) nn.ConvTranspose2d(...) nn.BatchNorm2d(..) # 2 Blocks from Bottom nn.ReLU(...) nn.ConvTranspose2d(...) nn.Tanh()
That looks like a convolutional autoencoder, but how do the skip layers work? This is implemented in the forward pass, which looks like this for each individual layer (this is important):
def forward(self, x): if self.outermost: return self.model(x) else: return torch.cat([x, self.model(x)], 1)
self.outermost is a property set at the layer level. So, in our skeleton above, it's going to be true for the "2 Blocks from the Bottom" part but false for the others.
So, if we think abvout this for a minute, here's what happens. We get an input tensor
x. The outermost layer,
True, so this "2 Blocks from the Bottom" layer calls
two_from_bottom.model(x). This creates
nn.Conv2d(x) = x_1 and then passes the result to
one_from_bottom.outermost=False, so we return
torch.cat([x_1, one_from_bottom.model(x_1)]). Let's avoid going a layer deeper for now, and just think about what happens next, assuming that we know what
one_from_bottom.model(x_1) is. We pass
torch.cat([x_1, one_from_bottom.model(x_1)]) back up to
the last part of "2 Blocks from the Bottom", namely
nn.ReLU(...), nn.ConvTranspose2d(...), nn.Tanh(). In this way, the skip layers are implemented; not only does
x_1 travel down the stac, it's passed directly to the end of the same-level layer. Again, this isn't saying anything different from what the diagram above indicates; it's just interesting to see how this gets implemented in practice.
For our generator, we'll implement the same effect less elegantly but more simply.
We're going to use a PatchGAN discriminator. Here's how the pix2pix authors describe the idea:
In order to model high-frequencies, it is sufficient to restrict our attention to the structure in local image patches. Therefore, we design a discriminator architecture – which we term a PatchGAN – that only penalizes structure at the scale of patches. This discriminator tries to classify if each N × N patch in an image is real or fake. We run this discriminator convolutationally across the image, averaging all responses to provide the ultimate output of D.
Again, we can see Jun-Yan Zhu's implementation here. Simplifying, the skeleton looks something like this:
nn.Conv2d(in_channels, out_channels=64, ...) nn.LeakyReLU(...) # Going to do this block 4x, doubling the # channel count each time nn.Conv2d(in_channels=64, out_channels=128) nn.BatchNorm2d(...) nn.LeakyReLU(...) nn.Conv2d(in_channels=128, out_channels=256) nn.BatchNorm2d(...) nn.LeakyReLU(...) nn.Conv2d(in_channels=256, out_channels=512) nn.BatchNorm2d(...) nn.LeakyReLU(...) nn.Conv2d(in_channels=512, out_channels=1024) nn.BatchNorm2d(...) nn.LeakyReLU(...) # Now go down to 1 channel nn.Conv2d(1024, 1, ...)
We end up with a 1024 dimensional tensor as our output. One might wonder: If we're supposed to be labeling images as 0s (believed fake) or 1s (believed real), how are we going to use this 1024 dimensional output? The answer is that we'll apply our function to this entire vector against a 1024 dimensional vector of zeroes (if the image is fake) and ones (if the image is real). You can see how pix2pix does it here.