# Adapted from https://github.com/fregu856/deeplabv3
# NOTE! OS: output stride, the ratio of input image resolution to final output resolution (OS16: output size is (img_h/16, img_w/16)) (OS8: output size is (img_h/8, img_w/8))
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models.backbones.resnetdd import resnet18, resnet34, resnet50, resnet101, resnet152
[docs]def make_layer(block, in_channels, channels, num_blocks, stride=1, dilation=1):
strides = [stride] + [1]*(num_blocks - 1) # (stride == 2, num_blocks == 4 --> strides == [2, 1, 1, 1])
blocks = []
for stride in strides:
blocks.append(block(in_channels=in_channels, channels=channels, stride=stride, dilation=dilation))
in_channels = block.expansion*channels
layer = nn.Sequential(*blocks) # (*blocks: call with unpacked list entires as arguments)
return layer
[docs]class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, channels, stride=1, dilation=1):
super(BasicBlock, self).__init__()
out_channels = self.expansion*channels
self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
if (stride != 1) or (in_channels != out_channels):
conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
bn = nn.BatchNorm2d(out_channels)
self.downsample = nn.Sequential(conv, bn)
self.downsample = nn.Sequential()
[docs] def forward(self, x):
# (x has shape: (batch_size, in_channels, h, w))
out = F.relu(self.bn1(self.conv1(x))) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
out = self.bn2(self.conv2(out)) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
out = out + self.downsample(x) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
out = F.relu(out) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
return out
[docs]class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channels, channels, stride=1, dilation=1):
super(Bottleneck, self).__init__()
out_channels = self.expansion*channels
self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
self.conv3 = nn.Conv2d(channels, out_channels, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
if (stride != 1) or (in_channels != out_channels):
conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
bn = nn.BatchNorm2d(out_channels)
self.downsample = nn.Sequential(conv, bn)
self.downsample = nn.Sequential()
[docs] def forward(self, x):
# (x has shape: (batch_size, in_channels, h, w))
out = F.relu(self.bn1(self.conv1(x))) # (shape: (batch_size, channels, h, w))
out = F.relu(self.bn2(self.conv2(out))) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
out = self.bn3(self.conv3(out)) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2)
out = out + self.downsample(x) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2)
out = F.relu(out) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2)
return out
[docs]class ResNet_Bottleneck_OS16(nn.Module):
def __init__(self, num_layers, pretrained=False):
super(ResNet_Bottleneck_OS16, self).__init__()
if num_layers == 50:
resnet = resnet50(pretrained)
elif num_layers == 101:
resnet = resnet101(pretrained)
elif num_layers == 152:
resnet = resnet152(pretrained)
raise Exception("num_layers must be in {50, 101, 152}!")
# remove fully connected layer, avg pool and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-3])
self.layer5 = make_layer(Bottleneck, in_channels=4*256, channels=512, num_blocks=3, stride=1, dilation=2)
[docs] def forward(self, x):
# (x has shape (batch_size, 3, h, w))
# pass x through (parts of) the pretrained ResNet:
c4 = self.resnet(x) # (shape: (batch_size, 4*256, h/16, w/16)) (it's called c4 since 16 == 2^4)
output = self.layer5(c4) # (shape: (batch_size, 4*512, h/16, w/16))
return output
[docs]class ResNet_BasicBlock_OS16(nn.Module):
def __init__(self, num_layers, pretrained=False):
super(ResNet_BasicBlock_OS16, self).__init__()
if num_layers == 18:
resnet = resnet18(pretrained)
num_blocks = 2
elif num_layers == 34:
resnet = resnet34(pretrained)
num_blocks = 3
raise Exception("num_layers must be in {18, 34}!")
# remove fully connected layer, avg pool and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-3])
self.layer5 = make_layer(BasicBlock, in_channels=256, channels=512, num_blocks=num_blocks, stride=1, dilation=2)
[docs] def forward(self, x):
# (x has shape (batch_size, 3, h, w))
# pass x through (parts of) the pretrained ResNet:
c4 = self.resnet(x) # (shape: (batch_size, 256, h/16, w/16)) (it's called c4 since 16 == 2^4)
output = self.layer5(c4) # (shape: (batch_size, 512, h/16, w/16))
return output
[docs]class ResNet_BasicBlock_OS8(nn.Module):
def __init__(self, num_layers, pretrained=False):
super(ResNet_BasicBlock_OS8, self).__init__()
if num_layers == 18:
resnet = resnet18(pretrained)
num_blocks_layer_4 = 2
num_blocks_layer_5 = 2
elif num_layers == 34:
resnet = resnet34(pretrained)
num_blocks_layer_4 = 6
num_blocks_layer_5 = 3
raise Exception("num_layers must be in {18, 34}!")
# remove fully connected layer, avg pool, layer4 and layer5:
self.resnet = nn.Sequential(*list(resnet.children())[:-4])
self.layer4 = make_layer(BasicBlock, in_channels=128, channels=256, num_blocks=num_blocks_layer_4, stride=1, dilation=2)
self.layer5 = make_layer(BasicBlock, in_channels=256, channels=512, num_blocks=num_blocks_layer_5, stride=1, dilation=4)
[docs] def forward(self, x):
# (x has shape (batch_size, 3, h, w))
# pass x through (parts of) the pretrained ResNet:
c3 = self.resnet(x) # (shape: (batch_size, 128, h/8, w/8)) (it's called c3 since 8 == 2^3)
output = self.layer4(c3) # (shape: (batch_size, 256, h/8, w/8))
output = self.layer5(output) # (shape: (batch_size, 512, h/8, w/8))
return output
[docs]def ResNet18_OS16(pretrained):
return ResNet_BasicBlock_OS16(num_layers=18, pretrained=pretrained)
[docs]def ResNet34_OS16(pretrained):
return ResNet_BasicBlock_OS16(num_layers=34, pretrained=pretrained)
[docs]def ResNet50_OS16(pretrained):
return ResNet_Bottleneck_OS16(num_layers=50, pretrained=pretrained)
[docs]def ResNet101_OS16(pretrained):
return ResNet_Bottleneck_OS16(num_layers=101, pretrained=pretrained)
[docs]def ResNet152_OS16(pretrained):
return ResNet_Bottleneck_OS16(num_layers=152, pretrained=pretrained)
[docs]def ResNet18_OS8(pretrained):
return ResNet_BasicBlock_OS8(num_layers=18, pretrained=pretrained)
[docs]def ResNet34_OS8(pretrained):
return ResNet_BasicBlock_OS8(num_layers=34, pretrained=pretrained)