Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error in task_importance_weights function: afad-coral.py #29

Open
jpcenteno80 opened this issue Jul 31, 2020 · 1 comment
Open

error in task_importance_weights function: afad-coral.py #29

jpcenteno80 opened this issue Jul 31, 2020 · 1 comment

Comments

@jpcenteno80
Copy link

jpcenteno80 commented Jul 31, 2020

Hi,

I think the following function will throw an error if the dataset does not have some age values represented:

def task_importance_weights(label_array):
    uniq = torch.unique(label_array)
    num_examples = label_array.size(0)

    m = torch.zeros(uniq.shape[0])

    for i, t in enumerate(torch.arange(torch.min(uniq), torch.max(uniq))):
        m_k = torch.max(torch.tensor([label_array[label_array > t].size(0), 
                                      num_examples - label_array[label_array > t].size(0)]))
        m[i] = torch.sqrt(m_k.float())

    imp = m/torch.max(m)
    return imp

For the AFAD training set, the line m = torch.zeros(uniq.shape[0]) will generate a tensor of shape 23 since 3 age label groups are missing from the training set (age labels 15, 22, and 24). Enumerating through torch.arange(torch.min(uniq), torch.max(uniq)) might assume all age label groups are represented and will have a different shape than m.

@rasbt
Copy link
Member

rasbt commented Aug 1, 2020

Yeah most of the code here is the code used straight up what's used in the paper. So, given this code, it is expected that users label the dataset such that the labels start at 0 and that the labels are renumbered that they don't contain any gaps.

I.e., age labels 15, 22, and 24 would become 15, 16, 17.

Maybe an automatic function that relabels the labels from 0 to n-1 (where n is the number of labels), together with a mapping dictionary that is shown to the user when using the network to make predictions, would be the best way to handle this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants