Source code for models.backbones.deeplabv3

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

import urllib
import os

from src.models.backbones.deeplabv3_resnet import ResNet18_OS16, ResNet34_OS16, ResNet50_OS16, ResNet101_OS16, \
    ResNet152_OS16, ResNet18_OS8, ResNet34_OS8
from src.models.backbones.deeplabv3_aspp import ASPP, ASPP_Bottleneck

CLASS_NAMES = {"deeplabv3_resnet18_os16": ResNet18_OS16,
               "deeplabv3_resnet34_os16": ResNet34_OS16,
               "deeplabv3_resnet50_os16": ResNet50_OS16,
               "deeplabv3_resnet101_os16": ResNet101_OS16,
               "deeplabv3_resnet152_os16": ResNet152_OS16,
               "deeplabv3_resnet18_os8": ResNet18_OS8,
               "deeplabv3_resnet34_os8": ResNet34_OS8,
               }


[docs]def deeplabv3(num_classes, **kwargs): return deeplabv3_builder('deeplabv3', num_classes, **kwargs)
[docs]def deeplabv3_resnet18_os16(num_classes, **kwargs): return deeplabv3_builder('deeplabv3_resnet18_os16', num_classes, **kwargs)
[docs]def deeplabv3_resnet34_os16(num_classes, **kwargs): return deeplabv3_builder('deeplabv3_resnet34_os16', num_classes, **kwargs)
[docs]def deeplabv3_resnet50_os16(num_classes, **kwargs): return deeplabv3_builder('deeplabv3_resnet50_os16', num_classes, **kwargs)
[docs]def deeplabv3_resnet101_os16(num_classes, **kwargs): return deeplabv3_builder('deeplabv3_resnet101_os16', num_classes, **kwargs)
[docs]def deeplabv3_resnet152_os16(num_classes, **kwargs): return deeplabv3_builder('deeplabv3_resnet152_os16', num_classes, **kwargs)
[docs]def deeplabv3_resnet18_os8(num_classes, **kwargs): return deeplabv3_builder('deeplabv3_resnet18_os8', num_classes, **kwargs)
[docs]def deeplabv3_resnet34_os8(num_classes, **kwargs): return deeplabv3_builder('deeplabv3_resnet34_os8', num_classes, **kwargs)
# *********************************************************************************
[docs]def deeplabv3_builder(model_name, output_channels, pretrained=False, resume=None, cityscapes=False, **kwargs): if model_name=='deeplabv3': logging.info('ResNet type not specified, running "deeplabv3_resnet18_os8". (choose from {})'.format(", ".join(CLASS_NAMES.keys()))) model = DeepLabV3("deeplabv3_resnet18_os8", pretrained, output_channels, **kwargs) else: model = DeepLabV3(model_name, pretrained, output_channels, **kwargs) # load a model from a path if resume: if os.path.isfile(resume): model_dict = torch.load(resume) logging.info('Loading a saved model') try: model.load_state_dict(model_dict['state_dict'], strict=False) except Exception as exp: logging.warning(exp) else: logging.error("No model dict found at '{}'".format(resume)) # load the weights pre-trained on cityscapes dataset (only possible for current "deeplabv3_resnet18_os8" set-up) if "deeplabv3_resnet18_os8" and cityscapes: try: path = get_cityscapes_model_path(**kwargs) model.load_state_dict(torch.load(path), strict=False) except Exception as exp: logging.warning(exp) return model
[docs]class DeepLabV3(nn.Module): def __init__(self, model_name, pretrained, num_classes, **kwargs): super(DeepLabV3, self).__init__() self.num_classes = num_classes self.resnet = CLASS_NAMES[model_name](pretrained) # NOTE! specify the type of ResNet here if 'resnet18' in model_name or 'resnet34' in model_name: self.aspp = ASPP(num_classes=self.num_classes) # NOTE! if you use ResNet50-152, set self.aspp = ASPP_Bottleneck(num_classes=self.num_classes) instead else: self.aspp = ASPP_Bottleneck(num_classes=self.num_classes)
[docs] def forward(self, x): # (x has shape (batch_size, 3, h, w)) h = x.size()[2] w = x.size()[3] feature_map = self.resnet(x) # (shape: (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. If self.resnet is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8). If self.resnet is ResNet50-152, it will be (batch_size, 4*512, h/16, w/16)) output = self.aspp(feature_map) # (shape: (batch_size, num_classes, h/16, w/16)) output = torch.nn.functional.interpolate(output, size=(h, w), mode="bilinear", align_corners=True) # (shape: (batch_size, num_classes, h, w)) return output
[docs]def get_cityscapes_model_path(**kwargs): download_path = os.path.join(os.getcwd(), "models/deeplabv3_13_2_2_2_epoch_580.pth") if not os.path.exists(download_path): url = urllib.parse.urlparse("https://github.com/fregu856/deeplabv3/blob/master/pretrained_models/model_13_2_2_2_epoch_580.pth?raw=true") print('Downloading {}...'.format(url.geturl())) urllib.request.urlretrieve(url.geturl(), download_path) return download_path