Source code for models.backbones.deeplabv3_aspp

# Adapted from https://github.com/fregu856/deeplabv3

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class ASPP(nn.Module): def __init__(self, num_classes): super(ASPP, self).__init__() self.conv_1x1_1 = nn.Conv2d(512, 256, kernel_size=1) self.bn_conv_1x1_1 = nn.BatchNorm2d(256) self.conv_3x3_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=6, dilation=6) self.bn_conv_3x3_1 = nn.BatchNorm2d(256) self.conv_3x3_2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=12, dilation=12) self.bn_conv_3x3_2 = nn.BatchNorm2d(256) self.conv_3x3_3 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=18, dilation=18) self.bn_conv_3x3_3 = nn.BatchNorm2d(256) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_1x1_2 = nn.Conv2d(512, 256, kernel_size=1) self.bn_conv_1x1_2 = nn.BatchNorm2d(256) self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256) self.bn_conv_1x1_3 = nn.BatchNorm2d(256) self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1)
[docs] def forward(self, feature_map): # (feature_map has shape (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. If self.resnet instead is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8)) feature_map_h = feature_map.size()[2] # (== h/16) feature_map_w = feature_map.size()[3] # (== w/16) out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1)) out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1)) out_img = torch.nn.functional.interpolate(out_img, size=(feature_map_h, feature_map_w), mode="bilinear", align_corners=True) # (shape: (batch_size, 256, h/16, w/16)) out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16)) out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16)) out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16)) return out
[docs]class ASPP_Bottleneck(nn.Module): def __init__(self, num_classes): super(ASPP_Bottleneck, self).__init__() self.conv_1x1_1 = nn.Conv2d(4*512, 256, kernel_size=1) self.bn_conv_1x1_1 = nn.BatchNorm2d(256) self.conv_3x3_1 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=6, dilation=6) self.bn_conv_3x3_1 = nn.BatchNorm2d(256) self.conv_3x3_2 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=12, dilation=12) self.bn_conv_3x3_2 = nn.BatchNorm2d(256) self.conv_3x3_3 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=18, dilation=18) self.bn_conv_3x3_3 = nn.BatchNorm2d(256) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_1x1_2 = nn.Conv2d(4*512, 256, kernel_size=1) self.bn_conv_1x1_2 = nn.BatchNorm2d(256) self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256) self.bn_conv_1x1_3 = nn.BatchNorm2d(256) self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1)
[docs] def forward(self, feature_map): # (feature_map has shape (batch_size, 4*512, h/16, w/16)) feature_map_h = feature_map.size()[2] # (== h/16) feature_map_w = feature_map.size()[3] # (== w/16) out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1)) out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1)) out_img = torch.nn.functional.interpolate(out_img, size=(feature_map_h, feature_map_w), mode="bilinear", align_corners=True) # (shape: (batch_size, 256, h/16, w/16)) out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16)) out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16)) out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16)) return out