Skip to content

Conversation

@DhairyaLGandhi
Copy link
Contributor

No description provided.

@DhairyaLGandhi DhairyaLGandhi changed the title chore: add 5 arg mul detection Add 5-arg mul! Detection Oct 22, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Oct 23, 2025

Benchmark Results (Julia vlts)

Time benchmarks
master 468e8b7... master / 468e8b7...
arithmetic/addition 0.0806 ± 0.001 ms 0.0781 ± 0.001 ms 1.03 ± 0.019
arithmetic/division 29.8 ± 0.76 μs 29.3 ± 0.75 μs 1.02 ± 0.037
arithmetic/multiplication 0.0634 ± 0.0014 ms 0.0619 ± 0.002 ms 1.02 ± 0.04
overhead/acrule/a+2 2.76 ± 0.071 μs 2.79 ± 0.1 μs 0.989 ± 0.044
overhead/acrule/a+2+b 0.07 ± 0.01 μs 0.07 ± 0.01 μs 1 ± 0.2
overhead/acrule/a+b 4.94 ± 0.15 μs 4.92 ± 0.16 μs 1 ± 0.045
overhead/acrule/noop:Int 0.05 ± 0.01 μs 0.05 ± 0.01 μs 1 ± 0.28
overhead/acrule/noop:Sym 0.05 ± 0.01 μs 0.05 ± 0.01 μs 1 ± 0.28
overhead/get_degrees/large_poly 0.103 ± 0.0044 s 0.109 ± 0.0048 s 0.95 ± 0.058
overhead/rule/noop:Int 0.07 ± 0.01 μs 0.06 ± 0.01 μs 1.17 ± 0.26
overhead/rule/noop:Sym 0.06 ± 0.01 μs 0.06 ± 0.01 μs 1 ± 0.24
overhead/rule/noop:Term 0.06 ± 0.01 μs 0.06 ± 0.01 μs 1 ± 0.24
overhead/ruleset/noop:Int 30 ± 0 ns 30 ± 0 ns 1 ± 0
overhead/ruleset/noop:Sym 0.291 ± 0.011 μs 0.3 ± 0.01 μs 0.97 ± 0.049
overhead/ruleset/noop:Term 1.2 ± 0.021 μs 1.21 ± 0.02 μs 0.992 ± 0.024
overhead/simplify/noop:Int 30 ± 0 ns 30 ± 0 ns 1 ± 0
overhead/simplify/noop:Sym 30 ± 10 ns 30 ± 10 ns 1 ± 0.47
overhead/simplify/noop:Term 0.0331 ± 0.00096 ms 0.0327 ± 0.00088 ms 1.01 ± 0.04
overhead/simplify/randterm (+, *):serial 0.274 ± 0.0075 s 0.271 ± 0.0065 s 1.01 ± 0.037
overhead/simplify/randterm (+, *):thread 0.319 ± 0.16 s 0.483 ± 0.16 s 0.659 ± 0.4
overhead/simplify/randterm (/, *):serial 0.0889 ± 0.0015 ms 0.0936 ± 0.002 ms 0.949 ± 0.026
overhead/simplify/randterm (/, *):thread 0.0923 ± 0.0017 ms 0.0975 ± 0.0022 ms 0.947 ± 0.028
overhead/substitute/a 0.0538 ± 0.0013 ms 0.0517 ± 0.0012 ms 1.04 ± 0.034
overhead/substitute/a,b 0.0671 ± 0.0016 ms 0.0646 ± 0.0014 ms 1.04 ± 0.033
overhead/substitute/a,b,c 0.0603 ± 0.0014 ms 0.0581 ± 0.0012 ms 1.04 ± 0.032
polyform/easy_iszero 24 ± 0.56 μs 25.3 ± 0.61 μs 0.948 ± 0.032
polyform/isone 1.11 ± 0.021 ms 1.17 ± 0.029 ms 0.954 ± 0.03
polyform/isone:noop 0.15 ± 0.001 μs 0.15 ± 0.01 μs 1 ± 0.067
polyform/iszero 0.953 ± 0.015 ms 0.97 ± 0.022 ms 0.982 ± 0.027
polyform/iszero:noop 0.151 ± 0.01 μs 0.16 ± 0.01 μs 0.944 ± 0.086
polyform/simplify_fractions 1.23 ± 0.022 ms 1.26 ± 0.026 ms 0.974 ± 0.026
time_to_load 1.2 ± 0.019 s 1.21 ± 0.0072 s 0.991 ± 0.017
Memory benchmarks
master 468e8b7... master / 468e8b7...
arithmetic/addition 0.438 k allocs: 16 kB 0.438 k allocs: 16 kB 1
arithmetic/division 0.197 k allocs: 6.86 kB 0.2 k allocs: 6.95 kB 0.987
arithmetic/multiplication 0.357 k allocs: 11.7 kB 0.358 k allocs: 11.7 kB 0.997
overhead/acrule/a+2 0.036 k allocs: 1.27 kB 0.037 k allocs: 1.29 kB 0.988
overhead/acrule/a+2+b 0 allocs: 0 B 0 allocs: 0 B
overhead/acrule/a+b 0.051 k allocs: 1.84 kB 0.053 k allocs: 1.88 kB 0.983
overhead/acrule/noop:Int 0 allocs: 0 B 0 allocs: 0 B
overhead/acrule/noop:Sym 0 allocs: 0 B 0 allocs: 0 B
overhead/get_degrees/large_poly 0.601 M allocs: 18.9 MB 0.712 M allocs: 20.6 MB 0.918
overhead/rule/noop:Int 2 allocs: 0.0625 kB 2 allocs: 0.0625 kB 1
overhead/rule/noop:Sym 2 allocs: 0.0625 kB 2 allocs: 0.0625 kB 1
overhead/rule/noop:Term 2 allocs: 0.0625 kB 2 allocs: 0.0625 kB 1
overhead/ruleset/noop:Int 0 allocs: 0 B 0 allocs: 0 B
overhead/ruleset/noop:Sym 3 allocs: 0.109 kB 3 allocs: 0.109 kB 1
overhead/ruleset/noop:Term 12 allocs: 0.391 kB 12 allocs: 0.391 kB 1
overhead/simplify/noop:Int 0 allocs: 0 B 0 allocs: 0 B
overhead/simplify/noop:Sym 0 allocs: 0 B 0 allocs: 0 B
overhead/simplify/noop:Term 0.372 k allocs: 14.2 kB 0.391 k allocs: 14.5 kB 0.98
overhead/simplify/randterm (+, *):serial 2.86 M allocs: 0.105 GB 2.96 M allocs: 0.107 GB 0.985
overhead/simplify/randterm (+, *):thread 2.87 M allocs: 0.261 GB 2.97 M allocs: 0.263 GB 0.994
overhead/simplify/randterm (/, *):serial 0.806 k allocs: 29.8 kB 0.886 k allocs: 31 kB 0.96
overhead/simplify/randterm (/, *):thread 0.831 k allocs: 30.5 kB 0.912 k allocs: 31.8 kB 0.96
overhead/substitute/a 0.308 k allocs: 11 kB 0.324 k allocs: 11.2 kB 0.978
overhead/substitute/a,b 0.394 k allocs: 13.9 kB 0.41 k allocs: 14.2 kB 0.982
overhead/substitute/a,b,c 0.355 k allocs: 12.1 kB 0.37 k allocs: 12.3 kB 0.981
polyform/easy_iszero 0.14 k allocs: 4.83 kB 0.147 k allocs: 4.95 kB 0.975
polyform/isone 8.9 k allocs: 0.579 MB 8.95 k allocs: 0.54 MB 1.07
polyform/isone:noop 2 allocs: 32 B 2 allocs: 32 B 1
polyform/iszero 7.41 k allocs: 0.475 MB 7.54 k allocs: 0.449 MB 1.06
polyform/iszero:noop 2 allocs: 32 B 2 allocs: 32 B 1
polyform/simplify_fractions 10.1 k allocs: 0.619 MB 10 k allocs: 0.578 MB 1.07
time_to_load 0.153 k allocs: 14.5 kB 0.153 k allocs: 14.5 kB 1

