Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX ERROR -- (jit(_multiply)/jit(main)/mul_multiply.0) Internal tensorizer error: BirCodeGenLoop:BIRCodegen does not support broadcast patterns, but found one in {0,+,0}[128] #1044

Open
felarof99 opened this issue Nov 27, 2024 · 4 comments
Labels

Comments

@felarof99
Copy link

Hi,

I am trying to llama3.2 1B fine-tuning using AWS Trn1 and I'm running into the following error.

Error in eager mode (without jax.jit):

2024-11-21 04:44:13.000699:  3926  ERROR ||NEURON_CC_WRAPPER||: Failed compilation with ['neuronx-cc', 'compile', '--target=trn1', '--framework=XLA', '/tmp/ubuntu/neuroncc_compile_workdir/19af4718-112f-4fb8-92fb-ec725d3f5334/model.MODULE_5516390676483383119+d7517139.hlo_module.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/19af4718-112f-4fb8-92fb-ec725d3f5334/model.MODULE_5516390676483383119+d7517139.neff', '--verbose=35']: 2024-11-21T04:44:13Z [TEN404] (jit(_multiply)/jit(main)/mul_multiply.0) Internal tensorizer error: BirCodeGenLoop:BIRCodegen does not support broadcast patterns, but found one in {0,+,0}[128] - Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new. You may also be able to obtain more information using the 'XLA_IR_DEBUG' and 'XLA_HLO_DEBUG' environment variables.

My code is open-source, here's the model definition file.

I tried running the entire training step with JIT and in completely eager mode. Both options failed. I've attached the stack traces below too.
jitted_run.txt
eager_run_no_jit.txt

Let me know if you need any other info.
Thanks a lot for looking into this!

@awsrjh
Copy link
Contributor

awsrjh commented Nov 27, 2024

thanks for sending -- we will take a look.

@fayyadd
Copy link

fayyadd commented Nov 27, 2024

Thanks for reaching out! To debug this further, some additional artifacts will be helpful. Can you run the script with JAX_DUMP_IR_TO=/path and JAX_TRACEBACK_FILTERING=off to dump JAX generated IR? This can be used to determine if it's due to an issue such as unsupported ops/patterns

We have a list of Jax Neuron known issues that might be helpful as well: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/jax/setup/jax-neuronx-known-issues.html#jax-neuron-known-issues

@felarof99
Copy link
Author

Thanks for known issue link, I went through it and borrowed changes applicable to me -- I now set jax.config.update("jax_default_prng_impl", "rbg").

I'm still getting the error when I run the trainer.

I ran the trainer by dumping IR as you suggested (JAX_DUMP_IR_TO=/path and JAX_TRACEBACK_FILTERING=off) -- here is the IR dump. https://github.com/felarof99/aws-trn-debug-ir

@fayyadd
Copy link

fayyadd commented Dec 2, 2024

Thanks for sharing the IR and utilizing the known issues page! Our team took a look at the IR you shared, and based on this we need some additional information to debug this further. Can you run the script and use the following to dump the HLOs: os.environ['NEURON_CC_FLAGS'] = "--dump=/path"

With the HLOs we will be able to better investigate and assist you with this issue.

@jeffhataws jeffhataws added the jax label Dec 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants