datamodules package

Subpackages

Submodules

datamodules.base_datamodule module

class AbstractDatamodule[source]

Bases: LightningDataModule

Abstract class for all datamodules. All datamodules should inherit from this class. It provides some basic functionality like checking the number of samples and the number of classes. Also, it provides a resolver for the datamodule object itself, so that it can be used in the config. The class variable dims must be set in the subclass.

static check_min_num_samples(num_devices: int, batch_size_input: int, num_samples: int, data_split: str, drop_last: bool)[source]

Checks if the number of samples is sufficient for the given batch size and number of devices.

Parameters:
  • num_devices (int) – The number of devices

  • batch_size_input (int) – The batch size

  • num_samples (int) – The number of samples

  • data_split (str) – The data split (train, val, test)

  • drop_last (bool) – Whether to drop the last batch if it is smaller than the batch size

Raises:

ValueError – If the number of samples is not sufficient

setup(stage: Optional[str] = None) None[source]

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)

Module contents