Implementation of DLRM: Embedding Operations #4227
Unanswered
Sir-NoChill
asked this question in
Q&A
Replies: 1 comment
-
Still working on this problem, I switched from nnx back to flax.linen just for more historical code examples. My current model looks like the following, which (I think) is correct: class DLRM_Net(nn.Module):
m_spa: int
ln_emb: List[int]
ln_bot: List[int]
ln_top: List[int]
arch_interaction_op: str
arch_interaction_itself: bool = False
sigmoid_bot: int = -1
sigmoid_top: int = -1
loss_threshold: float = 0.0
weighted_pooling: Optional[str] = None
def setup(self):
self.embeddings = [nn.Embed(num_embeddings=n, features=self.m_spa)
for n in self.ln_emb]
self.bot_mlp = self.create_mlp(self.ln_bot, self.sigmoid_bot)
self.top_mlp = self.create_mlp(self.ln_top, self.sigmoid_top)
def create_mlp(self, ln, sigmoid_layer):
layers = []
for i in range(len(ln) - 1):
layers.append(nn.Dense(features=ln[i + 1]))
if i == sigmoid_layer:
layers.append(nn.sigmoid)
else:
layers.append(nn.relu)
return nn.Sequential(layers)
def apply_embedding(self, lS_o, lS_i, embeddings):
"""Embeddings lookup for sparse features using jax.lax.gather"""
ly = []
for k in range(len(embeddings)):
E = embeddings[k]
embeds = E.apply(lS_i[k], lS_o[k])
# Perform sum over the range of gathered embeddings specified by lS_o
V = jnp.sum(embeds, axis=-1)
ly.append(V)
return ly
def interact_features(self, x, ly):
if self.arch_interaction_op == "dot":
T = jnp.concatenate([x] + ly, axis=1).reshape(x.shape[0], -1, x.shape[1])
Z = jnp.matmul(T, jnp.transpose(T, axes=(0, 2, 1)))
offset = 1 if self.arch_interaction_itself else 0
li = jnp.array([i for i in range(Z.shape[1]) for j in range(i + offset)])
lj = jnp.array([j for i in range(Z.shape[2]) for j in range(i + offset)])
Zflat = Z[:, li, lj]
R = jnp.concatenate([x, Zflat], axis=1)
elif self.arch_interaction_op == "cat":
R = jnp.concatenate([x] + ly, axis=1)
else:
raise ValueError(f"Unsupported interaction op: {self.arch_interaction_op}")
return R
def __call__(self, dense_x, lS_o, lS_i):
x = self.bot_mlp(dense_x)
ly = self.apply_embedding(lS_o, lS_i, self.embeddings)
z = self.interact_features(x, ly)
p = self.top_mlp(z)
if 0.0 < self.loss_threshold < 1.0:
p = jnp.clip(p, self.loss_threshold, 1.0 - self.loss_threshold)
return p though I am having trouble with the |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello community!
I am currently trying to reimplement Meta's DLRM algorithm, specifically the architecture discussed in this paper for profiling and performance research. I am having some trouble with writing a flax implementation of the sparse vector embedding code:
In Meta's implementation, they initialize a
torch.EmbeddingBag
as follows (see line in the original code):but they subsequently use it like this (refer to this line):
However I cannot find a way to duplicate this functionality using the flax
nnx.Embed
orlinen.Embed
class. I am also relatively new to jax/flax so I apologize in advance for my further questions :) My current model is as follows (using nnx):Beta Was this translation helpful? Give feedback.
All reactions