What is the proper way to concatenate two embeddings? #1946
-
I'm a first-time Jax user trying to use embeddings on the MovieLens dataset. The inputs to this model are a I have created a simple model that creates an embedding for class MovieLensModel(nn.Module):
"""A simple embedding model."""
config: MovieLensConfig
@nn.compact
def __call__(self, user_id, movie_id):
cfg = self.config
user_id = user_id.astype('int32')
user_emb = nn.Embed(num_embeddings=cfg.users_size, embedding_init=nn.initializers.xavier_uniform(), features=cfg.emb_dim, name='user')(user_id)
movie_id = movie_id.astype('int32')
movie_emb = nn.Embed(num_embeddings=cfg.movies_size, embedding_init=nn.initializers.xavier_uniform(), features=cfg.emb_dim, name='movie')(movie_id)
# Note: Is this a problem? How should I concatenate two embeddings?
x = jnp.concatenate((user_emb, movie_emb))
x = nn.Dense(cfg.dense_size_0, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6))(x)
x = nn.relu(x)
x = nn.Dense(cfg.dense_size_1, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6))(x)
x = nn.relu(x)
x = nn.Dense(cfg.out_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6))(x)
return x When I train my network, I notice that the embeddings do not appear to change. When I log I'm guessing I am not allowed to concatenate embeddings using |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
There is nothing wrong with concatenating arrays using # "user" embedding table is shape (users_size, 8)
user_emb = nn.Embed(..., name='user') # users_emb.shape == (1, 8)
# "movie" embedding table is shape (movies_size, 8)
movie_emb = nn.Embed(..., name='movie') # movie_emb.shape == (1, 8)
x = jnp.concatenate((user_emb, movie_emb)) # x.shape == (2, 8) Since the x = jnp.concatenate((user_emb, movie_emb), axis=-1) # x.shape == (1, 16) This will also output a single value (your model currently outputs two: one for users and one for movie). |
Beta Was this translation helpful? Give feedback.
There is nothing wrong with concatenating arrays using
jnp.concatenate
, but I think you probably want to concatenate them in the last dimension. Right now it seems you do the following:Since the
Dense
is applied on the last dimension, it seems you want to do:This will also output a single value (your model currently…