Source code for callbacks.general_callbacks

from time import time

from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning import Callback


[docs]class TimeTracker(Callback): """ A callback to track the time taken for training and testing. It logs the time taken for each epoch and the total time taken for training and testing. """ def __init__(self): self.start_time_train = None self.start_time_train_epoch = None self.start_time_test = None
[docs] @rank_zero_only def on_train_start(self, trainer, pl_module): self.start_time_train = time()
[docs] @rank_zero_only def on_train_epoch_start(self, trainer, pl_module): self.start_time_train_epoch = time()
[docs] @rank_zero_only def on_train_epoch_end(self, trainer, pl_module): if trainer.current_epoch == trainer.max_epochs - 1: self.log("train/total_time", time() - self.start_time_train) self.log("train/epoch_time", time() - self.start_time_train_epoch)
[docs] @rank_zero_only def on_test_epoch_start(self, trainer, pl_module): self.start_time_test = time()
[docs] @rank_zero_only def on_test_epoch_end(self, trainer, pl_module): self.log("test/total_time", time() - self.start_time_test)