You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I think the following function will throw an error if the dataset does not have some age values represented:
def task_importance_weights(label_array):
uniq = torch.unique(label_array)
num_examples = label_array.size(0)
m = torch.zeros(uniq.shape[0])
for i, t in enumerate(torch.arange(torch.min(uniq), torch.max(uniq))):
m_k = torch.max(torch.tensor([label_array[label_array > t].size(0),
num_examples - label_array[label_array > t].size(0)]))
m[i] = torch.sqrt(m_k.float())
imp = m/torch.max(m)
return imp
For the AFAD training set, the line m = torch.zeros(uniq.shape[0]) will generate a tensor of shape 23 since 3 age label groups are missing from the training set (age labels 15, 22, and 24). Enumerating through torch.arange(torch.min(uniq), torch.max(uniq)) might assume all age label groups are represented and will have a different shape than m.
The text was updated successfully, but these errors were encountered:
Yeah most of the code here is the code used straight up what's used in the paper. So, given this code, it is expected that users label the dataset such that the labels start at 0 and that the labels are renumbered that they don't contain any gaps.
I.e., age labels 15, 22, and 24 would become 15, 16, 17.
Maybe an automatic function that relabels the labels from 0 to n-1 (where n is the number of labels), together with a mapping dictionary that is shown to the user when using the network to make predictions, would be the best way to handle this.
Hi,
I think the following function will throw an error if the dataset does not have some age values represented:
For the AFAD training set, the line
m = torch.zeros(uniq.shape[0])
will generate a tensor of shape 23 since 3 age label groups are missing from the training set (age labels 15, 22, and 24). Enumerating throughtorch.arange(torch.min(uniq), torch.max(uniq))
might assume all age label groups are represented and will have a different shape thanm
.The text was updated successfully, but these errors were encountered: