If so, what is the broadcastable flag of an Elemwise output?
x = pt.vector("x", shape=(1,), broadcastable=(False,))
y = x + x
assert y.type == ?
What about clone_replace? It seems like it should fail in strict=False
(default)
import pytensor.tensor as pt
x = pt.vector("x")
y = x + 5
pytensor.dprint(y, print_type=True)
new_x = pt.vector("new_x", shape=(1,))
new_y = pytensor.clone_replace(y, {x: new_x})
pytensor.dprint(new_y, print_type=True)
Add [id A] <Vector(float64, shape=(?,), broadcastable=(False,))>
├─ x [id B] <Vector(float64, shape=(?,), broadcastable=(False,))>
└─ ExpandDims{axis=0} [id C] <Vector(int8, shape=(1,), broadcastable=(True,))>
└─ 5 [id D] <Scalar(int8, shape=(), broadcastable=())>
Add [id A] <Vector(float64, shape=(1,), broadcastable=(True,))>
├─ Unbroadcast{0} [id B] <Vector(float64, shape=(1,), broadcastable=(False,))>
│ └─ new_x [id C] <Vector(float64, shape=(1,), broadcastable=(True,))>
└─ ExpandDims{axis=0} [id D] <Vector(int8, shape=(1,), broadcastable=(True,))>
└─ 5 [id E] <Scalar(int8, shape=(), broadcastable=())>
pymc-devs/pytensor#390 (NEEDS REVIEW) pymc-devs/pytensor#329 (comment)
- Should we allow / be as strict as in other Ops?
- pymc-devs/pytensor#381 (NEEDS REVIEW)
- PR cleans up many rewrites logic by not worrying about shape safety. Rewrites are tagged so users can exclude them. Seems reasonable?
- Elemwise for non-fuseable 0d arrays doesn't seem to have that much of a drawback (at least in C / JAX backends): pymc-devs/pytensor#349 (comment)
- We already fuse chains of 0d elemwise, which take care of boxing/unboking at the edges, but are otherwise pure scalar operations in between
- We can avoid some more cases by "unbroadcasting" constant arrays that go into Elemwise (already done for fusion graphs): pymc-devs/pytensor#361
- Can we consider this solved?
Relevant JAX PR on Sparse / RNGs inputs: pymc-devs/pytensor#278 (Mixed reviews so far)
AFAICT can't be (easily) solved by the multiple container idea. We could implement a subclass of list that syncronizes on right, but I am afraid this would break it on the C-backend. Solution: Copy shared variables that have incompatabile types and tell users how they can be retrieved from the compiled function Problem: I couldn'n figure out how to do this.
- pymc-devs/pytensor#306 (NEEDS REVIEW)
- Should our rewrites support both forms, or just Blockwise?
- My approach was to support only Blockwise, and have a late rewrite that removes useless Blockwise (0 batch dims)
- Many blockwise even with batch dims are useless in some backends (almost everything I looked at was naturally batched in JAX, and in numpy linalg, but not scipy linalg). What about Numba?
- Easy to include a simple rewrite that removes backend-specific uselesses.
- We do we want to do a C-implementation
- Who wants to write a Numba impl?
- Who wants to write a JAX impl?
- Do we want to fuse BatchedOps on Numba backend?
https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#batching
- We already do something like this for the gradient of Blockwise, where we start from a "core" gradient and vectorize it via a dispatch.
- Dispatch fallsback to Blockwise for Ops, but there are special cases like Elemwise/CAReduce/Dimshuffle/RandomVariables that are "natively" batcheable with little-no logic
- This doesn't have all the bells and whistles of JAX vmap (axes and stuff) but I feel those are not really important? Batching everything to the left sounds easy enough and covers many cases. Objections?
- Should be easy to support in_axes and out_axes, the dispatch functions become a bit more complex, you have to transpose inputs and outputs sometimes. Not sure it's worth it, Adrian thinks it's neat.
- Can be done in a follow up PR to Blockwise, even if we change the signature of the dispatch function. I (Ricardo) will take the anger for breaking the "public API".
- Question of rewrite ordering and worries about duplicating costly operations arise almost in every rewrite PR.
- For instance we could replace Switch(cond, a, b) -> (empty(), set_subtensor(cond, a[cond]), set_subtensor(~cond, b[~cond])) after broadcasting everything. Indexing operations can then be lifted closer to the inputs, making the switch in fact "lazy". But we don't know when is this useful if we can't know how much lifting can be done (as it might otherwise break Elemwise fusion)
- Eggs and some meta-optimizer sound like the right solution for this. Is it? Can it actually reason about e.g., different permutations of index lifting and fusion rewrites?
- Do we want to consider it seriously?
- What are the biggest obstacles?
- Complexly parametrized Ops jump to mind (Good luck representing an SetSubtensor symbolically in any useful way)
- Do we need immutable graphs?
- Worth doing a POC and if it looks promising trying to get some GSOC / Numfocus project for it?
-
Gradient optimizations
- We experimented with running canonicalize/stabilizy in PyMC logp graph because taking the grad.
- Still considering the idea of a lazy dummy Grad Op. I think reasoning in terms of gradient operations could be very interesting.
- Still think it achieves the same thing as
value_and_grad
optimizations easily with a very simple kind of PyTensor rewrite.
-
More ergonomic scan
- No updates
-
True IfElse (in JAX and Numba)
- No updates
-
RandomVariable updates logic
- No updates (no pun intended)
-
Gradient logic
- Still don't know if the double Lop vs Rop thing is true.
- Should still remove Lop vs Grad distinction
- Consider other names