Copy link
Member

@AayushSabharwal AayushSabharwal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the broad strokes of implementation are here, but I have a few overarching concerns:

  • There's a lot that can be refactored in the core codegen to make it more efficient and type-stable. That's not within the scope of this PR, but I feel that we should try and ensure anything we're adding now at least attempts to be as type-stable as reasonably possible. A lot of the data structures here can be more concretely typed. This also helps convey what they store/represent when passed around.
  • In a similar vein, there are a significant number of extra allocations that can be avoided without much effort.
  • Could the functions (and some core data structures) be documented to make this easier to refactor in the future? It's fine if AI/Claude writes the docs, as long as they're human-reviewed and the commits are marked as being written by AI.

src/matmuladd.jl Outdated
Comment on lines 29 to 33
iscall(rhs(x)) || return false
args = arguments(rhs(x))
all_arrays = all(y -> y <: AbstractArray, symtype.(args))
is_mul = operation(rhs(x)) === *
all_arrays && is_mul
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be phrased as operation(rhs(x)) === * || return false for an early-exit, since checking symtypes is relatively more expensive. Also, for array multiplication with a scalar coefficient (i.e. k * A * B * C for scalar k) is phrased as a Term with * as the operation and the scalar coefficient (which may itself be an expression) as the first argument. The check here will miss this case.

