rewrites for exp/log combinations #1540
-
Describe the issue:We have a component in a model that basically boils down to Even considering the case where I thought it would already be simplified but it looks like no rewrite happens Reproducable code example:import pytensor.tensor as pt
from pytensor.graph import rewrite_graph
a = pt.dscalar("a")
b = pt.dscalar("b")
x = pt.dvector("x")
g = pt.exp(a + b * pt.log(x))
g.dprint();
rewrite_graph(g).dprint(); PyTensor version information:pytensor 2.31.6 installed from conda-forge |
Beta Was this translation helpful? Give feedback.
Replies: 10 comments 1 reply
-
I'm not sure you would want such a rewrite, if anything you want to make more things on log scale, not less (see #177), for stability. In your case if b is negative for example it helps keeping the exponential small, but making it power will facilitate underflow. In either case, in the final fused function you only iterate over the vector x once. The performance will depend on your machine, but in mine the original one is roughly 2x faster than the proposed alternative (also allows smaller outputs before underflowing). Power with an arbitrary float is more expensive than repeated log and exp. It may change depending on the specific import pytensor
import pytensor.tensor as pt
import numpy as np
a = pt.dscalar("a")
b = pt.dscalar("b")
x = pt.dvector("x")
g = pt.exp(a + b * pt.log(x))
fn = pytensor.function([a, b, x], g, trust_input=True)
fn.dprint()
# Composite{exp((i2 + (i1 * log(i0))))} [id A] 2
# ├─ x [id B]
# ├─ ExpandDims{axis=0} [id C] 1
# │ └─ b [id D]
# └─ ExpandDims{axis=0} [id E] 0
# └─ a [id F]
# Inner graphs:
# Composite{exp((i2 + (i1 * log(i0))))} [id A]
# ← exp [id G] 'o0'
# └─ add [id H]
# ├─ i2 [id I]
# └─ mul [id J]
# ├─ i1 [id K]
# └─ log [id L]
# └─ i0 [id M]
a_test = np.array(100.0)
b_test = np.array(-100.0)
x_test = np.repeat([1000., 2000.], 1000)
res = fn(a_test, b_test, x_test)
print(res.min()) # 2.1205505218333326e-287
%timeit fn(a_test, b_test, x_test) # 16.1 μs ± 419 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
g_alt = pt.exp(a) * (x ** b)
fn_alt = pytensor.function([a, b, x], g_alt, trust_input=True)
fn_alt.dprint()
# Composite{(i2 * (i0 ** i1))} [id A] 3
# ├─ x [id B]
# ├─ ExpandDims{axis=0} [id C] 2
# │ └─ b [id D]
# └─ Exp [id E] 1
# └─ ExpandDims{axis=0} [id F] 0
# └─ a [id G]
# Inner graphs:
# Composite{(i2 * (i0 ** i1))} [id A]
# ← mul [id H] 'o0'
# ├─ i2 [id I]
# └─ pow [id J]
# ├─ i0 [id K]
# └─ i1 [id L]
res_alt = fn_alt(a_test, b_test, x_test)
print(res_alt.min()) # 0.0
%timeit fn_alt(a_test, b_test, x_test) # 32.6 μs ± 489 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) |
Beta Was this translation helpful? Give feedback.
-
Actually |
Beta Was this translation helpful? Give feedback.
-
Small note when calling |
Beta Was this translation helpful? Give feedback.
-
Thanks for checking. Should we have the inverse rewrites then? from power to exp of product times log? I guess the opposite direction is trickier though because we still want things like
For this you mean by the computational backends themselves? So below what pytensor controls? In case it helps I tried: rewrite_graph(g, include=("canonicalize", "stabilize", "specialize", "fusion")).dprint();
# Composite{exp((i2 + (i1 * log(i0))))} [id A]
# ├─ x [id B]
# ├─ ExpandDims{axis=0} [id C]
# │ └─ b [id D]
# └─ ExpandDims{axis=0} [id E]
# └─ a [id F]
#
# Inner graphs:
#
# Composite{exp((i2 + (i1 * log(i0))))} [id A]
# ← exp [id G] 'o0'
# └─ add [id H]
# ├─ i2 [id I]
# └─ mul [id J]
# ├─ i1 [id K]
# └─ log [id L]
# └─ i0 [id M]
rewrite_graph(g_alt, include=("canonicalize", "stabilize", "specialize", "fusion")).dprint();
# Composite{(i2 * (i0 ** i1))} [id A]
# ├─ x [id B]
# ├─ ExpandDims{axis=0} [id C]
# │ └─ b [id D]
# └─ Exp [id E]
# └─ ExpandDims{axis=0} [id F]
# └─ a [id G]
#
# Inner graphs:
#
# Composite{(i2 * (i0 ** i1))} [id A]
# ← mul [id H] 'o0'
# ├─ i2 [id I]
# └─ pow [id J]
# ├─ i0 [id K]
# └─ i1 [id L] |
Beta Was this translation helpful? Give feedback.
-
Depends on the goal of the rewrite. If you're just compiling for performance/stability you probably don't want that. If you are rewriting to find whether an expression is equivalent to something (or just simplify it in some well defined definition) then it may be fine to have such rewrites, but not include them in the compilation database. This sort of stuff may be helped by having hints at the graph level (like the user specifying that x is non-negative, or we inferring it because it comes from the output of an What was your motivation for this issue? |
Beta Was this translation helpful? Give feedback.
-
Yes. pow is probably a standard library function (in C/numba, and xla in JAX) that implements it at its own discretion. The final form may also be CPU specific |
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
-
I added a comment while reviewing a model to check if that was being rewritten, after seeing it wasn't @tomicapretto and @jessegrabowski were also a bit surprised so I opened the issue.
In the general case right? In our case with x>0, it looks like if we had actually worked out the model analytically on paper first, simplified to power, then implemented it with the power equivalent we would have ended up with slower and less stable code. |
Beta Was this translation helpful? Give feedback.
-
Yes, what looks like a simplification mathematically and performance/stability in float point operations can be very different. In either case I wouldn't expect this to ever be a model bottleneck. If you were worried about performance the first thing would be to profile the logp_dlogp function of the model and see what is taking time If you care about stability you probably would want the original expression as well, but that's a different kind of benchmark altogether |
Beta Was this translation helpful? Give feedback.
-
I'll move this to a discussion, let me know if you think it should still be an issue |
Beta Was this translation helpful? Give feedback.
I'm not sure you would want such a rewrite, if anything you want to make more things on log scale, not less (see #177), for stability.
In your case if b is negative for example it helps keeping the exponential small, but making it power will facilitate underflow. In either case, in the final fused function you only iterate over the vector x once.
The performance will depend on your machine, but in mine the original one is roughly 2x faster than the proposed alternative (also allows smaller outputs before underflowing). Power with an arbitrary float is more expensive than repeated log and exp. It may change depending on the specific
b
.