Skip to content

What is the function class_batch_to_labeling_batch(y, y_hat, y_hat_mask=None) mean in ctc_cost.py? #6

@star013

Description

@star013

Hello, I am doing some research on TIMIT and I have to use CTC in my model. I read ctc_cost.py but I can not understand the function: class_batch_to_labeling_batch(y, y_hat, y_hat_mask=None).
In comments, y_hat is T x B x (C+1) matrix and y_hat_mask is T x B matrix. In line 65:
y_hat = y_hat * y_hat_mask.dimshuffle(0, 'x', 1)
I am puzzled because y_hat_mask.dimshuffle(0, 'x', 1) is T x 1 x B matrix and it can not multiply with y_hat which is T x B x (C+1) matrix. In addition, I tried to run this function in Ipython notebook and it reported an error.
Could you please explain why it is y_hat = y_hat * y_hat_mask.dimshuffle(0, 'x', 1) and what is res in the function?
Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions