In this tutorial, we will implement a UNet to solve Kaggle's 2018 Data Science Bowl Competition. The challenge asks participants to find the location of nuclei from images of cells. The source of this tutorial and instructions to reproduce this analysis can be found at the thomasjpfan/ml-journal repo.
We can now define the datasets training and validiation datasets:
samples_dirs = list(d for d in Path('data/cells/').iterdir() if d.is_dir()) train_dirs, valid_dirs = train_test_split( samples_dirs, test_size=0.2, random_state=42) train_cell_ds = CellsDataset(train_dirs) valid_cell_ds = CellsDataset(valid_dirs)
Overall the cell images come in different sizes, and fall in three different categories:
Most of the data is of Type 2. Training a single model to be able to find the nuclei for all types may not be the best option, but we will give it a try! For reference here are the corresponding masks for the above three types:
In order to train a neutral net, each image we feed in must be the same size. For our dataset, we break our images up into 256x256 patches. The UNet architecture typically has a hard time dealing with objects on the edge of an image. In order to deal with this issue, we pad our images by 16 using reflection. The image augmentation is handled by
PatchedDataset. Its implementation can be found in
train_ds = PatchedDataset( train_cell_ds, patch_size=(256, 256), padding=16, random_flips=True) val_ds = PatchedDataset( valid_cell_ds, patch_size=(256, 256), padding=16, random_flips=False)
Now we define the UNet module with the pretrained
VGG16_bn as a feature encoder. The details of this module can be found in
module = UNet(pretrained=True)
The features generated by
VGG16_bn are prefixed with
conv. These weights will be frozen, which restricts training to only our decoder layers.
from skorch.callbacks import Freezer freezer = Freezer('conv*')
We use a Cyclic Learning Rate scheduler to train our neutral network.
from skorch.callbacks import LRScheduler from skorch.callbacks.lr_scheduler import CyclicLR cyclicLR = LRScheduler(policy=CyclicLR, base_lr=0.002, max_lr=0.2, step_size_up=540, step_size_down=540)
Why is step_size_up 540?
Since we are using a batch size of 32, each epoch will have about 54 (
len(train_ds)//32) training iterations. We are also setting
max_epochs to 20, which gives a total of 1080 (
max_epochs*54) training iterations. We construct our Cyclic Learning Rate policy to peak at the 10th epoch by setting
step_size_up to 540. This can be shown with a plot of the learning rate:
_, ax = plt.subplots(figsize=(10, 5)) ax.set_title('Cyclic Learning Rate Scheduler') ax.set_xlabel('Training iteration') ax.set_ylabel('Learning Rate') ax.plot(cyclicLR.simulate(1080, 0.002));
A checkpoint is used to save the model weights with the best loss:
from skorch.callbacks import Checkpoint checkpoint = Checkpoint(dirname='unet')
Since we have padded our images and mask, the loss function will need to ignore the padding when calculating the binary log loss. We define a
BCEWithLogitsLossPadding to filter out the padding:
class BCEWithLogitsLossPadding(nn.Module): def __init__(self, padding=16): super().__init__() self.padding = padding def forward(self, input, target): input = input.squeeze_( dim=1)[:, self.padding:-self.padding, self.padding:-self.padding] target = target.squeeze_( dim=1)[:, self.padding:-self.padding, self.padding:-self.padding] return binary_cross_entropy_with_logits(input, target)
Now we can define the
skorch NeutralNet to train out UNet!
from skorch.net import NeuralNet from skorch.helper import predefined_split net = NeuralNet( module, criterion=BCEWithLogitsLossPadding, criterion__padding=16, batch_size=32, max_epochs=20, optimizer__momentum=0.9, iterator_train__shuffle=True, iterator_train__num_workers=4, iterator_valid__shuffle=False, iterator_valid__num_workers=4, train_split=predefined_split(val_ds), callbacks=[('freezer', freezer), ('cycleLR', cyclicLR), ('checkpoint', checkpoint)], device='cuda' )
Let's highlight some parametesr in our
criterion__padding=16- Passes the padding to our
train_split=predefined_split(val_ds)- Sets the
val_dsto be the validation set during training.
callbacks=[(..., Checkpoint(f_params='best_params.pt'))]- Saves the best parameters to
Next we train our UNet with the training dataset:
epoch train_loss valid_loss cp dur ------- ------------ ------------ ---- ------- 1 0.4901 0.4193 + 53.9509 2 0.3803 0.3331 + 46.7676 3 0.2797 0.2307 + 46.9844 4 0.1653 0.1053 + 46.9767 5 0.1076 0.1025 + 46.9547 6 0.0825 0.0780 + 47.0113 7 0.0765 0.0747 + 47.1332 8 0.0732 0.0641 + 47.0073 9 0.0632 0.0548 + 47.0701 10 0.0574 0.0537 + 46.9553 11 0.0565 0.0537 47.1040 12 0.0544 0.0536 + 47.1731 13 0.0543 0.0513 + 47.2048 14 0.0523 0.0513 47.1222 15 0.0520 0.0503 + 47.3969 16 0.0515 0.0512 47.1741 17 0.0514 0.0503 + 46.9930 18 0.0522 0.0501 + 47.0438 19 0.0517 0.0501 47.3764 20 0.0515 0.0519 47.2810
Before we evaluate our model, we load the checkpoint with the best weights into the
Now that we trained our model, lets see how we did with the three types presented at the beginning of this tutorial. Since our UNet module, is designed to output logits, we must convert these values to probabilities:
val_masks = net.predict(val_ds).squeeze(1) val_prob_masks = 1/(1 + np.exp(-val_masks))
We plot the predicted mask with its corresponding true mask and original image:
Our UNet is able to predict the location of the nuclei for all three types of cell images!
In this tutorial, we used
skorch to train a UNet to predict the location of nuclei in an image. There are still areas that can be improved with our solution:
- Since there are three types of images in our dataset, we can improve our results by having three different UNet models for each of the three types.
- We can use traditional image processing to fill in the holes that our UNet produced.
- Our loss function can include a loss analogous to the compeititons metric of intersection over union.