Alan Martyn – 3 April 2019
The theory behind residual shortcuts, described in the original ResNet paper, suggests that they should improve image-to-image mapping. Specifically, if we provide an image as input to a convolutional network, and set the objective to output the input unchanged, then we expect a residual shortcut from input to output to improve performance.
Here we test that assumption in a semi-minimal setting with a 2-layer convolutional network equivalent to the resnet block described in the original papers' CIFAR-10 experiment.
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
We use the cell membrane dataset used in the U-net paper, a set of 30 black and white microscopic images of cells.
train_pth = '../data/membrane/train/target'
class DataGen():
"""
Open images from disk an format as numpy array
as expected by PyTorch DataLoader
"""
def __init__(self, root_dir):
self.root_dir = root_dir
def __len__(self):
return len(os.listdir(self.root_dir))
def __getitem__(self, idx):
filepath = f'{self.root_dir}/{idx}.png'
img = Image.open(filepath)
# binary black and white image 0=>black, 1=>white
arr = np.asarray(img, dtype=np.float32) / 255
return np.expand_dims(arr, axis=0)
# Initilise dataloader
dataset = DataGen(train_pth)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
The images in this dataset are binary black and white, which means that each pixel is either totally black 0.0
or totally white 1.0
. This means that an L1 loss will be an ideal objective, whereas if possible values were continuous then the L1 objective alone would encourage blurring.
# Plot a sample image from the dataset
x = dataset[0]
plt.imshow(x[0], cmap='gray');
np.unique(x)
Setup PyTorch model and utility functions.
class resnet_block(nn.Module):
"""
The ResNet block as described in the CIFAR-10 experiment of:
Deep Residual Learning for Image Recognition
"""
def __init__(self, channels):
super().__init__()
# Setup layers
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels, track_running_stats=True)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels, track_running_stats=True)
self.relu2 = nn.ReLU()
# Initialise weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x, shortcuts=False):
# Forward pass
z = self.conv1(x)
z = self.bn1(z)
z = self.relu1(z)
z = self.conv2(z)
z = self.bn2(z)
# Shortcut connection
if shortcuts:
z = z + x
z = self.relu2(z)
return z
def tensor2arr(tensor):
"""Convert torch tensor to numpy array"""
arr = tensor.detach().numpy()
return arr
def row_show(img_ts: list, titles: list):
"""Plot list of image tensors in a row with title"""
cols = len(img_ts)
fig, axs = plt.subplots(1, cols, figsize=(10, 3))
for i, ax in enumerate(axs):
ax.imshow(tensor2arr(img_ts[i][0, 0]), cmap='gray')
ax.axis('off')
ax.set_title(titles[i])
plt.show()
We can now train two distinct models in parallel to test our hypothesis. For each step we train a model without and with a residual shortcut. After each epoch of training we record the average training loss and generate a sample image from each model to evaluate progress.
# Initialise two models, one to train with shortcut
# one to train without
shortcut_model = resnet_block(1)
s_optimizer = torch.optim.Adam(shortcut_model.parameters())
noshortcut_model = resnet_block(1)
ns_optimizer = torch.optim.Adam(noshortcut_model.parameters())
# Both models trained with L1 criterion
criterion = nn.L1Loss()
results = {'no_shortcut': [], 'shortcut': []}
for epoch in range(50):
# mean loss for this epoch
ns_loss = 0.0
s_loss = 0.0
for i, x in enumerate(loader, 0):
# Train a net without shortcut
ns_optimizer.zero_grad()
z = torch.sigmoid(noshortcut_model(x, shortcuts=False))
loss = criterion(z, x)
loss.backward()
ns_optimizer.step()
ns_loss += loss.item()
# Train a net with shortcut
s_optimizer.zero_grad()
z = torch.sigmoid(shortcut_model(x, shortcuts=True))
loss = criterion(z, x)
loss.backward()
s_optimizer.step()
s_loss += loss.item()
# Record results
results['no_shortcut'].append(ns_loss)
results['shortcut'].append(s_loss)
# Evaluate
# Forward pass through each model
ns_img = torch.sigmoid(noshortcut_model(x, shortcuts=False))
s_img = torch.sigmoid(shortcut_model(x, shortcuts=True))
# Plot results
print(f'Epoch {epoch}')
row_show([ns_img, s_img, x],
[f'no shortcut loss: {ns_loss : .2f}',
f'shortcut loss: {s_loss: .2f}',
'target'])
Let's plot the results:
df = pd.DataFrame(results)
plt.figure(figsize=(10, 6))
plt.plot(list(df.index), df['shortcut'])
plt.plot(list(df.index), df['no_shortcut'])
plt.ylabel('L1 loss')
plt.xlabel('epoch')
plt.title('Training curves for image-to-image identity mapping of 512px images with and without shortcut')
plt.legend();
The model with residual shortcut learns significantly faster and after 50 epochs outperforms the model-without-shortcut.