-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider implementing hooks as MLIR/custom calls with enif_send #1541
Comments
I looked into this, here are a few notes. Jax has external callbacks, in particular used for printing. On CPU and GPU they use a custom call that invokes a the Python callback function (for GPU the custom call copies data off the GPU first and invokes the callback exactly the same as CPU). For TPU they use send/recv operations (and eventually invoke the callback also). I found one PR with more context jax-ml/jax#13759 (I also asked some questions there, but no answer so far). We could implement the custom calls with On a separate note, we could possibly even add a callback API for getting hook return value into defn. One idea I have would be to create a resource in the custom call that would hold a mutex, a condition variable and a value field. The resource ref would be a part of the message. After sending the message we would call All that said, implementing the custom calls is rather annoying, and it doesn't automatically translate to more platforms. For example, if we ended up adding the Metal plugin, we would need to implement another custom call (and I expect custom call may not even be a thing there, so perhaps it needs yet another mechanism). The only reason to make that change now would be switching from the StreamExecutor GPU implementation to the PjRt plugins (which don't support infeed/outfeed), but it doesn't really make a difference for the end user, and we can likely maintain compatibility if we do it in the future (i.e. XLA_TARGET would download/register the necessary plugins, so existing setup would work). Given many of the decisions happen internally and xla/jax/tensorflow is multiple efforts, things may shift in the future. For all these reasons we decided to wait and make changes once we really need. |
I think that the custom call could receive, as an MLIR attribute, something that encodes a edit: this is specifically for sending things out from the computation back to Elixir -- really useful for monitoring values in a long-running computation, for instance, or debugging via print_value |
@polvalente to pass it as MLIR attribute (or constant input, since that's how we pass info to custom calls) it needs to be known at MLIR-compile time, so it can't be any transient information like ref/pid, especially that we can even cache the executable on disk. |
@jonatanklosko I hadn't considered the possibility of model serialization. We can actually pass a PID (or any term for that matter) as a runtime argument if we use This would add a new input to the function, but it's an alternative to having a fixed value. |
@polvalente I thought about using inputs, but it seems to me that it's too much. The custom call alone makes the MLIR elixir-specific and not necessarily portable, but having a specific input is a step further. |
We need to understand if it will make interoperability better or worse, in particular in regards to IREE and Apple Metal plugin. Note that JAX in particular emits custom MLIR code for these operations.
The text was updated successfully, but these errors were encountered: