callbacks package

Submodules

callbacks.general_callbacks module

class TimeTracker[source]

Bases: Callback

A callback to track the time taken for training and testing. It logs the time taken for each epoch and the total time taken for training and testing.

on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)[source]

Called when the test epoch begins.

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, either:

  1. Implement training_epoch_end in the LightningModule and access outputs via the module OR

  2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.

on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

on_train_start(trainer, pl_module)[source]

Called when the train begins.

callbacks.model_callbacks module

class CheckBackboneHeaderCompatibility[source]

Bases: Callback

Checks if the backbone and the header are compatible. This is checked by passing a random tensor through the backbone and the header. If the backbone and the header are not compatible, the program is terminated.

setup(trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) None[source]

Called when fit, validate, test, predict, or tune begins.

class SaveModelStateDictAndTaskCheckpoint(backbone_filename: Optional[str] = 'backbone', header_filename: Optional[str] = 'header', **kwargs)[source]

Bases: ModelCheckpoint

Saves the neural network weights into a pth file. It produces a file for each the encoder and the header. The encoder file is named after the backbone_filename and the header file after the header_filename. The backbone_filename and the header_filename can be specified in the constructor.

Parameters:
  • backbone_filename (str, optional) – Filename of the backbone checkpoint

  • header_filename (str, optional) – Filename of the header checkpoint

callbacks.wandb_callbacks module

class WatchModelWithWandb(log_category: str = 'gradients', log_freq: int = 100)[source]

Bases: Callback

Make WandbLogger watch model at the beginning of the run.

Parameters:
  • log_category (str) – Category of the model to log (“gradients”, “parameters”, “all”, or None). Default: “gradients”.

  • log_freq (int) – How often to log the model. Default: 100.

on_train_start(trainer, pl_module)[source]

Called when the train begins.

get_wandb_logger(trainer: Trainer) WandbLogger[source]

Get WandbLogger from trainer or loggers.

Parameters:

trainer (Trainer) – PyTorch Lightning trainer.

Returns:

WandbLogger

Return type:

WandbLogger

Raises:

ValueError – If WandbLogger was not found

Module contents