EDIT: Actually, this entire condition could just be isterm(x) && operation(x) === * && symtype(x) <: AbstractArray.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do, I wasn't sure how robust type propagation is with expressions or if we will be better served by ensuring all the arguments are explicitly arrays. Since we didn't handle scalars I wanted to explicitly exit for those cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct though? Can we assume that the optimizations we are running will be applicable for when the symtype is the only thing we are relying on? For example, we don't generate the broadcast/ allocation statement for when C is a scalar. Is it better to be a bit conservative for the first iteration?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. If it helps, multiplying a bunch of scalars and arrays is guaranteed to result in a Term for which args[1] is the scalar coefficient. If the coefficient is absent, then args[1] is just the first array.

Comment on lines 37 to 43
plus_candidates_idx = findall(expr.pairs) do x
iscall(rhs(x)) || return false
args = arguments(rhs(x))
all_arrays = all(y -> y <: AbstractArray, symtype.(args))
is_plus = operation(rhs(x)) === +
all_arrays && is_plus
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check could be isadd(x) && symtype(x) <: AbstractArray

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I seem to get isadd(x) = false and symtype(x) = Assignment for every x. operation(x) = Assignment as well. So it seems like I will need to operate on the rhs at the least.

src/matmuladd.jl Outdated
mul_vals = lhs.(mul_candidates)
candidates = map(plus_candidates_idx, plus_candidates) do p_idx, p
map(mul_candidates_idx, mul_vals) do m_idx, m_v
if nameof(m_v) in nameof.(arguments(rhs(p)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to check nameof? Why not any(isequal(m_v), arguments(rhs(p)))? Additionally, pargs = Set(arguments(rhs(p)) outside the inner map can potentially save a lot of time.

nameof is also not always a valid operation.

Comment on lines 47 to 54
candidates = map(plus_candidates_idx, plus_candidates) do p_idx, p
map(mul_candidates_idx, mul_vals) do m_idx, m_v
if nameof(m_v) in nameof.(arguments(rhs(p)))
(m_idx, m_v) => (p_idx, expr.pairs[p_idx])
end
end
end
candidates = filter(!isnothing, reduce(vcat, candidates))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, this entire block seems to be dead code? The result is not used anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this block can be removed, there is a potentially faster and better check which maps the candidates to the Assignments, but that is not implemented yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this block is now used, and the previous pattern matcher has been removed

Comment on lines +145 to +147
dic[key] = vcat(dic[key], value)
else
dic[key] = value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dic[key] = vcat(dic[key], value)
else
dic[key] = value
push!(dic[key], value)
else
dic[key] = copy(value)

Allocates once instead of every time a key is pushed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how that helps? At the time of insertion, we don't actually know if there is a N-to-1 map trivially in this function. And assuming that will make us always do some extra legwork on the codegen side.

src/matmuladd.jl Outdated

function find_cse_expr(x, state)
idx = findfirst(y -> nameof(lhs(y)) == nameof(x), state.sorted_exprs)
isnothing(idx) ? nothing : (; expr = rhs(state.sorted_exprs[idx]), x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to return a NamedTuple with x if x is an input to the function? I'm pretty sure returning Any is better than returning @NamedTuple{expr::Any, x::BasicSymbolic{T})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can certainly be removed for now, it was part of the symbol counting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is removed

src/matmuladd.jl Outdated
A, B = arguments(rhs(m))
Cs = filter(x -> nameof(x) != nameof(mul_val), plus_args)
validate_mul_shapes(A, B, Cs...) || return nothing
return (; A, B, Cs, mul_candidate = m, plus_candidate = c, mul_idx, plus_idx, pattern="A*B + C")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a struct instead of a NamedTuple? Easier to document and less likely to miss or typo a field. Additionally, what role does pattern serve here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed a simple way to ID what optimization was detected when we have multiple passes assuming there is a fairly generic implementation for a code transformation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now a struct

))

is_array(x) = symtype(x) <: AbstractArray
const ARRAY_RULES = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule isn't exactly correct, and is untested. This won't result in the expression LinearAlgebra.mul!(copy(c), a, b, 1, 1) as I think you intended. Instead, it will error since copy is not implemented for BasicSymbolic. Disregarding that, mul! will trace since it is also not registered. Even if it were registered, simplify does not take codegen into consideration. It is simply a set of rules that transform a declarative expression into an equivalent heuristically "simpler" declarative expression.

Copy link
Contributor Author

@DhairyaLGandhi DhairyaLGandhi Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs Symbolics.jl#dg/arr which adds the @register_array_symbolic rules, and actually does the transformation. Test is fair though I am not sure about how to go about it since it needs the Symbolics branch first, so perhaps this can be a part of a follow on once that is released?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The registration can be done manually in methods.jl. It doesn't really need the macro. But this doesn't address the fact that this rule is conflating codegen with symbolic representation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that, we can get rid of it then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done fwiw

Comment on lines 159 to 177
import .Code
import .Code: CSEState, Let, Assignment, Func, MakeArray,
MakeSparseArray, AtIndex, DestructuredArgs, SpawnFetch,
LiteralExpr, BasicSymbolic, ForLoop, SetArray, MakeTuple,
lhs, rhs
include("matmuladd.jl")

# Updated mul5_cse2 that uses the rule system
function mul5_cse2(expr, state::CSEState)

# Try to apply optimization rules
optimized = apply_optimization_rules(expr, state)
if optimized !== nothing
return optimized
end

# If no optimization applied, return original expression
return expr
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matmuladd.jl should be part of code.jl, alongside all other codegen aspects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code.jl is set up as a module in a file, and it would be better to refactor it as a subpackage and not bloat the file too much

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't include("matmuladd.jl") just be moved to code.jl? The source doesn't need to be inlined.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

test_optimization(expr4, A, B, C, D)

expr5 = sin.(A * B + C + D + C * D)
test_optimization(expr5, A, B, C, D)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try

P = A + B
Q = B + C
R = C / D
P * Q + R

Does this correctly turn into mul!(R, P, Q, 1, 1)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the basic answer is we currently filter out some variables in our detection to find the C.

@@ -67,13 +74,17 @@ function detect_matmul_add_pattern(expr::Code.Let, state::Code.CSEState)
-    Cset = Set(filter(!is_cse_var, reduce(vcat,getproperty.(match_data_, :Cs))))
+    Cset = Set(Iterators.flatten(getproperty.(match_data_, :Cs)))

I need a simplified version which correctly identifies which Cs are valid for our transformation.

Fwiw with this applied:

:(function (A, B, C, D)
      #= /home/dhairyalgandhi/arpa/jsmo/clone/Symbolics.jl/src/build_function.jl:147 =# @inbounds begin
              #= /home/dhairyalgandhi/arpa/jsmo/clone/Symbolics.jl/src/build_function.jl:147 =#
              begin
                  #= /home/dhairyalgandhi/arpa/jsmo/clone/SymbolicUtils.jl/src/code.jl:507 =#
                  #= /home/dhairyalgandhi/arpa/jsmo/clone/SymbolicUtils.jl/src/code.jl:508 =#
                  #= /home/dhairyalgandhi/arpa/jsmo/clone/SymbolicUtils.jl/src/code.jl:509 =#
                  begin
                      var"##cse#1" = (/)(C, D)
                      var"##cse#2" = (+)(B, A)
                      var"##cse#3" = (+)(B, C)
                      var"##mul5_temp#554" = (copy)(var"##cse#1")
                      var"##mul5_temp#554" = (mul!)(var"##mul5_temp#554", var"##cse#2", var"##cse#3", 1, 1)
                      var"##mul5_temp#554" = var"##mul5_temp#554"
                      var"##mul5_temp#554"
                  end
              end
          end
  end)

@DhairyaLGandhi
Copy link
Contributor Author

Re: \ -> ldiv! : Is that worth it? We would often see oop versions of some functions being simple wrappers around the iip versions with a buffer allocation similar to

function f!(buf, ...)
  ...
end

function f(...)
    buf = similar(...)
    f!(buf, ...)
    buf
end

The value in hoisting allocations in these cases is also only when buf is used purely as a write cache which is harder to assume. We can use Bumper.jl to do the allocations only once, but we need to be careful to actually write into the buffer before we use it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants