Source code for tasks.utils.outputs

from typing import Dict, List

import numpy
import numpy as np
from pytorch_lightning.utilities import LightningEnum


[docs]class OutputKeys(LightningEnum): """ Enum class for the keys of the output dictionary. """ PREDICTION = 'pred' TARGET = 'target' LOG = 'logs' LOSS = 'loss' def __hash__(self): return hash(self.value)
[docs]def reduce_dict(input_dict: Dict, key_list: List) -> Dict: """ Reduce the input dictionary to only contain the keys in the key_list. :param input_dict: The dictionary to reduce :type input_dict: Dict :param key_list: List of keys to keep :type key_list: List :return: The dictionary input_dict with only the keys in the key_list :rtype: Dict """ return {key: input_dict[key] for key in key_list if key in input_dict}
[docs]def save_numpy_files(trainer, test_output_path, input_idx, output): if not hasattr(trainer.datamodule, 'get_img_name_coordinates'): raise NotImplementedError('Datamodule does not provide detailed information of the crop') for patch, idx in zip(output[OutputKeys.PREDICTION].detach().cpu().numpy(), input_idx.detach().cpu().numpy()): patch_info = trainer.datamodule.get_img_name_coordinates(idx) img_name = patch_info[0] patch_name = patch_info[1] dest_folder = test_output_path / 'patches' / img_name dest_folder.mkdir(parents=True, exist_ok=True) dest_filename = dest_folder / f'{patch_name}.npy' np.save(file=str(dest_filename), arr=patch)