Source code for models.backbones.backboned_unet

import torch
import torch.nn as nn
from torchvision import models
from torch.nn import functional as F

from src.models.backbones.resnet import ResNet50, ResNet18, ResNet34, ResNet152, ResNet101


# The whole class is from https://github.com/mkisantal/backboned-unet/blob/master/backboned_unet/unet.py

[docs]def get_backbone(name, pretrained=True): """ Loading backbone, defining names for skip-connections and encoder output. """ # TODO: More backbones # loading backbone model if name == 'resnet18': backbone = ResNet18() elif name == 'resnet34': backbone = ResNet34() elif name == 'resnet50': backbone = ResNet50() elif name == 'resnet101': backbone = ResNet101() elif name == 'resnet152': backbone = ResNet152() elif name == 'vgg16': backbone = models.vgg16_bn(pretrained=pretrained).features elif name == 'vgg19': backbone = models.vgg19_bn(pretrained=pretrained).features # elif name == 'inception_v3': # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False) elif name == 'densenet121': backbone = models.densenet121(pretrained=True).features elif name == 'densenet161': backbone = models.densenet161(pretrained=True).features elif name == 'densenet169': backbone = models.densenet169(pretrained=True).features elif name == 'densenet201': backbone = models.densenet201(pretrained=True).features elif name == 'unet_encoder': backbone = UnetEncoder(3) else: raise NotImplementedError('{} backbone model is not implemented so far.'.format(name)) # specifying skip feature and output names if name.startswith('resnet'): feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3'] backbone_output = 'layer4' elif name == 'vgg16': # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output feature_names = ['5', '12', '22', '32', '42'] backbone_output = '43' elif name == 'vgg19': feature_names = ['5', '12', '25', '38', '51'] backbone_output = '52' # elif name == 'inception_v3': # feature_names = [None, 'Mixed_5d', 'Mixed_6e'] # backbone_output = 'Mixed_7c' elif name.startswith('densenet'): feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3'] backbone_output = 'denseblock4' elif name == 'unet_encoder': feature_names = ['module1', 'module2', 'module3', 'module4'] backbone_output = 'module5' else: raise NotImplementedError('{} backbone model is not implemented so far.'.format(name)) return backbone, feature_names, backbone_output
[docs]class UpsampleBlock(nn.Module): # TODO: separate parametric and non-parametric classes? # TODO: skip connection concatenated OR added def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False): super(UpsampleBlock, self).__init__() self.parametric = parametric ch_out = ch_in/2 if ch_out is None else ch_out # first convolution: either transposed conv, or conv following the skip connection if parametric: # versions: kernel=4 padding=1, kernel=2 padding=0 self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4), stride=2, padding=1, output_padding=0, bias=(not use_bn)) self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None else: self.up = None ch_in = ch_in + skip_in self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3), stride=1, padding=1, bias=(not use_bn)) self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None self.relu = nn.ReLU(inplace=True) # second convolution conv2_in = ch_out if not parametric else ch_out + skip_in self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3), stride=1, padding=1, bias=(not use_bn)) self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None
[docs] def forward(self, x, skip_connection=None): x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear', align_corners=None) if self.parametric: x = self.bn1(x) if self.bn1 is not None else x x = self.relu(x) if skip_connection is not None: x = torch.cat([x, skip_connection], dim=1) if not self.parametric: x = self.conv1(x) x = self.bn1(x) if self.bn1 is not None else x x = self.relu(x) x = self.conv2(x) x = self.bn2(x) if self.bn2 is not None else x x = self.relu(x) return x
[docs]class Unet(nn.Module): """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones.""" def __init__(self, backbone_name='resnet50', pretrained=False, encoder_freeze=False, num_classes=21, decoder_filters=(256, 128, 64, 32, 16), parametric_upsampling=True, shortcut_features='default', decoder_use_batchnorm=True): super(Unet, self).__init__() self.backbone_name = backbone_name self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained) shortcut_chs, bb_out_chs = self.infer_skip_channels() if shortcut_features != 'default': self.shortcut_features = shortcut_features # build decoder part self.upsample_blocks = nn.ModuleList() decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1]) num_blocks = len(self.shortcut_features) for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)): self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out, skip_in=shortcut_chs[num_blocks-i-1], parametric=parametric_upsampling, use_bn=decoder_use_batchnorm)) # self.final_conv = nn.Conv2d(decoder_filters[-1], num_classes, kernel_size=(1, 1)) if encoder_freeze: self.freeze_encoder() self.replaced_conv1 = False # for accommodating inputs with different number of channels later
[docs] def freeze_encoder(self): """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """ for param in self.backbone.parameters(): param.requires_grad = False
[docs] def forward(self, *input): """ Forward propagation in U-Net. """ x, features = self.forward_backbone(*input) for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks): skip_features = features[skip_name] x = upsample_block(x, skip_features) # x = self.final_conv(x) return x
[docs] def forward_backbone(self, x): """ Forward propagation in backbone encoder network. """ features = {None: None} if None in self.shortcut_features else dict() for name, child in self.backbone.named_children(): x = child(x) if name in self.shortcut_features: features[name] = x if name == self.bb_out_name: break return x, features
[docs] def infer_skip_channels(self): """ Getting the number of channels at skip connections and at the output of the encoder. """ x = torch.zeros(1, 3, 224, 224) has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder' channels = [] if has_fullres_features else [0] # only VGG has features at full resolution # forward run in backbone to count channels (dirty solution but works for *any* Module) for name, child in self.backbone.named_children(): x = child(x) if name in self.shortcut_features: channels.append(x.shape[1]) if name == self.bb_out_name: out_channels = x.shape[1] break return channels, out_channels
[docs]class UnetDownModule(nn.Module): """ U-Net downsampling block. """ def __init__(self, in_channels, out_channels, downsample=True): super(UnetDownModule, self).__init__() # layers: optional downsampling, 2 x (conv + bn + relu) self.maxpool = nn.MaxPool2d((2,2)) if downsample else None self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True)
[docs] def forward(self, x): if self.maxpool is not None: x = self.maxpool(x) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) return x
[docs]class UnetEncoder(nn.Module): """ U-Net encoder. https://arxiv.org/pdf/1505.04597.pdf """ def __init__(self, num_channels): super(UnetEncoder, self,).__init__() self.module1 = UnetDownModule(num_channels, 64, downsample=False) self.module2 = UnetDownModule(64, 128) self.module3 = UnetDownModule(128, 256) self.module4 = UnetDownModule(256, 512) self.module5 = UnetDownModule(512, 1024)
[docs] def forward(self, x): x = self.module1(x) x = self.module2(x) x = self.module3(x) x = self.module4(x) x = self.module5(x)