import torch
from torch import nn
from torch.nn import functional as F
[docs]class OldUNet(nn.Module):
"""
Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation
<https://arxiv.org/abs/1505.04597>`_
Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox
Implemented by:
- `Annika Brundyn <https://github.com/annikabrundyn>`_
- `Akshay Kulkarni <https://github.com/akshaykvnit>`_
Args:
num_classes: Number of output classes required
input_channels: Number of channels in input images (default 3)
num_layers: Number of layers in each side of U-net (default 5)
features_start: Number of features in first layer (default 64)
bilinear: Whether to use bilinear interpolation or transposed convolutions (default) for upsampling.
"""
def __init__(
self,
num_classes: int,
input_channels: int = 3,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False,
):
if num_layers < 1:
raise ValueError(f"num_layers = {num_layers}, expected: num_layers > 0")
super().__init__()
self.num_layers = num_layers
layers = [DoubleConv(input_channels, features_start)]
feats = features_start
for _ in range(num_layers - 1):
layers.append(Down(feats, feats * 2))
feats *= 2
for _ in range(num_layers - 1):
layers.append(Up(feats, feats // 2, bilinear))
feats //= 2
layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))
self.layers = nn.ModuleList(layers)
[docs] def forward(self, x):
xi = [self.layers[0](x)]
# Down path
for layer in self.layers[1: self.num_layers]:
xi.append(layer(xi[-1]))
# Up path
for i, layer in enumerate(self.layers[self.num_layers: -1]):
xi[-1] = layer(xi[-1], xi[-2 - i])
return self.layers[-1](xi[-1])
[docs]class UNet(nn.Module):
"""
Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation
<https://arxiv.org/abs/1505.04597>`_
Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox
Implemented by:
- `Annika Brundyn <https://github.com/annikabrundyn>`_
- `Akshay Kulkarni <https://github.com/akshaykvnit>`_
Args:
num_classes: Number of output classes required
input_channels: Number of channels in input images (default 3)
num_layers: Number of layers in each side of U-net (default 5)
features_start: Number of features in first layer (default 64)
bilinear: Whether to use bilinear interpolation or transposed convolutions (default) for upsampling.
"""
def __init__(
self,
input_channels: int = 3,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False,
):
if num_layers < 1:
raise ValueError(f"num_layers = {num_layers}, expected: num_layers > 0")
super().__init__()
self.num_layers = num_layers
layers = [DoubleConv(input_channels, features_start)]
feats = features_start
for _ in range(num_layers - 1):
layers.append(Down(feats, feats * 2))
feats *= 2
for _ in range(num_layers - 1):
layers.append(Up(feats, feats // 2, bilinear))
feats //= 2
# layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))
self.layers = nn.ModuleList(layers)
[docs] def forward(self, x):
xi = [self.layers[0](x)]
# Down path
for layer in self.layers[1: self.num_layers]:
xi.append(layer(xi[-1]))
# Up path
for i, layer in enumerate(self.layers[self.num_layers:]):
xi[-1] = layer(xi[-1], xi[-2 - i])
return xi[-1]
[docs]class DoubleConv(nn.Module):
"""[ Conv2d => BatchNorm (optional) => ReLU ] x 2."""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
[docs] def forward(self, x):
return self.net(x)
[docs]class Down(nn.Module):
"""Downscale with MaxPool => DoubleConvolution block."""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), DoubleConv(in_ch, out_ch))
[docs] def forward(self, x):
return self.net(x)
[docs]class Up(nn.Module):
"""Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature
map from contracting path, followed by DoubleConv."""
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
super().__init__()
self.upsample = None
if bilinear:
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_ch, in_ch // 2, kernel_size=1),
)
else:
self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_ch, out_ch)
[docs] def forward(self, x1, x2):
x1 = self.upsample(x1)
# Pad x1 to the size of x2
diff_h = x2.shape[2] - x1.shape[2]
diff_w = x2.shape[3] - x1.shape[3]
x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
# Concatenate along the channels axis
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
[docs]class Baby_UNet(UNet):
def __init__(self):
super(Baby_UNet, self).__init__(num_layers=2, features_start=32)
[docs]def encoding_block(in_c, out_c):
conv = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
return conv
[docs]class UNetNajoua(nn.Module):
def __init__(self, num_classes=4, features=[16, 32]):
super(UNetNajoua, self).__init__()
self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.conv1 = encoding_block(3, features[0])
self.conv2 = encoding_block(features[0], features[0])
self.conv3 = encoding_block(features[0], features[0])
self.conv4 = encoding_block(features[0], features[0])
self.conv5 = encoding_block(features[1], features[0])
self.conv6 = encoding_block(features[1], features[0])
self.conv7 = encoding_block(features[1], features[0])
self.conv8 = encoding_block(features[1], features[0])
self.tconv1 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)
self.tconv2 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)
self.tconv3 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)
self.tconv4 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)
self.bottleneck = encoding_block(features[0], features[0])
# self.final_layer = nn.Conv2d(features[0], num_classes, kernel_size=1)
[docs] def forward(self, x):
# encoder
x_1 = self.conv1(x) # to_concat
# print(x_1.size())
x_2 = self.pool(x_1)
# print(x_2.size())
x_3 = self.conv2(x_2) # to_concat
# print(x_3.size())
x_4 = self.pool(x_3)
# print(x_4.size())
x_5 = self.conv3(x_4) # to_concat
# print(x_5.size())
x_6 = self.pool(x_5)
# print(x_6.size())
x_7 = self.conv4(x_6) # to_concat
# print(x_7.size())
x_8 = self.pool(x_7)
# print(x_8.size())
x_9 = self.bottleneck(x_8)
# print(x_9.size())
# decoder
x_10 = self.tconv1(x_9)
# print(x_10.size())
x_11 = torch.cat((x_7, x_10), dim=1)
# print(x_11.size())
x_12 = self.conv5(x_11)
# print(x_12.size())
x_13 = self.tconv2(x_12)
# print(x_13.size())
x_14 = torch.cat((x_5, x_13), dim=1)
# print(x_14.size())
x_15 = self.conv6(x_14)
# print(x_15.size())
x_16 = self.tconv3(x_15)
# print(x_16.size())
x_17 = torch.cat((x_3, x_16), dim=1)
# print(x_17.size())
x_18 = self.conv7(x_17)
# print(x_18.size())
x_19 = self.tconv4(x_18)
# print(x_19.size())
x_20 = torch.cat((x_1, x_19), dim=1)
# print(x_20.size())
x_21 = self.conv8(x_20)
# print(x_21.size())
# x = self.final_layer(x_21)
# print(x.size())
return x_21
[docs]class UNet16(UNetNajoua):
def __init__(self, num_classes=4):
super(UNet16, self).__init__(num_classes=num_classes, features=[16, 32])
[docs]class UNet32(UNetNajoua):
def __init__(self, num_classes=4):
super(UNet32, self).__init__(num_classes=num_classes, features=[32, 64])
[docs]class UNet64(UNetNajoua):
def __init__(self, num_classes=4):
super(UNet64, self).__init__(num_classes=num_classes, features=[64, 128])
[docs]def one_conv1(in_c, out_c):
convol1 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=True),
nn.ReLU(inplace=True),
)
return convol1
[docs]def one_conv2(in_c, out_c):
convol2 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=True),
nn.ReLU(inplace=True),
)
return convol2