import os
from abc import ABCMeta
from pathlib import Path
from typing import Optional, Union, Type, Mapping, Sequence, Callable, Dict, Any, Tuple, List
import numpy as np
import pandas as pd
import seaborn as sn
import torch
import wandb
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassConfusionMatrix
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from omegaconf import OmegaConf
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from src.callbacks.wandb_callbacks import get_wandb_logger
from src.tasks.utils.outputs import OutputKeys
from src.tasks.utils.task_utils import get_callable_dict
from src.utils import utils
log = utils.get_logger(__name__)
[docs]class AbstractTask(LightningModule, metaclass=ABCMeta):
"""
Inspired by `Pytorch Lightning <https://github.com/PyTorchLightning/lightning-flash/blob/2ec52e633bb3679f50dd7e30526885a4547e1851/flash/core/model.py>`_.
A general abstract Task. It provieds the basic functionality for training, validation and testing. A step method
is provided which can be overwritten for custom behavior. The step method is called in the training, validation and
testing loop. The step method should return a dictionary with the following keys:
- ``OutputKeys.PREDICTION``: The prediction of the model.
- ``OutputKeys.LOSS``: The loss of the model.
- ``OutputKeys.LOG``: A dictionary with all the logs. The keys are the metric names and the values are the
metric values.
- ``OutputKeys.TARGET``: The target of the model.
:param model: Composed model to use for the task.
:type model: nn.Module
:param loss_fn: Loss function for training
:type loss_fn: Union[Callable, Mapping, Sequence]
:param optimizer: Optimizer to use for training, defaults to :class:`torch.optim.Adam`.
:type optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer]
:param optimizer_kwargs: Keyword arguments to pass to the optimizer.
:type optimizer_kwargs: Optional[Dict[str, Any]]
:param scheduler: Learning rate scheduler to use for training, defaults to None.
:type scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]]
:param scheduler_kwargs: Keyword arguments to pass to the scheduler.
:type scheduler_kwargs: Optional[Dict[str, Any]]
:param metric_train: Metrics to compute for training.
:type metric_train: Optional[MetricCollection]
:param metric_val: Metrics to compute for evaluation.
:type metric_val: Optional[MetricCollection]
:param metric_test: Metrics to compute for testing.
:type metric_test: Optional[MetricCollection]
:param confusion_matrix_val: Whether to compute the confusion matrix for the validation set.
:type confusion_matrix_val: bool
:param confusion_matrix_test: Whether to compute the confusion matrix for the test set.
:type confusion_matrix_test: bool
:param confusion_matrix_log_every_n_epoch: How often to compute the confusion matrix.
:type confusion_matrix_log_every_n_epoch: int
:param lr: Learning rate to use for training, defaults to ``5e-5``.
:type lr: float
:param test_output_path: Path relative to the normal output folder where to save the test output
:type test_output_path: Union[str, Path]
:param predict_output_path: Path relative to the normal output folder where to save the predict output
:type predict_output_path: Union[str, Path]
"""
def __init__(
self,
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metric_train: Optional[MetricCollection] = None,
metric_val: Optional[MetricCollection] = None,
metric_test: Optional[MetricCollection] = None,
confusion_matrix_val: Optional[bool] = False,
confusion_matrix_test: Optional[bool] = False,
confusion_matrix_log_every_n_epoch: Optional[int] = 1,
lr: float = 1e-3,
test_output_path: Optional[Union[str, Path]] = 'test_output',
predict_output_path: Optional[Union[str, Path]] = 'predict_output'
):
super().__init__()
resolver_name = 'task'
if not OmegaConf.has_resolver(resolver_name):
OmegaConf.register_new_resolver(
resolver_name,
lambda name: getattr(self, name),
use_cache=False
)
if model is not None:
self.model = model
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
self.optimizer = optimizer
self.scheduler = scheduler
self.optimizer_kwargs = optimizer_kwargs or {}
self.scheduler_kwargs = scheduler_kwargs or {}
self.metric_train = nn.ModuleDict({}) if metric_train is None else metric_train
self.metric_val = nn.ModuleDict({}) if metric_val is None else metric_val
self.metric_test = nn.ModuleDict({}) if metric_test is None else metric_test
self.confusion_matrix_val = confusion_matrix_val
self.confusion_matrix_test = confusion_matrix_test
self.confusion_matrix_log_every_n_epoch = confusion_matrix_log_every_n_epoch
self.lr = lr
self.test_output_path = Path(test_output_path)
self.predict_output_path = Path(predict_output_path)
# self.save_hyperparameters()
[docs] def setup(self, stage: str):
if self.confusion_matrix_val:
self.metric_conf_mat_val = MulticlassConfusionMatrix(num_classes=self.trainer.datamodule.num_classes,
compute_on_step=False)
if self.confusion_matrix_test:
self.metric_conf_mat_test = MulticlassConfusionMatrix(num_classes=self.trainer.datamodule.num_classes,
compute_on_step=False)
if self.trainer.strategy.strategy_name == 'ddp':
batch_size = self.trainer.datamodule.batch_size
if stage == 'fit':
num_samples = len(self.trainer.datamodule.train)
datasplit_name = 'train'
elif stage == 'test':
num_samples = len(self.trainer.datamodule.test)
datasplit_name = 'test'
elif stage == 'predict':
num_samples = len(self.trainer.datamodule.predict)
datasplit_name = 'predict'
else:
log.warning(f'Unknown stage ({stage}) during setup!')
num_samples = -1
datasplit_name = None
if num_samples % self.trainer.datamodule.batch_size != 0:
log.warning(
f'Number of sample ({num_samples}) in {datasplit_name} not dividable by batch size ({batch_size}).')
log.warning('Last batch will be incomplete. Behavior depends on datamodule.drop_last setting.')
[docs] def step(self,
batch: Any,
metric_kwargs: Optional[Dict[str, Dict[str, Any]]] = None) -> Union[Dict[OutputKeys, Any], Tuple[Any, Any]]:
"""
The training/validation/test step. Override for custom behavior.
:param batch: the batch with the images (x) and the gt (y) in the order (x, y)
:type batch: Any
:param metric_kwargs: a dictionary with a entry with the additional arguments (pred and y always provided).
e.g. you have two metrics (A, B) and B takes an additional arguments x and y so the dictionary would
look like this: {'B': {'x': 'value', 'y': 'value'}}
:type metric_kwargs: Optional[Dict[str, Dict[str, Any]]]
"""
for key in self.loss_fn:
if hasattr(self.loss_fn[key], 'weight') and self.loss_fn[key].weight is not None:
if torch.is_tensor(self.loss_fn[key].weight):
self.loss_fn[key].weight = self.loss_fn[key].weight.cuda(device=self.device)
if metric_kwargs is None:
metric_kwargs = {}
x, y = batch
y_hat = self(x)
if isinstance(y_hat, Dict):
y_hat = y_hat['out']
output = {OutputKeys.PREDICTION: y_hat}
y_hat = self.to_loss_format(output[OutputKeys.PREDICTION])
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
y_hat = self.to_metrics_format(output[OutputKeys.PREDICTION])
current_metric = self._get_current_metric()
for name, metric in current_metric.items():
if name in metric_kwargs:
logs[name] = metric(y_hat, y, **metric_kwargs[name])
else:
logs[name] = metric(y_hat, y)
logs.update(losses)
if len(losses.values()) > 1:
logs["total_loss"] = sum(losses.values())
return logs["total_loss"], logs
output[OutputKeys.LOSS] = list(losses.values())[0]
output[OutputKeys.LOG] = logs
output[OutputKeys.TARGET] = y
return output
[docs] def forward(self, x: Any) -> Any:
# some frameworks like torchvision returns dict
x = self.model(x)
# this looks like a bug in pycharm
if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
out = x['out']
elif torch.is_tensor(x):
out = x
else:
raise NotImplementedError(f"Unsupported network output type: {type(x)}")
return out
[docs] def training_step(self, batch: Any, batch_idx: int, **kwargs) -> Any:
"""
The training step. Calls the step method and logs the metrics and loss.
:param batch: The current batch to train on.
:type batch: Any
:param batch_idx: The index of the current batch.
:type batch_idx: int
:param kwargs: Additional arguments.
:type kwargs: Any
:return: The output of the step method.
:rtype: Any
"""
output = self.step(batch=batch, **kwargs)
self._log_metrics_and_loss(output, stage='train')
return output
[docs] def validation_step(self, batch: Any, batch_idx: int, **kwargs) -> None:
"""
the validation step. Calls the step method and logs the metrics and loss.
:param batch: The current batch to validate on.
:type batch: Any
:param batch_idx: The index of the current batch.
:type batch_idx: int
:param kwargs: Additional arguments.
:type kwargs: Any
:return: The output of the step method.
:rtype: Any
"""
output = self.step(batch=batch, **kwargs)
if self.trainer.state.stage == RunningStage.SANITY_CHECKING:
return output
self._log_metrics_and_loss(output, stage='val')
if self.confusion_matrix_val and (
self.trainer.current_epoch + 1) % self.confusion_matrix_log_every_n_epoch == 0:
self.metric_conf_mat_val(preds=output[OutputKeys.PREDICTION], target=output[OutputKeys.TARGET])
return output
[docs] def validation_epoch_end(self, outputs: Any) -> None:
if self.trainer.state.stage == RunningStage.SANITY_CHECKING \
or not self.confusion_matrix_val \
or (self.trainer.current_epoch + 1) % self.confusion_matrix_log_every_n_epoch != 0:
return
hist = self.metric_conf_mat_val.compute()
hist = hist.cpu().numpy()
self._create_conf_mat(matrix=hist, stage='val')
self.metric_conf_mat_val.reset()
[docs] def test_step(self, batch: Any, batch_idx: int, **kwargs) -> None:
output = self.step(batch=batch, **kwargs)
if self.confusion_matrix_test:
self.metric_conf_mat_test(preds=output[OutputKeys.PREDICTION], target=output[OutputKeys.TARGET])
self._log_metrics_and_loss(output, stage='test')
return output
[docs] def test_epoch_end(self, outputs: Any) -> None:
if not self.confusion_matrix_test:
return
hist = self.metric_conf_mat_test.compute()
hist = hist.cpu().numpy()
self._create_conf_mat(matrix=hist, stage='test')
self.metric_conf_mat_test.reset()
[docs] def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any:
y_hat = self(batch)
return {OutputKeys.PREDICTION: y_hat}
def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler:
scheduler = self.scheduler
if isinstance(scheduler, _LRScheduler):
return scheduler
elif issubclass(scheduler, _LRScheduler):
return scheduler(optimizer, **self.scheduler_kwargs)
raise MisconfigurationException(
"scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` "
f"or a built-in scheduler in {self.available_schedulers()}"
)
def _get_current_metric(self) -> MetricCollection:
"""
Get the current metrics for the current stage.
:return: The current metrics.
:rtype: MetricCollection
"""
if self.trainer.state.stage == RunningStage.TRAINING:
return self.metric_train
if self.trainer.state.stage == RunningStage.VALIDATING:
return self.metric_val
if self.trainer.state.stage == RunningStage.TESTING:
return self.metric_test
return {}
def _log_metrics_and_loss(self, output: Dict[str, Any], stage: str) -> None:
"""
Log the metrics and loss for the current stage to the logger.
:param output:
:type output: Dict[str, Any]
:param stage: The current stage.
:type stage: str
:return: None
:rtype: None
"""
for key, value in output[OutputKeys.LOG].items():
if value.dim() != 0 and len(value) != 1:
for i, v in enumerate(value):
self.log(f"{stage}/{key}_c_{i}", v, on_epoch=True, on_step=True, sync_dist=True,
rank_zero_only=True)
else:
self.log(f"{stage}/{key}", value, on_epoch=True, on_step=True, sync_dist=True, rank_zero_only=True)
def _create_conf_mat(self, matrix: np.ndarray, stage: str = 'val') -> None:
"""
Create and save confusion matrix to disc and wandb.
:param matrix: The confusion matrix.
:type matrix: np.ndarray
:param stage: The current stage.
:type stage: str
:return: None
:rtype: None
:raises ValueError: If the sum of the confusion matrix is not close to the expected sum.
"""
# verify sum of conf mat entries
pixels_per_crop = self.trainer.datamodule.dims[1] * self.trainer.datamodule.dims[2]
num_processes = self.trainer.num_devices
if stage == 'val':
dataloader = self.trainer.val_dataloaders[0]
elif stage == 'test':
assert not self.trainer.test_dataloaders[0].drop_last
dataloader = self.trainer.test_dataloaders[0]
else:
raise ValueError(f'_create_conf_mat received unexpected stage ({stage})')
num_samples = len(dataloader.dataset)
batch_size = dataloader.batch_size
if not dataloader.drop_last:
num_samples_in_last_step = num_samples % (batch_size * num_processes) # == 0, when last step is filled
num_process_filled = num_samples_in_last_step % num_processes # == 0, when all evenly filled
additional_crops = num_processes - num_process_filled if num_process_filled > 0 else 0
total_samples = num_samples + additional_crops
expected_sum = total_samples * pixels_per_crop
else:
num_steps = len(dataloader)
expected_sum = num_steps * num_processes * batch_size * pixels_per_crop
matrix_sum = matrix.sum()
if not np.isclose(a=expected_sum, b=matrix_sum, rtol=2.5e-7):
log.warning(f'matrix.sum() is not close to expected_sum '
f'({matrix_sum} != {expected_sum}, '
f'diff: {matrix_sum - expected_sum}')
# print(f'With all_gather: {str(len(outputs[0][OutputKeys.PREDICTION][0]))}')
conf_mat_name = f'CM_epoch_{self.trainer.current_epoch}'
# set figure size
plt.figure(figsize=(14, 8))
# set labels size
sn.set(font_scale=1.4)
# set font size
fig = sn.heatmap(matrix, annot=True, annot_kws={"size": 8}, fmt="g")
for i in range(matrix.shape[0]):
fig.add_patch(Rectangle((i, i), 1, 1, fill=False, edgecolor='yellow', lw=3))
plt.xlabel('Predictions')
plt.ylabel('Targets')
plt.title(conf_mat_name)
conf_mat_path = Path(os.getcwd()) / 'conf_mats' / stage
conf_mat_path.mkdir(parents=True, exist_ok=True)
conf_mat_file_path = conf_mat_path / (conf_mat_name + '.txt')
df = pd.DataFrame(matrix)
# save as csv or tsv to disc
if self.trainer.is_global_zero:
df.to_csv(path_or_buf=conf_mat_file_path, sep='\t')
try:
# save tsv to wandb
logger = get_wandb_logger(self.trainer)
experiment = logger.experiment
experiment.save(glob_str=str(conf_mat_file_path), base_path=os.getcwd())
# names should be uniqe or else charts from different experiments in wandb will overlap
experiment.log({f"confusion_matrix_{stage}_img/ep_{self.trainer.current_epoch}": wandb.Image(plt)},
commit=False)
except ValueError:
log.warning('No wandb logger found. Confusion matrix images are not saved.')