Source code for callbacks.wandb_callbacks

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_only

from src.utils import utils


[docs]def get_wandb_logger(trainer: Trainer) -> WandbLogger: """ Get WandbLogger from trainer or loggers. :param trainer: PyTorch Lightning trainer. :type trainer: Trainer :returns: WandbLogger :rtype: WandbLogger :raises ValueError: If WandbLogger was not found """ if isinstance(trainer.logger, WandbLogger): return trainer.logger for logger in trainer.loggers: if isinstance(logger, WandbLogger): return logger raise ValueError( "You are using wandb related callback, but WandbLogger was not found for some reason..." )
[docs]class WatchModelWithWandb(Callback): """ Make WandbLogger watch model at the beginning of the run. :param log_category: Category of the model to log ("gradients", "parameters", "all", or None). Default: "gradients". :type log_category: str :param log_freq: How often to log the model. Default: 100. :type log_freq: int """ def __init__(self, log_category: str = "gradients", log_freq: int = 100) -> None: self.log_category = log_category self.log_freq = log_freq
[docs] @rank_zero_only def on_train_start(self, trainer, pl_module): try: logger = get_wandb_logger(trainer=trainer) logger.watch(model=pl_module.model, log=self.log_category, log_freq=self.log_freq) except ValueError: logger = utils.get_logger(__name__) logger.error('No wandb logger found. WatchModelWithWandb callback will not do anything.')