Source code for datamodules.RGB.datasets.cropped_dataset

"""
Load a dataset of historic documents by specifying the folder where its located.
"""

# Utils
import re
from pathlib import Path
from typing import List, Tuple, Union, Optional, Any

import torch.utils.data as data
from PIL import Image
from omegaconf import ListConfig
from torch import is_tensor, Tensor
from torchvision.datasets.folder import pil_loader, has_file_allowed_extension
from torchvision.transforms import ToTensor

from src.datamodules.utils.misc import selection_validation
from src.utils import utils

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.gif')

log = utils.get_logger(__name__)


[docs]class CroppedDatasetRGB(data.Dataset): """A generic data loader where the images are arranged in this way: :: path ├── data_folder_name │ ├── original_image_name_1 │ │ ├── image_crop_1.png │ │ ├── ... │ │ └── image_crop_N.png │ └──original_image_name_N │ ├── image_crop_1.png │ ├── ... │ └── image_crop_N.png └── gt_folder_name ├── original_image_name_1 │ ├── image_crop_1.png │ ├── ... │ └── image_crop_N.png └──original_image_name_N ├── image_crop_1.png ├── ... └── image_crop_N.png :param path: Path to dataset folder (train / val / test) :type path: Path :param data_folder_name: name of the folder that contains the data :type data_folder_name: str :param gt_folder_name: name of the folder that contains the ground truth :type gt_folder_name: str :param selection: selection of the data, defaults to None :type selection: Optional[Union[int, List[str]]], optional :param is_test: flag to indicate if the dataset is used for testing, defaults to False :type is_test: bool, optional :param image_transform: image transformation, defaults to None :type image_transform: callable, optional :param target_transform: target transformation, defaults to None :type target_transform: callable, optional :param twin_transform: twin transformation, defaults to None :type twin_transform: callable, optional """ def __init__(self, path: Path, data_folder_name: str, gt_folder_name: str, selection: Optional[Union[int, List[str]]] = None, is_test: bool = False, image_transform: callable = None, target_transform: callable = None, twin_transform: callable = None): """ Constructor method for the class: `CroppedDatasetRGB`. """ self.path = path self.data_folder_name = data_folder_name self.gt_folder_name = gt_folder_name self.selection = selection # transformations self.image_transform = image_transform self.target_transform = target_transform self.twin_transform = twin_transform self.is_test = is_test # List of tuples that contain the path to the gt and image that belong together self.img_paths_per_page = self.get_gt_data_paths(path, data_folder_name=self.data_folder_name, gt_folder_name=self.gt_folder_name, selection=self.selection) self.num_samples = len(self.img_paths_per_page) if self.num_samples == 0: raise RuntimeError("Found 0 images in subfolders of: {} \n Supported image extensions are: {}".format( path, ",".join(IMG_EXTENSIONS))) def __len__(self): """ This function returns the length of an epoch so the data loader knows when to stop. The length is different during train/val and test, because we process the whole image during testing, and only sample from the images during train/val. """ return self.num_samples def __getitem__(self, index: int) -> Union[Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor]]: if self.is_test: return self._get_test_items(index=index) else: return self._get_train_val_items(index=index) def _get_train_val_items(self, index: int) -> Tuple[Tensor, Tensor]: """ Returns the image and the ground truth image at the given index. If transformations have been defined, they are applied here. :param index: index of the image to return :type index: int :return: The image and the corresponding ground truth image with transformations applied :rtype: Tuple[Tensor, Tensor] """ data_img, gt_img = self._load_data_and_gt(index=index) img, gt = self._apply_transformation(data_img, gt_img) return img, gt def _get_test_items(self, index: int) -> Tuple[Tensor, Tensor, int]: """ Returns the image and the ground truth image at the given index for testing. If transformations have been defined, they are applied here. Additionally, to the :method: `_get_train_val_items`, the index of the image is returned. :param index: index of the image to return :type index: int :return: The image and the corresponding ground truth image with transformations applied :rtype: Tuple[Tensor, Tensor, int] """ data_img, gt_img = self._load_data_and_gt(index=index) img, gt = self._apply_transformation(data_img, gt_img) return img, gt, index def _load_data_and_gt(self, index: int) -> Tuple[Image.Image, Image.Image]: """ Loads the image and the ground truth image at the given index. :param index: index of the image to return :type index: int :return: The image and the corresponding ground truth image :rtype: Tuple[Image.Image, Image.Image] """ data_img = pil_loader(self.img_paths_per_page[index][0]) gt_img = pil_loader(self.img_paths_per_page[index][1]) return data_img, gt_img def _apply_transformation(self, img: Union[Image.Image, Tensor], gt: Union[Image.Image, Tensor]) \ -> Tuple[Tensor, Tensor]: """ Applies the transformations that have been defined in the setup (setup.py). If no transformations have been defined, the PIL image is returned instead. :param img: The original image to apply the transformations to :type img: Union[Image.Image, Tensor] :param gt: The corresponding ground truth image to apply the transformations to :type img: Union[Image.Image, Tensor] :return: The transformed image and the transformed ground truth image :rtype: Tuple[Tensor, Tensor] """ if self.twin_transform is not None and not self.is_test: img, gt = self.twin_transform(img, gt) if self.image_transform is not None: # perform transformations img, gt = self.image_transform(img, gt) if not is_tensor(img): img = ToTensor()(img) if not is_tensor(gt): gt = ToTensor()(gt) if self.target_transform is not None: img, gt = self.target_transform(img, gt) return img, gt
[docs] @staticmethod def get_gt_data_paths(directory: Path, data_folder_name: str, gt_folder_name: str, selection: Optional[Union[int, List[str]]] = None) \ -> List[Tuple[Any, Any, str, Any]]: """ Returns a list of tuples that contain the path to the gt and image that belong together. Structure of the folder directory/data/ORIGINAL_FILENAME/FILE_NAME_X_Y.png directory/gt/ORIGINAL_FILENAME/FILE_NAME_X_Y.png :param directory: Path to dataset folder (train / val / test) :type directory: Path :param data_folder_name: name of the folder that contains the data :type data_folder_name: str :param gt_folder_name: name of the folder that contains the ground truth :type gt_folder_name: str :param selection: selection of the data, defaults to None :type selection: Optional[Union[int, List[str]]], optional :return: List of tuples that contain the path to the gt and image that belong together :rtype: List[Tuple[Any, Any, str, Any]] """ paths = [] directory = directory.expanduser() path_data_root = directory / data_folder_name path_gt_root = directory / gt_folder_name if not (path_data_root.is_dir() or path_gt_root.is_dir()): log.error("folder data or gt not found in " + str(directory)) # get all subitems (and files) sorted subitems = sorted(path_data_root.iterdir()) # check the selection parameter if selection: selection = selection_validation(subitems, selection, full_page=False) counter = 0 # Counter for subdirectories, needed for selection parameter for path_data_subdir in subitems: if not path_data_subdir.is_dir(): if has_file_allowed_extension(path_data_subdir.name, IMG_EXTENSIONS): log.warning("image file found in data root: " + str(path_data_subdir)) continue counter += 1 if selection: if isinstance(selection, int): if counter > selection: break elif isinstance(selection, ListConfig) or isinstance(selection, list): if path_data_subdir.name not in selection: continue path_gt_subdir = path_gt_root / path_data_subdir.stem assert path_gt_subdir.is_dir() for path_data_file, path_gt_file in zip(sorted(path_data_subdir.iterdir()), sorted(path_gt_subdir.iterdir())): assert has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) == \ has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS), \ 'get_img_gt_path_list(): image file aligned with non-image file' if has_file_allowed_extension(path_data_file.name, IMG_EXTENSIONS) and \ has_file_allowed_extension(path_gt_file.name, IMG_EXTENSIONS): assert path_data_file.stem == path_gt_file.stem, \ 'get_img_gt_path_list(): mismatch between data filename and gt filename' paths.append((path_data_file, path_gt_file, path_data_subdir.stem, path_data_file.stem)) return paths