You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX ERROR -- (jit(_multiply)/jit(main)/mul_multiply.0) Internal tensorizer error: BirCodeGenLoop:BIRCodegen does not support broadcast patterns, but found one in {0,+,0}[128]
#1044
Open
felarof99 opened this issue
Nov 27, 2024
· 4 comments
I am trying to llama3.2 1B fine-tuning using AWS Trn1 and I'm running into the following error.
Error in eager mode (without jax.jit):
2024-11-21 04:44:13.000699: 3926 ERROR ||NEURON_CC_WRAPPER||: Failed compilation with ['neuronx-cc', 'compile', '--target=trn1', '--framework=XLA', '/tmp/ubuntu/neuroncc_compile_workdir/19af4718-112f-4fb8-92fb-ec725d3f5334/model.MODULE_5516390676483383119+d7517139.hlo_module.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/19af4718-112f-4fb8-92fb-ec725d3f5334/model.MODULE_5516390676483383119+d7517139.neff', '--verbose=35']: 2024-11-21T04:44:13Z [TEN404] (jit(_multiply)/jit(main)/mul_multiply.0) Internal tensorizer error: BirCodeGenLoop:BIRCodegen does not support broadcast patterns, but found one in {0,+,0}[128] - Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new. You may also be able to obtain more information using the 'XLA_IR_DEBUG' and 'XLA_HLO_DEBUG' environment variables.
My code is open-source, here's the model definition file.
I tried running the entire training step with JIT and in completely eager mode. Both options failed. I've attached the stack traces below too. jitted_run.txt eager_run_no_jit.txt
Let me know if you need any other info.
Thanks a lot for looking into this!
The text was updated successfully, but these errors were encountered:
Thanks for reaching out! To debug this further, some additional artifacts will be helpful. Can you run the script with JAX_DUMP_IR_TO=/path and JAX_TRACEBACK_FILTERING=off to dump JAX generated IR? This can be used to determine if it's due to an issue such as unsupported ops/patterns
Thanks for sharing the IR and utilizing the known issues page! Our team took a look at the IR you shared, and based on this we need some additional information to debug this further. Can you run the script and use the following to dump the HLOs: os.environ['NEURON_CC_FLAGS'] = "--dump=/path"
With the HLOs we will be able to better investigate and assist you with this issue.
Hi,
I am trying to
llama3.2 1B
fine-tuning using AWS Trn1 and I'm running into the following error.Error in eager mode (without jax.jit):
My code is open-source, here's the model definition file.
I tried running the entire training step with JIT and in completely eager mode. Both options failed. I've attached the stack traces below too.
jitted_run.txt
eager_run_no_jit.txt
Let me know if you need any other info.
Thanks a lot for looking into this!
The text was updated successfully, but these errors were encountered: