-
Hi,
Is it true that we can not using flax/jax to implement NMS? It seems to me we can directly use jax.numpy to implement it, am I wrong? Is there special concern if we use jax.numpy to implement NMS? Is there official model implemented by flax that using NMS for reference? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Hi @wztdream, I don't know the context of that comment (it is very old), but I am sure we are able to implement this in JAX somehow. Perhaps @Marvin182 knows more about this, since he was involved in that project? I don't know any reference implementation though, but if you have any specific questions feel free to ask. |
Beta Was this translation helpful? Give feedback.
-
@wztdream - there's actually an NMS implementation in the older JAX MLPerf Training 0.7 "SSD" entry at: Note that some the of other neural net code there uses the deprecated pre-linen Flax API, but the NMS implementation is pure JAX, so hopefully will help. |
Beta Was this translation helpful? Give feedback.
@wztdream - there's actually an NMS implementation in the older JAX MLPerf Training 0.7 "SSD" entry at:
https://github.com/mlperf/training_results_v0.7/blob/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py
Note that some the of other neural net code there uses the deprecated pre-linen Flax API, but the NMS implementation is pure JAX, so hopefully will help.