Source code for metrics.divahisdb

from typing import Any, Optional, Callable

import numpy as np
import torch
from torch import Tensor
from torchmetrics import Metric


[docs]class HisDBIoU(Metric): """ Implementation of the mIoU metric used in the paper of `Alberti et al. <https://ieeexplore.ieee.org/abstract/document/8270257>`_. Using it just makes sense if the gt is in the DIVA-HisDB format. :param num_classes: number of classes :type num_classes: int :param mask_modifies_prediction: if True, the mask is used to modify the prediction, otherwise the prediction is used to modify the mask :type mask_modifies_prediction: bool :param compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True :type compute_on_step: bool :param dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. default: False :type dist_sync_on_step: bool :param process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) :type process_group: Optional[Any] """ def __init__(self, num_classes: int = None, mask_modifies_prediction: bool = True, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ) -> None: super().__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn) self.num_classes = num_classes self.mask_modifies_prediction = mask_modifies_prediction # use state save self.add_state("tps", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
[docs] def update(self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor = None, **kwargs) -> None: # take into account the boundary pixels like done in the offical evaluator # https://github.com/DIVA-DIA/DIVA_Layout_Analysis_Evaluator/blob/87a11ede232f8fb490401a382b8764697b65ea8d/src/main/java/ch/unifr/LayoutAnalysisEvaluator.java#L225 if mask is not None: mask_and_bg_predicted = torch.logical_and(mask, torch.eq(pred, 0)) if self.mask_modifies_prediction: pred = pred.clone() pred[mask_and_bg_predicted] = target[mask_and_bg_predicted] else: target = target.clone() target[mask_and_bg_predicted] = pred[mask_and_bg_predicted] hist = torch.zeros((self.num_classes, self.num_classes)).type_as(target) for lt, lp in zip(target, pred): try: # the images all have the same size hist = torch.add(hist, self._fast_hist(lt.flatten(), lp.flatten(), self.num_classes)) except ValueError: # the images have different sizes hist = torch.add(hist, self._fast_hist([l.flatten() for l in lt].flatten(), [l.flatten() for l in lp].flatten(), self.num_classes)) with np.errstate(divide='ignore', invalid='ignore'): self.tps = torch.add(self.tps, torch.diag(hist)) self.total = torch.add(self.total, hist.sum(axis=1) + hist.sum(axis=0) - torch.diag(hist))
[docs] def compute(self) -> Any: res = torch.div(self.tps.float(), self.total) return res[~res.isnan()].mean()
@staticmethod def _fast_hist(label_true: Tensor, label_pred: Tensor, n_class: int): """ Creates a Historgram in a fash fashion taken adventage of the hardware. Inspired from `https://github.com/wkentaro/pytorch-fcn`_. :param label_true: matrix (batch size x H x W) contains the true class labels for each pixel :type label_true: torch.Tensor :param label_pred: matrix (batch size x H x W) contains the predicted class for each pixel :type label_pred: torch.Tensor :param n_class: int number possible classes :type n_class: int :return histogram :rtype: torch.Tensor """ mask = torch.bitwise_and(torch.ge(label_true, 0), torch.lt(label_true, n_class)) hist = torch.bincount( torch.add(torch.mul(n_class, label_true[mask]), label_pred[mask]), minlength=n_class ** 2).reshape(n_class, n_class) return hist