datamodules.RotNet package

Subpackages

Submodules

datamodules.RotNet.datamodule_cropped module

class RotNetDivaHisDBDataModuleCropped(data_dir: str, data_folder_name: str, selection_train: Optional[Union[int, List[str]]] = None, selection_val: Optional[Union[int, List[str]]] = None, selection_test: Optional[Union[int, List[str]]] = None, crop_size: int = 256, num_workers: int = 4, batch_size: int = 8, shuffle: bool = True, drop_last: bool = True)[source]

Bases: AbstractDatamodule

Datamodule implementation of the RoNet paper of Gidaris et al.. This datamodule is used for the DivaHisDB dataset in a cropped setup.

The structure of the folder should be as follows:

data_dir
├── train_folder_name
│   ├── data_folder_name
│   │   ├── original_image_name_1
│   │   │   ├── image_crop_1.png
│   │   │   ├── ...
│   │   │   └── image_crop_N.png
│   │   └──original_image_name_N
│   │       ├── image_crop_1.png
│   │       ├── ...
│   │       └── image_crop_N.png
│   └── gt_folder_name
│       ├── original_image_name_1
│       │   ├── image_crop_1.png
│       │   ├── ...
│       │   └── image_crop_N.png
│       └──original_image_name_N
│           ├── image_crop_1.png
│           ├── ...
│           └── image_crop_N.png
├── validation_folder_name
│   ├── data_folder_name
│   │   ├── original_image_name_1
│   │   │   ├── image_crop_1.png
│   │   │   ├── ...
│   │   │   └── image_crop_N.png
│   │   └──original_image_name_N
│   │       ├── image_crop_1.png
│   │       ├── ...
│   │       └── image_crop_N.png
│   └── gt_folder_name
│       ├── original_image_name_1
│       │   ├── image_crop_1.png
│       │   ├── ...
│       │   └── image_crop_N.png
│       └──original_image_name_N
│           ├── image_crop_1.png
│           ├── ...
│           └── image_crop_N.png
└── test_folder_name
    ├── data_folder_name
    │   ├── original_image_name_1
    │   │   ├── image_crop_1.png
    │   │   ├── ...
    │   │   └── image_crop_N.png
    │   └──original_image_name_N
    │       ├── image_crop_1.png
    │       ├── ...
    │       └── image_crop_N.png
    └── gt_folder_name
        ├── original_image_name_1
        │   ├── image_crop_1.png
        │   ├── ...
        │   └── image_crop_N.png
        └──original_image_name_N
            ├── image_crop_1.png
            ├── ...
            └── image_crop_N.png
Parameters:
  • data_dir (str) – Path to root dir of the dataset (folder containing the train/val/test folder)

  • data_folder_name (str) – Name of the folder containing the train/val/test folder

  • selection_train (Optional[Union[int, List[str]]]) – Selection of the train set. Can be either a list of strings or an integer. If it is a list of strings, it should contain the names of the images to be used. If it is an integer, it should be the number of images to be used. If None, all images are used.

  • selection_val (Optional[Union[int, List[str]]]) – Selection of the validation set. Can be either a list of strings or an integer. If it is a list of strings, it should contain the names of the images to be used. If it is an integer, it should be the number of images to be used. If None, all images are used.

  • selection_test (Optional[Union[int, List[str]]]) – Selection of the test set. Can be either a list of strings or an integer. If it is a list of strings, it should contain the names of the images to be used. If it is an integer, it should be the number of images to be used. If None, all images are used.

  • crop_size (int) – Size of the crop to be used

  • num_workers (int) – Number of workers to be used for loading the data

  • batch_size (int) – Batch size to be used

  • shuffle (bool) – Whether to shuffle the data

  • drop_last (bool) – Whether to drop the last batch

setup(stage: Optional[str] = None)[source]

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
test_dataloader(*args, **kwargs) Union[DataLoader, List[DataLoader]][source]

Implement one or multiple PyTorch DataLoaders for testing.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying testing samples.

Example:

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Note

In the case where you return multiple test dataloaders, the test_step() will have an argument dataloader_idx which matches the order here.

train_dataloader(*args, **kwargs) DataLoader[source]

Implement one or more PyTorch DataLoaders for training.

Returns:

A collection of torch.utils.data.DataLoader specifying training samples. In the case of multiple dataloaders, please see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example:

# single dataloader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader

# multiple dataloaders, return as list
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a list of tensors: [batch_mnist, batch_cifar]
    return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
    return {'mnist': mnist_loader, 'cifar': cifar_loader}
val_dataloader(*args, **kwargs) Union[DataLoader, List[DataLoader]][source]

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples:

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.

Module contents