from typing import List

import numpy as np
import torch

from sklearn.preprocessing import OneHotEncoder

[docs]def gt_to_int_encoding(matrix: torch.Tensor, class_encodings: List[int]) -> torch.Tensor: """ Convert ground truth tensor to integer encoded matrix :param matrix: Image as a tensor of size [C x H x W] (BGR) :type matrix: torch.Tensor :param class_encodings: class encoding so which class (index) has what value (element) :type class_encodings: List[int] :return: integer encoded matrix :rtype: torch.Tensor """ matrix = (matrix * 255) # take only blue channel img_blue = matrix[2, :, :] # change border pixels to background border_mask = torch.where(matrix[0, :, :] != 0, True, False) img_blue[border_mask] = 1 integer_encoded = torch.full(size=img_blue.shape, fill_value=-1, dtype=torch.long) for index, encoding in enumerate(class_encodings): mask = torch.where(img_blue == encoding, True, False) integer_encoded[mask] = index return integer_encoded
[docs]def gt_to_one_hot(matrix: torch.Tensor, class_encodings: List[int]): """ Convert ground truth tensor or numpy matrix to one-hot encoded matrix :param matrix: float tensor from to_tensor() or numpy array shape (C x H x W) in the range [0.0, 1.0] or shape (H x W x C) BGR :type matrix: torch.Tensor or np.ndarray :param class_encodings: List of int Blue channel values that encode the different classes :type class_encodings: List[int] :return: Tensor of size [#C x H x W] sparse one-hot encoded multi-class matrix, where #C is the number of classes :rtype: torch.LongTensor """ num_classes = len(class_encodings) np_array = (matrix * 255).numpy().astype(np.uint8) im_np = np_array[2, :, :].astype(np.uint8) border_mask = np_array[0, :, :].astype(np.uint8) != 0 im_np[border_mask] = 1 integer_encoded = np.array([i for i in range(num_classes)]) onehot_encoder = OneHotEncoder(sparse=False, categories='auto') integer_encoded = integer_encoded.reshape(len(integer_encoded), 1) onehot_encoded = onehot_encoder.fit_transform(integer_encoded).astype(np.int8), im_np == 0, 1) # needed to deal with 0 fillers at the borders during testing (replace with background) replace_dict = {k: v for k, v in zip(class_encodings, onehot_encoded)} # create the one hot matrix one_hot_matrix = np.asanyarray( [[replace_dict[im_np[i, j]] for j in range(im_np.shape[1])] for i in range(im_np.shape[0])]).astype( np.uint8) return torch.LongTensor(one_hot_matrix.transpose((2, 0, 1)))