Source code for datamodules.IndexedFormats.datasets.full_page_dataset

"""
Load a dataset of historic documents by specifying the folder where its located.
"""

# Utils
from pathlib import Path
from typing import List, Tuple, Union, Optional

import numpy as np
import torch
import torch.utils.data as data
from PIL import Image
from omegaconf import ListConfig
from torch import is_tensor
from torchvision.datasets.folder import pil_loader, has_file_allowed_extension
from torchvision.transforms import ToTensor

from src.datamodules.utils.misc import ImageDimensions, selection_validation, pil_loader_gif, get_output_file_list
from src.utils import utils

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.gif')
GT_EXTENSION = ('.gif')

log = utils.get_logger(__name__)


[docs]class DatasetIndexed(data.Dataset): """A dataset where the images are arranged in this way: root/gt/xxx.gif root/gt/xxy.gif root/gt/xxz.gif root/data/xxx.png root/data/xxy.png root/data/xxz.png And the ground truth is represented in an index format like GIF. :param path: Path to the dataset :type path: Path :param data_folder_name: Name of the folder where the data is located :type data_folder_name: str :param gt_folder_name: Name of the folder where the ground truth is located :type gt_folder_name: str :param image_dims: Image dimensions of the dataset :type image_dims: ImageDimensions :param is_test: Flag to indicate if the dataset is used for testing :type is_test: bool :param selection: Selection of the dataset, can be an integer or a list of strings :type selection: Optional[Union[int, List[str]]] :param image_transform: Transformations that are applied to the image :type image_transform: Optional[Callable] """ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, image_dims: ImageDimensions, is_test=False, selection: Optional[Union[int, List[str]]] = None, image_transform=None) -> None: """ Constructor method for the DatasetIndexed class. """ self.path = path self.data_folder_name = data_folder_name self.gt_folder_name = gt_folder_name self.selection = selection self.image_dims = image_dims # transformations self.image_transform = image_transform self.is_test = is_test # List of tuples that contain the path to the gt and image that belong together self.img_gt_path_list = self.get_img_gt_path_list(path, data_folder_name=self.data_folder_name, gt_folder_name=self.gt_folder_name, selection=self.selection) if is_test: self.image_path_list = [img_gt_path[0] for img_gt_path in self.img_gt_path_list] self.output_file_list = get_output_file_list(image_path_list=self.image_path_list) self.num_samples = len(self.img_gt_path_list) if self.num_samples == 0: raise RuntimeError(f"Found 0 images in: {path} \n " f"Supported image extensions are: {' '.join(IMG_EXTENSIONS)}\n" f"Supported ground truth extensions are: {' '.join(GT_EXTENSION)}") def __len__(self): """ This function returns the length of an epoch so the data loader knows when to stop. The length is different during train/val and test, because we process the whole image during testing, and only sample from the images during train/val. """ return self.num_samples def __getitem__(self, index: int) -> Union[ Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, int]]: data_img, gt_img = self._load_data_and_gt(index=index) img, gt = self._apply_transformation(data_img, gt_img) assert img.shape[-2:] == gt.shape[-2:] if self.is_test: return img, gt, index else: return img, gt def _load_data_and_gt(self, index: int) -> Tuple[Image.Image, Image.Image]: """ Load the data and the ground truth. :param index: Index of the image :type index: int :return: Data and ground truth as PIL Image :rtype: Tuple[Image.Image, Image.Image] """ data_img = pil_loader(str(self.img_gt_path_list[index][0])) gt_img = pil_loader_gif(self.img_gt_path_list[index][1]) assert data_img.height == self.image_dims.height and data_img.width == self.image_dims.width assert gt_img.height == self.image_dims.height and gt_img.width == self.image_dims.width return data_img, gt_img def _apply_transformation(self, img: Image, gt: Image) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply transformations to the image and the ground truth. :param img: Original image :type img: Image :param gt: Ground truth as an image :type gt: Image :return: Original and ground Truth as Tensor with applied transformations :rtype: Tuple[torch.Tensor, torch.Tensor] """ if self.image_transform is not None: # perform transformations img, _ = self.image_transform(img, gt) if not is_tensor(img): img = ToTensor()(img) # remove first dim s.t. gt is just w x h gt_np = np.asarray(gt) if len(gt_np.shape) == 3: gt_np = np.squeeze(gt_np, axis=0) gt = torch.tensor(gt_np, dtype=torch.long) return img, gt
[docs] @staticmethod def get_img_gt_path_list(directory: Path, data_folder_name: str, gt_folder_name: str, selection: Optional[Union[int, List[str]]] = None) \ -> List[Tuple[Path, Path]]: """ Structure of the folder directory/data/FILE_NAME.png directory/gt/FILE_NAME.gif :param directory: Path to the dataset :type directory: Path :param data_folder_name: Name of the folder where the data is located :type data_folder_name: str :param gt_folder_name: Name of the folder where the ground truth is located :type gt_folder_name: str :param selection: Selection of the dataset, can be an integer or a list of strings :type selection: Optional[Union[int, List[str]]] :return: List of tuples with the path to the data and the ground truth :rtype: List[Tuple[Path, Path]] :raises ValueError: If the folder data or gt is not found in the directory """ paths = [] directory = directory.expanduser() path_data_root = directory / data_folder_name path_gt_root = directory / gt_folder_name if not (path_data_root.is_dir() or path_gt_root.is_dir()): raise ValueError("folder data or gt not found in " + str(directory)) # get all files sorted files_in_data_root = sorted(path_data_root.iterdir()) # check the selection parameter if selection: selection = selection_validation(files_in_data_root, selection, full_page=True) counter = 0 # Counter for subdirectories, needed for selection parameter for path_data_file, path_gt_file in zip(sorted(files_in_data_root), sorted(path_gt_root.iterdir())): counter += 1 if selection: if isinstance(selection, int) and counter > selection: break elif isinstance(selection, ListConfig) or isinstance(selection, list): if path_data_file.stem not in selection: continue assert has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) == \ has_file_allowed_extension(path_gt_file.name, GT_EXTENSION), \ 'get_img_gt_path_list(): image file aligned with non-image file' paths.append((path_data_file, path_gt_file)) return paths