datamodules package
Subpackages
- datamodules.Classification package
- datamodules.DivaHisDB package
- datamodules.IndexedFormats package
- datamodules.RGB package
- datamodules.RolfFormat package
- datamodules.RotNet package
- datamodules.utils package
- Submodules
- datamodules.utils.dataset_predict module
- datamodules.utils.exceptions module
- datamodules.utils.functional module
- datamodules.utils.image_analytics module
- datamodules.utils.misc module
- datamodules.utils.output_tools module
- datamodules.utils.single_transforms module
- datamodules.utils.twin_transforms module
- datamodules.utils.wrapper_transforms module
- Module contents
Submodules
datamodules.base_datamodule module
- class AbstractDatamodule[source]
Bases:
LightningDataModule
Abstract class for all datamodules. All datamodules should inherit from this class. It provides some basic functionality like checking the number of samples and the number of classes. Also, it provides a resolver for the datamodule object itself, so that it can be used in the config. The class variable dims must be set in the subclass.
- static check_min_num_samples(num_devices: int, batch_size_input: int, num_samples: int, data_split: str, drop_last: bool)[source]
Checks if the number of samples is sufficient for the given batch size and number of devices.
- Parameters:
num_devices (int) – The number of devices
batch_size_input (int) – The batch size
num_samples (int) – The number of samples
data_split (str) – The data split (train, val, test)
drop_last (bool) – Whether to drop the last batch if it is smaller than the batch size
- Raises:
ValueError – If the number of samples is not sufficient
- setup(stage: Optional[str] = None) None [source]
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit'
,'validate'
,'test'
, or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)