from glob import glob
from pathlib import Path
from typing import List, Tuple
import torch.utils.data as data
from torch import is_tensor, Tensor
from torchvision.datasets.folder import pil_loader
from torchvision.transforms import ToTensor
from PIL import Image
from src.datamodules.utils.misc import ImageDimensions, get_output_file_list
from src.utils import utils
log = utils.get_logger(__name__)
[docs]class DatasetPredict(data.Dataset):
"""
Dataset class for the prediction of the test set. It takes a folder of images and creates the prediction of these images.
:param image_path_list: list of image paths
:type image_path_list: List[str]
:param image_dims: image dimensions
:type image_dims: ImageDimensions
:param image_transform: image transformation
:type image_transform: Callable
:param target_transform: target transformation
:type target_transform: Callable
:param twin_transform: twin transformation
:type twin_transform: Callable
"""
def __init__(self, image_path_list: List[str], image_dims: ImageDimensions,
image_transform=None, target_transform=None, twin_transform=None):
"""
Constructor method for the DatasetPredict class.
"""
self._raw_image_path_list = list(image_path_list)
self.image_path_list = self.expend_glob_path_list(glob_path_list=self._raw_image_path_list)
self.output_file_list = get_output_file_list(image_path_list=self.image_path_list)
self.image_dims = image_dims
# transformations
self.image_transform = image_transform
self.target_transform = target_transform
self.twin_transform = twin_transform
self.num_samples = len(self.image_path_list)
if self.num_samples == 0:
raise RuntimeError('List of image paths is empty!')
def __len__(self) -> int:
"""
This function returns the length of an epoch so the data loader knows when to stop.
The length is different during train/val and test, because we process the whole image during testing,
and only sample from the images during train/val.
:returns: number of samples
:rtype: int
"""
return self.num_samples
def __getitem__(self, index: int) -> Tuple[Tensor, int]:
"""
This function returns the data and the gt for a given index.
:param index: index of the sample
:type index: int
:return: data and gt
:rtype: Tuple[Tensor, int]
"""
data_img = self._load_data_and_gt(index=index)
data_tensor = self._apply_transformation(img=data_img)
return data_tensor, index
def _load_data_and_gt(self, index: int) -> Image:
"""
Loads the data and gt from the disk.
:param index: index of the sample
:type index: int
:returns: The image at the given index
:rtype: Image
"""
data_img = pil_loader(self.image_path_list[index])
assert data_img.height == self.image_dims.height and data_img.width == self.image_dims.width
return data_img
def _apply_transformation(self, img: Image) -> Tensor:
"""
Applies the transformations that have been defined in the setup (setup.py). If no transformations
have been defined, the PIL image is returned instead.
:param img: Original image to apply transformation on
:type img: PIL image
:returns: transformed image
:rtype: Tensor
"""
if self.image_transform is not None:
# perform transformations
img, _ = self.image_transform(img, None)
if not is_tensor(img):
img = ToTensor()(img)
return img
[docs] @staticmethod
def expend_glob_path_list(glob_path_list: List[str]) -> List[Path]:
"""
Expends the glob path list to a list of paths.
:param glob_path_list: list of glob paths
:type glob_path_list: List[str]
:returns: list of paths
:rtype: List[Path]
"""
output_list = []
for glob_path in glob_path_list:
for s in sorted(glob(glob_path)):
path = Path(s)
if path not in output_list:
output_list.append(Path(s))
return output_list