Source code for models.backbones.adaptive_unet

import torch
from torch import nn


[docs]def encoding_block(in_c, out_c): conv = nn.Sequential( nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(inplace=True) ) return conv
[docs]def encoding_block1(in_c, out_c): conv = nn.Sequential( nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU(inplace=True), nn.Dropout(0.4) ) return conv
[docs]def decoding_block(in_c, out_c): conv = nn.Sequential( nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2), nn.ReLU(inplace=True), ) return conv
[docs]class Adaptive_Unet(nn.Module): def __init__(self, out_channels=4, features=[32, 64, 128, 256]): super(Adaptive_Unet, self).__init__() self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) self.pool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) self.pool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) self.conv1 = encoding_block(3, features[0]) self.conv2 = encoding_block(features[0], features[1]) self.conv3 = encoding_block(features[1], features[2]) self.conv4 = encoding_block1(features[2], features[3]) self.conv5 = encoding_block(features[3] * 2, features[3]) self.conv6 = encoding_block(features[3], features[2]) self.conv7 = encoding_block(features[2], features[1]) self.conv8 = encoding_block(features[1], features[0]) self.tconv1 = decoding_block(features[-1] * 2, features[-1]) self.tconv2 = decoding_block(features[-1], features[-2]) self.tconv3 = decoding_block(features[-2], features[-3]) self.tconv4 = decoding_block(features[-3], features[-4]) self.bottleneck = encoding_block1(features[3], features[3] * 2) self.final_layer = nn.Conv2d(features[0], out_channels, kernel_size=1)
[docs] def forward(self, x): # encoder x_1 = self.conv1(x) # Convolution, ReLU 32 3×3 # To concat # print(x_1.size()) x_2 = self.pool1(x_1) # Maxpooling 2×2 ([1, 32, 672, 480]) # print(x_2.size()) x_3 = self.conv2(x_2) # 2 conv ([1, 64, 672, 480]) # To concat # print(x_3.size()) x_4 = self.pool2(x_3) # Maxpooling 2×2 ([1, 64, 336, 240]) # print(x_4.size()) x_5 = self.conv3(x_4) # 2 conv ([1, 128, 336, 240]) # To concat # print(x_5.size()) x_6 = self.pool3(x_5) # Maxpooling 2×2 ([1, 128, 168, 120]) # print(x_6.size()) x_7 = self.conv4(x_6) # 2 conv ([1, 256, 168, 120]) # To concat # print(x_7.size()) x_8 = self.pool4(x_7) # Maxpooling 2×2 [1, 256, 84, 60]) # print(x_8.size()) x_9 = self.bottleneck(x_8) # 2 conv ([1, 512, 84, 60]) # print(x_9.size()) # decoder x_10 = self.tconv1(x_9) # deconv ([1, 256, 168, 120]) # print(x_10.size()) x_11 = torch.cat((x_7, x_10), dim=1) # ([1, 512, 168, 120]) # print(x_11.size()) x_12 = self.conv5(x_11) # 2 conv ([1, 256, 168, 120]) # print(x_12.size()) x_13 = self.tconv2(x_12) # deconv [1, 128, 336, 240]) # print(x_13.size()) x_14 = torch.cat((x_5, x_13), dim=1) # ([1, 256, 336, 240]) # print(x_14.size()) x_15 = self.conv6(x_14) # 2 conv ([1, 128, 336, 240]) # print(x_15.size()) x_16 = self.tconv3(x_15) # deconv [1, 64, 672, 480]) # print(x_16.size()) x_17 = torch.cat((x_3, x_16), dim=1) # ([1, 128, 672, 480]) # print(x_17.size()) x_18 = self.conv7(x_17) # 2 conv ([1, 64, 672, 480]) # print(x_18.size()) x_19 = self.tconv4(x_18) # deconv [1, 32, 1344, 960]) # print(x_19.size()) x_20 = torch.cat((x_1, x_19), dim=1) # ([1, 64, 1344, 960]) # print(x_20.size()) x_21 = self.conv8(x_20) # 2 conv ([1, 32, 1344, 960] # print(x_21.size()) x = self.final_layer(x_21) # print(x.size()) return x