Skip to content

What is the proper way to concatenate two embeddings? #1946

Answered by marcvanzee
JoshVarty asked this question in Q&A
Discussion options

You must be logged in to vote

There is nothing wrong with concatenating arrays using jnp.concatenate, but I think you probably want to concatenate them in the last dimension. Right now it seems you do the following:

# "user" embedding table is shape (users_size, 8)
user_emb = nn.Embed(..., name='user')  # users_emb.shape == (1, 8)

# "movie" embedding table is shape (movies_size, 8)
movie_emb = nn.Embed(..., name='movie')  # movie_emb.shape == (1, 8)

x = jnp.concatenate((user_emb, movie_emb))  # x.shape == (2, 8)

Since the Dense is applied on the last dimension, it seems you want to do:

x = jnp.concatenate((user_emb, movie_emb), axis=-1)  # x.shape == (1, 16)

This will also output a single value (your model currently…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by JoshVarty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants