datamodules.IndexedFormats package

Subpackages

Submodules

datamodules.IndexedFormats.datamodule module

class DataModuleIndexed(data_dir: str, data_folder_name: str, gt_folder_name: str, train_folder_name: str = 'train', val_folder_name: str = 'val', test_folder_name: str = 'test', pred_file_path_list: Optional[List[str]] = None, selection_train: Optional[Union[int, List[str]]] = None, selection_val: Optional[Union[int, List[str]]] = None, selection_test: Optional[Union[int, List[str]]] = None, num_workers: int = 4, batch_size: int = 8, shuffle: bool = True, drop_last: bool = True)[source]

Bases: AbstractDatamodule

DataModule for datasets where the ground truth is in an index file format encoded (e.g., GIF, TIF).

The folder structure is as follows:

data_dir
├── train_folder_name
│   ├── data_folder_name
│   │   ├── image1.png
│   │   ├── ...
│   │   └── imageN.png
│   └── gt_folder_name
│       ├── image1.png
│       ├── ...
│       └── imageN.png
├── val_folder_name
│   ├── data_folder_name
│   │   ├── image1.png
│   │   ├── ...
│   │   └── imageN.png
│   └── gt_folder_name
│       ├── image1.png
│       ├── ...
│       └── imageN.png
└── test_folder_name
    ├── data_folder_name
    │   ├── image1.png
    │   ├── ...
    │   └── imageN.png
    └── gt_folder_name
        ├── image1.png
        ├── ...
        └── imageN.png
Parameters:
  • data_dir (Path) – Path to dataset folder (train / val / test)

  • data_folder_name (str) – name of the folder inside of the train/val/test that contains the images

  • gt_folder_name (str) – name of the folder in train/val/test containing the ground truth images

  • train_folder_name (str) – name of the train folder

  • val_folder_name (str) – name of the validation folder

  • test_folder_name (str) – name of the test folder

  • pred_file_path_list (List[str]) – list of file paths to predict

  • selection_train (Optional[Union[int, List[str]]]) – number of files or list of files that should be taken into account for the train split.

  • selection_val (Optional[Union[int, List[str]]]) – number of files or list of files that should be taken into account for the validation split.

  • selection_test (Optional[Union[int, List[str]]]) – number of files or list of files that should be taken into account for the test split.

  • num_workers (int) – number of workers for the dataloader

  • batch_size (int) – batch size

  • shuffle (bool) – shuffle the data

  • drop_last (bool) – drop the last batch if it is smaller than the batch size

get_output_filename_predict(index: int) str[source]

Returns the original filename of the doc image. You can just use this during prediction!

Parameters:

index (int) – index of the image

Raises:

ValueError – if the method is called during training

Returns:

filename of the image

Return type:

str

get_output_filename_test(index: int) str[source]

Returns the original filename of the doc image. You can just use this during testing!

Parameters:

index (int) – index of the image

Raises:

ValueError – if the method is called during training

Returns:

filename of the image

Return type:

str

predict_dataloader() Union[DataLoader, List[DataLoader]][source]

Implement one or multiple PyTorch DataLoaders for prediction.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

Note

In the case where you return multiple prediction dataloaders, the predict_step() will have an argument dataloader_idx which matches the order here.

setup(stage: Optional[str] = None) None[source]

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
test_dataloader(*args, **kwargs) Union[DataLoader, List[DataLoader]][source]

Implement one or multiple PyTorch DataLoaders for testing.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying testing samples.

Example:

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Note

In the case where you return multiple test dataloaders, the test_step() will have an argument dataloader_idx which matches the order here.

train_dataloader(*args, **kwargs) DataLoader[source]

Implement one or more PyTorch DataLoaders for training.

Returns:

A collection of torch.utils.data.DataLoader specifying training samples. In the case of multiple dataloaders, please see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example:

# single dataloader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader

# multiple dataloaders, return as list
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a list of tensors: [batch_mnist, batch_cifar]
    return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
    return {'mnist': mnist_loader, 'cifar': cifar_loader}
val_dataloader(*args, **kwargs) Union[DataLoader, List[DataLoader]][source]

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples:

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.

Module contents