Source code for models.backbone_header_model

from typing import Union, Optional, OrderedDict

import pytorch_lightning as pl
import torch.nn
from torchvision.models._utils import IntermediateLayerGetter


[docs]class BackboneHeaderModel(pl.LightningModule): """ A generic model class to provide the possibility to create different backbone/header combinations. The backbone and header compatibility can be tested with the callback :class:`CheckBackboneHeaderCompatibility` during runtime. The loading of the different parts is done in :class:`execute`. :param backbone: The backbone model :type backbone: Union[pl.LightningModule, torch.nn.Module] :param header: The header model :type header: Union[pl.LightningModule, torch.nn.Module] :param backbone_output_layer: The name of the output layer of the backbone. If None, the last layer of the backbone is used. :type backbone_output_layer: Optional[str] """ def __init__(self, backbone: Union[pl.LightningModule, torch.nn.Module], header: Union[pl.LightningModule, torch.nn.Module], backbone_output_layer: Optional[str] = None): super().__init__() # sanity check if the last layer of the backbone is compatible with the first layer of the header if backbone_output_layer is not None: return_layer = {backbone_output_layer: 'out'} self.backbone = IntermediateLayerGetter(model=backbone, return_layers=return_layer) else: self.backbone = backbone self.header = header
[docs] def forward(self, x): x = self.backbone(x) if isinstance(x, OrderedDict): x = x['out'] x = self.header(x) return x