import logging
import os
import sys
import traceback
from typing import Optional, OrderedDict
import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cloud_io import get_filesystem
log = logging.getLogger(__name__)
[docs]class SaveModelStateDictAndTaskCheckpoint(ModelCheckpoint):
"""
Saves the neural network weights into a pth file.
It produces a file for each the encoder and the header.
The encoder file is named after the backbone_filename and the header file after the header_filename.
The backbone_filename and the header_filename can be specified in the constructor.
:param backbone_filename: Filename of the backbone checkpoint
:type backbone_filename: str, optional
:param header_filename: Filename of the header checkpoint
:type header_filename: str, optional
"""
def __init__(self, backbone_filename: Optional[str] = 'backbone', header_filename: Optional[str] = 'header',
**kwargs):
"""
Constructor of the SaveModelStateDictAndTaskCheckpoint class.
"""
super(SaveModelStateDictAndTaskCheckpoint, self).__init__(**kwargs)
self.backbone_filename = backbone_filename
self.header_filename = header_filename
self.CHECKPOINT_NAME_LAST = 'task_last'
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
"""
Saves the header and backbone state dict and the task checkpoint in a combined folder.
The name of the folder ist the epoch the checkpoint was created on.
The task checkpoint is saved as task_last.pth if it is the last checkpoint and as task_epoch=x.pth otherwise.
:param trainer: The trainer object
:type trainer: pl.Trainer
:param filepath: The filepath of the model state dict
:type filepath: str
"""
if not trainer.is_global_zero:
return
super()._save_checkpoint(trainer=trainer, filepath=filepath)
model = trainer.lightning_module.model
metric_candidates = self._monitor_candidates(trainer)
# check if it is a last save or not
if 'last' not in filepath:
# fixed pathing problem
format_backbone_filename = self._format_checkpoint_name(filename=self.backbone_filename,
metrics=metric_candidates)
format_header_filename = self._format_checkpoint_name(filename=self.header_filename,
metrics=metric_candidates)
if trainer.current_epoch > 0:
# remove the last checkpoint folder
filepath_old = os.path.join(self.dirpath, f"epoch={trainer.current_epoch - 1}")
self._del_old_folder(filepath_old, trainer)
else:
format_backbone_filename = self.backbone_filename.split('/')[-1] + '_last'
format_header_filename = self.header_filename.split('/')[-1] + '_last'
torch.save(model.backbone.state_dict(), os.path.join(self.dirpath, format_backbone_filename + '.pth'))
torch.save(model.header.state_dict(), os.path.join(self.dirpath, format_header_filename + '.pth'))
@rank_zero_only
def _del_old_folder(self, filepath: str, trainer: "pl.Trainer") -> None:
"""
Deletes the old folder of the last checkpoint if the current epoch is better based on the monitor metric.
:param filepath: The filepath of the old folder
:type filepath: str
:param trainer: The trainer object
:type trainer: pl.Trainer
"""
file_system = get_filesystem(filepath)
if file_system.exists(filepath):
# delete all files in directory
for path in file_system.ls(filepath):
if file_system.exists(path):
trainer.strategy.remove_checkpoint(path)
# delete directory
file_system.rm(filepath, recursive=True)
log.debug(f"Removed checkpoint: {filepath}")