Source code for datamodules.DivaHisDB.utils.image_analytics

# Utils
import logging
import os
from pathlib import Path
from typing import Tuple, Any, Dict, List, Union

import numpy as np
# Torch related stuff
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from PIL import Image
from numpy import ndarray, dtype

from src.datamodules.utils.misc import save_json, check_missing_analytics
from src.datamodules.utils.image_analytics import compute_mean_std


[docs]def get_analytics(input_path: Path, data_folder_name: str, gt_folder_name: str, get_gt_data_paths_func) \ -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Get the analytics for the dataset. If the analytics file is not present, it will be computed and saved. :param input_path: Path to the root of the dataset :type input_path: Path :param data_folder_name: Name of the folder containing the data :type data_folder_name: str :param gt_folder_name: Name of the folder containing the ground truth :type gt_folder_name: str :param get_gt_data_paths_func: Function to get the paths to the data and ground truth :type get_gt_data_paths_func: Callable :return: Tuple of analytics for the data and ground truth :rtype: Tuple[Dict[str, Any], Dict[str, Any]] """ expected_keys_data = ['mean', 'std'] expected_keys_gt = ['class_weights', 'class_encodings'] analytics_path_data = input_path / f'analytics.data.{data_folder_name}.json' analytics_path_gt = input_path / f'analytics.gt.hisDB.{gt_folder_name}.json' analytics_data, missing_analytics_data = check_missing_analytics(analytics_path_data, expected_keys_data) analytics_gt, missing_analytics_gt = check_missing_analytics(analytics_path_gt, expected_keys_gt) if not (missing_analytics_data or missing_analytics_gt): return analytics_data, analytics_gt train_path = input_path / 'train' gt_data_path_list = get_gt_data_paths_func(train_path, data_folder_name=data_folder_name, gt_folder_name=gt_folder_name) file_names_data = np.asarray([str(item[0]) for item in gt_data_path_list]) file_names_gt = np.asarray([str(item[1]) for item in gt_data_path_list]) if missing_analytics_data: mean, std = compute_mean_std(file_names=file_names_data) analytics_data = {'mean': mean.tolist(), 'std': std.tolist()} # save json save_json(analytics_data, analytics_path_data) if missing_analytics_gt: # Measure weights for class balancing logging.info('Measuring class weights') # create a list with all gt file paths class_weights, class_encodings = _get_class_frequencies_weights_segmentation_hisdb(gt_images=file_names_gt) analytics_gt = {'class_weights': class_weights.tolist(), 'class_encodings': class_encodings.tolist()} # save json save_json(analytics_gt, analytics_path_gt) return analytics_data, analytics_gt
[docs]def get_class_weights(input_folder, workers=4) -> List[float]: """ Get the weights proportional to the inverse of their class frequencies. The vector sums up to 1 :param input_folder: Path to the dataset folder (see above for details) :type input_folder: Path :param workers: Number of workers to use for the mean/std computation :type workers: int :return: The weights vector as a 1D array normalized (sum up to 1) :rtype: List[float] """ # Sanity check on the folder if not os.path.isdir(input_folder): logging.error(f"Folder {input_folder} does not exist") raise FileNotFoundError # Load the dataset ds = datasets.ImageFolder(input_folder, transform=transforms.Compose([transforms.ToTensor()])) logging.info('Begin computing class frequencies weights') if hasattr(ds, 'targets'): labels = ds.targets elif hasattr(ds, 'labels'): labels = ds.labels else: # This is a fail-safe net in case a custom dataset changed the name of the internal variables data_loader = torch.utils.data.DataLoader(ds, batch_size=1, num_workers=workers) labels = [] for target, label in data_loader: labels.append(label) labels = np.concatenate(labels).reshape(len(ds)) class_support = np.unique(labels, return_counts=True)[1] class_frequencies = class_support / len(labels) # Class weights are the inverse of the class frequencies class_weights = 1 / class_frequencies # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) class_weights /= class_weights.sum() logging.info('Finished computing class frequencies weights ') logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') logging.info(f'Class weights (rounded): {np.around(class_weights * 100, decimals=2)}') return class_weights
def _get_class_frequencies_weights_segmentation_hisdb(gt_images: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ Get the weights proportional to the inverse of their class frequencies. The vector sums up to 1 :param gt_images: Path to all ground truth images, which contain the pixel-wise label :type gt_images: np.ndarray :return: The weights vector as a 1D array normalized (sum up to 1) :rtype: Tuple[Union[Union[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]], Any], ndarray] """ logging.info('Begin computing class frequencies weights') total_num_pixels = 0 label_counter = {} for path in gt_images: img = np.array(Image.open(path))[:, :, 2].flatten() total_num_pixels += len(img) for i, j in zip(*np.unique(img, return_counts=True)): label_counter[i] = label_counter.get(i, 0) + j classes = np.array(sorted(label_counter.keys())) num_samples_per_class = np.array([label_counter[k] for k in classes]) class_frequencies = (num_samples_per_class / total_num_pixels) logging.info('Finished computing class frequencies weights') logging.info(f'Class frequencies (rounded): {np.around(class_frequencies * 100, decimals=2)}') # Normalize vector to sum up to 1.0 (in case the Loss function does not do it) return (1 / num_samples_per_class) / ((1 / num_samples_per_class).sum()), classes