Source code for datamodules.utils.functional
import torch
[docs]def argmax_onehot(tensor: torch.Tensor):
"""
Returns the argmax of a one-hot encoded tensor.
:param tensor: The one-hot encoded tensor
:type tensor: torch.Tensor
:returns: The argmax of the one-hot encoded tensor
:rtype: torch.Tensor
"""
return torch.LongTensor(torch.argmax(tensor, dim=0))