Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scaling issue on CPU (X86 and arm) while using models from huffing face. #4357

Open
choudhary-devang opened this issue Nov 5, 2024 · 2 comments

Comments

@choudhary-devang
Copy link

I was inferring models like (BERT, Resnet, GPT2, etc)from Hugging face on CPU and found it performing was not good compared to Pytorch, so on observing found out that models are not scaling properly in FLAX case.

benchmarking results

BERT model

cores | time in sec
1 | 38.05184221
4 | 30.66457105
8 | 29.46852469
16 | 28.98397946
32 | 28.95349479
64 | 28.77831054

System information

  • OS Platform : Ubuntu 20.04
  • Flax : 0.10.0, jax: 0.4.35 , jaxlib: 0.4.35
  • Python version: 3.10.12
  • CPU: graviton 3

Problem you have encountered:

scaling issue

script i am using:
Image

cpu usage
Image

@choudhary-devang
Copy link
Author

@zaxtax, @cghawthorne, @avital, @lukaszlew can you please look into it!

@cgarciae
Copy link
Collaborator

@choudhary-devang you need to use jax.jit over forward_pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants