MultiHeadAttentionWrapper should instantiate CausalSelfAttention with d_out = d_out // num_heads? #609
Unanswered
henrythe9th
asked this question in
Q&A
Replies: 1 comment 1 reply
-
I believe the confusion lies in how we are interpreting In your impl, It's true that it's clearer in the sense that the |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Since the
MultiHeadAttentionWrapper
class callstorch.cat([head(x) for head in self.heads], dim=-1)
shouldn't we be instantiating
CausalSelfAttention
with d_out = d_out // num_heads so that the finalMultiHeadAttentionWrapper
output has the same shape and d_out as was specified in the input?In other words, is this a clearer implementation?
Beta Was this translation helpful? Give feedback.
All reactions