Source code for datamodules.Classification.utils.misc

from pathlib import Path

from src.datamodules.utils.exceptions import PathNone, PathNotDir, PathMissingSplitDir


[docs]def validate_path_for_classification(data_dir: str) -> Path: """ Checks if the path is valid for classification :param data_dir: path to the root dir of the dataset :type data_dir: str :return: path to the root dir of the dataset :rtype: Path """ if data_dir is None: raise PathNone("Please provide the path to root dir of the dataset " "(folder containing the train/val folder)") else: split_names = ['train', 'val'] data_folder = Path(data_dir) if not data_folder.is_dir(): raise PathNotDir("Please provide the path to root dir of the dataset " "(folder containing the train/val folder)") split_folders = [d for d in data_folder.iterdir() if d.is_dir() and d.name in split_names] if len(split_folders) != 2: raise PathMissingSplitDir('Your path needs to contain train/val and ' 'each of them a folder per class') return Path(data_dir)