Source code for datamodules.RGB.datasets.full_page_dataset

"""
Load a dataset of historic documents by specifying the folder where its located.
"""

from dataclasses import dataclass
# Utils
from pathlib import Path
from typing import List, Tuple, Union, Optional, Any

import torch.utils.data as data
from omegaconf import ListConfig
from PIL import Image
from torch import is_tensor, Tensor
from torchvision.datasets.folder import pil_loader, has_file_allowed_extension
from torchvision.transforms import ToTensor

from src.datamodules.utils.misc import ImageDimensions, get_output_file_list, selection_validation
from src.utils import utils

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif')

log = utils.get_logger(__name__)


[docs]class DatasetRGB(data.Dataset): """A generic data loader where the images are arranged in this way: :: root/gt/xxx.png root/gt/xxy.png root/gt/xxz.png root/data/xxx.png root/data/xxy.png root/data/xxz.png :param path: path to the dataset :type path: Path :param data_folder_name: name of the folder where the data is located :type data_folder_name: str :param gt_folder_name: name of the folder where the ground truth is located :type gt_folder_name: str :param image_dims: dimensions of the image :type image_dims: ImageDimensions :param selection: selection of the data, can be an int or a list of strings :type selection: Optional[Union[int, List[str]]] :param is_test: flag to indicate if the dataset is used for testing :type is_test: bool, optional :param image_transform: image transformation :type image_transform: callable, optional :param target_transform: target transformation :type target_transform: callable, optional :param twin_transform: twin transformation :type twin_transform: callable, optional """ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, image_dims: ImageDimensions, selection: Optional[Union[int, List[str]]] = None, is_test: bool = False, image_transform: callable = None, target_transform: callable = None, twin_transform: callable = None, **kwargs): """ """ self.path = path self.data_folder_name = data_folder_name self.gt_folder_name = gt_folder_name self.selection = selection self.image_dims = image_dims # transformations self.image_transform = image_transform self.target_transform = target_transform self.twin_transform = twin_transform self.is_test = is_test # List of tuples that contain the path to the gt and image that belong together self.img_gt_path_list = self.get_img_gt_path_list(path, data_folder_name=self.data_folder_name, gt_folder_name=self.gt_folder_name, selection=self.selection) if is_test: self.image_path_list = [img_gt_path[0] for img_gt_path in self.img_gt_path_list] self.output_file_list = get_output_file_list(image_path_list=self.image_path_list) self.num_samples = len(self.img_gt_path_list) if self.num_samples == 0: raise RuntimeError("Found 0 images in: {} \n Supported image extensions are: {}".format( path, ",".join(IMG_EXTENSIONS))) def __len__(self): """ This function returns the length of an epoch so the data loader knows when to stop. The length is different during train/val and test, because we process the whole image during testing, and only sample from the images during train/val. """ return self.num_samples def __getitem__(self, index: int) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, str]]: """ This function returns the data and the ground truth for a given index. If the dataset is used for testing, the index is used to get the image and the ground truth. If the dataset is used for training or validation, the index is used to get the coordinates where the sliding window should be cropped. :param index: index of the image :type index: int :return: the item at the given index :rtype: Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, str]] """ if self.is_test: return self._get_test_items(index=index) else: return self._get_train_val_items(index=index) def _get_train_val_items(self, index: int) -> Tuple[Tensor, Tensor]: """ This function returns the data and the ground truth for a given index. :param index: index of the image :type index: int :return: the item at the given index :rtype: Tuple[Tensor, Tensor] """ data_img, gt_img = self._load_data_and_gt(index=index) img, gt = self._apply_transformation(data_img, gt_img) assert img.shape[-2:] == gt.shape[-2:] return img, gt def _get_test_items(self, index: int) -> Tuple[Tensor, Tensor, int]: """ This function returns the data and the ground truth for a given index. Additionally, the index is returned. :param index: index of the image :type index: int :return: the item at the given index :rtype: Tuple[Tensor, Tensor, str] """ data_img, gt_img = self._load_data_and_gt(index=index) img, gt = self._apply_transformation(data_img, gt_img) assert img.shape[-2:] == gt.shape[-2:] return img, gt, index def _load_data_and_gt(self, index: int) -> Tuple[Image.Image, Image.Image]: """ This function loads the data and the ground truth for a given index. :param index: index of the image :type index: int :return: the item at the given index :rtype: Tuple[Image.Image, Image.Image] """ data_img = pil_loader(self.img_gt_path_list[index][0]) gt_img = pil_loader(self.img_gt_path_list[index][1]) assert data_img.height == self.image_dims.height and data_img.width == self.image_dims.width assert gt_img.height == self.image_dims.height and gt_img.width == self.image_dims.width return data_img, gt_img def _apply_transformation(self, img: Union[Tensor, Image.Image], gt: Union[Tensor, Image.Image]) \ -> Tuple[Tensor, Tensor]: """ Applies the transformations that have been defined in the setup (setup.py). If no transformations have been defined, the PIL image is returned instead. :param img: Original image :type img: Union[Tensor, Image.Image] :param gt: Corresponding ground truth image :type gt: Union[Tensor, Image.Image] :return: Transformed image and ground truth :rtype: Tuple[Tensor, Tensor] """ if self.twin_transform is not None and not self.is_test: img, gt = self.twin_transform(img, gt) if self.image_transform is not None: # perform transformations img, gt = self.image_transform(img, gt) if not is_tensor(img): img = ToTensor()(img) if not is_tensor(gt) and gt is not None: gt = ToTensor()(gt) if self.target_transform is not None: img, gt = self.target_transform(img, gt) return img, gt
[docs] @staticmethod def get_img_gt_path_list(directory: Path, data_folder_name: str, gt_folder_name: str, selection: Optional[Union[int, List[str]]] = None) \ -> List[Tuple[Any, Any, Any]]: """ Returns a list of tuples that contain the path to the gt and image that belong together. Structure of the folder directory/data/ORIGINAL_FILENAME/FILE_NAME_X_Y.png directory/gt/ORIGINAL_FILENAME/FILE_NAME_X_Y.png :param directory: Path to dataset folder (train / val / test) :type directory: Path :param data_folder_name: name of the folder that contains the data :type data_folder_name: str :param gt_folder_name: name of the folder that contains the ground truth :type gt_folder_name: str :param selection: selection of the data, defaults to None :type selection: Optional[Union[int, List[str]]], optional :return: List of tuples that contain the path to the gt and image that belong together :rtype: List[Tuple[Any, Any, str, Any]] """ paths = [] directory = directory.expanduser() path_data_root = directory / data_folder_name path_gt_root = directory / gt_folder_name if not (path_data_root.is_dir() or path_gt_root.is_dir()): log.error("folder data or gt not found in " + str(directory)) # get all files sorted files_in_data_root = sorted(path_data_root.iterdir()) # check the selection parameter if selection: selection = selection_validation(files_in_data_root, selection, full_page=True) counter = 0 # Counter for subdirectories, needed for selection parameter for path_data_file, path_gt_file in zip(sorted(files_in_data_root), sorted(path_gt_root.iterdir())): counter += 1 if selection: if isinstance(selection, int): if counter > selection: break elif isinstance(selection, ListConfig) or isinstance(selection, list): if path_data_file.stem not in selection: continue assert has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) == \ has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS), \ 'get_img_gt_path_list(): image file aligned with non-image file' paths.append((path_data_file, path_gt_file, path_data_file.stem)) return paths