Skip to content
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
270d17a
chore: add 5 arg mul detection
DhairyaLGandhi Oct 22, 2025
caf54ab
chore: fix imports
DhairyaLGandhi Oct 22, 2025
7aade81
chore: manage other operations
DhairyaLGandhi Oct 22, 2025
6c2dee1
test: matmuladd opt
DhairyaLGandhi Oct 23, 2025
3c28cdc
test: correct filename
DhairyaLGandhi Oct 23, 2025
08c601b
chore: filter out Const with iscall
DhairyaLGandhi Oct 23, 2025
d15ee05
test: rm Symbolics dep
DhairyaLGandhi Oct 23, 2025
e8eec80
chore: revert return symbol type
DhairyaLGandhi Oct 23, 2025
53ee3de
feat: safely substitute variables in IR
DhairyaLGandhi Oct 27, 2025
337e9b8
chore: rm tree walking
DhairyaLGandhi Oct 27, 2025
6abdbd9
chore: reduce on N-to-1 mapping on transformation
DhairyaLGandhi Oct 27, 2025
468e8b7
test: add downstream code execution test
DhairyaLGandhi Oct 27, 2025
138a2c4
chore: check v issym
DhairyaLGandhi Oct 27, 2025
742842c
chore: improve matching handling + parametrize some structs
DhairyaLGandhi Oct 29, 2025
8aade35
test: add complex cases with upstream/ downstream code sanitation che…
DhairyaLGandhi Oct 29, 2025
a0c5967
chore: handle chains of muls
DhairyaLGandhi Oct 29, 2025
2a31c78
chore: remove shape checking
DhairyaLGandhi Oct 29, 2025
cacd14c
chore: minor cleanups
DhairyaLGandhi Oct 29, 2025
e8cda29
test: add test case for chained mul
DhairyaLGandhi Oct 29, 2025
451c62e
chore: undo unnecessary commits
DhairyaLGandhi Oct 31, 2025
6ef44a0
chore: move optimization to code.jl
DhairyaLGandhi Oct 31, 2025
6b0925e
chore: revert array expression simplification by rules
DhairyaLGandhi Oct 31, 2025
182cb04
test: missing collect
DhairyaLGandhi Oct 31, 2025
283f172
chore: simplify pattern finding rules
DhairyaLGandhi Nov 3, 2025
fee1da5
test: add atol to opt test
DhairyaLGandhi Nov 3, 2025
180e465
chore: add docs for 5-arg mul optimization
DhairyaLGandhi Nov 3, 2025
f8ab20b
chore: use add_worker
DhairyaLGandhi Nov 3, 2025
75447f5
chore: revert isadd change
DhairyaLGandhi Nov 3, 2025
32288b1
chore: avoid gensym
DhairyaLGandhi Nov 3, 2025
920260d
chore: simplify some accesses
DhairyaLGandhi Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,25 @@ export substitute
include("substitute.jl")

include("code.jl")
import .Code
import .Code: CSEState, Let, Assignment, Func, MakeArray,
MakeSparseArray, AtIndex, DestructuredArgs, SpawnFetch,
LiteralExpr, BasicSymbolic, ForLoop, SetArray, MakeTuple,
lhs, rhs
include("matmuladd.jl")

# Updated mul5_cse2 that uses the rule system
function mul5_cse2(expr, state::CSEState)

# Try to apply optimization rules
optimized = apply_optimization_rules(expr, state)
if optimized !== nothing
return optimized
end

# If no optimization applied, return original expression
return expr
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matmuladd.jl should be part of code.jl, alongside all other codegen aspects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code.jl is set up as a module in a file, and it would be better to refactor it as a subpackage and not bloat the file too much

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't include("matmuladd.jl") just be moved to code.jl? The source doesn't need to be inlined.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


PrecompileTools.@recompile_invalidations begin
include("despecialize.jl")
Expand Down
6 changes: 4 additions & 2 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, a
symtype, sorted_arguments, metadata, isterm, term, maketerm, unwrap_const,
ArgsT, Const, SymVariant, _is_array_of_symbolics, _is_tuple_of_symbolics,
ArrayOp, isarrayop, IdxToAxesT, ROArgsT, shape, Unknown, ShapeVecT, BSImpl,
search_variables!, _is_index_variable, RangesT, IDXS_SYM, is_array_shape
search_variables!, _is_index_variable, RangesT, IDXS_SYM, is_array_shape,
vartype, symtype
using Moshi.Match: @match
import SymbolicIndexingInterface: symbolic_type, NotSymbolic

Expand Down Expand Up @@ -492,6 +493,8 @@ Func
toexpr_kw(f, st) = Expr(:kw, toexpr(f, st).args...)

function toexpr(f::Func, st)
# @show st
# @show f.args
funkyargs = get_rewrites(vcat(f.args, map(lhs, f.kwargs)))
union_rewrites!(st.rewrites, funkyargs)
dargs = filter(x->x isa DestructuredArgs, f.args)
Expand Down Expand Up @@ -1019,7 +1022,6 @@ function cse!(expr::BasicSymbolic{T}, state::CSEState) where {T}
return sym
end
end

end
end

Expand Down
219 changes: 219 additions & 0 deletions src/matmuladd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Pattern-based optimization templates for CSE
struct OptimizationRule
name::String
detector::Function
transformer::Function
priority::Int
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not parametrize detector and transformer?


function find_cse_expr(x, state)
idx = findfirst(y -> nameof(lhs(y)) == nameof(x), state.sorted_exprs)
isnothing(idx) ? nothing : (; expr = rhs(state.sorted_exprs[idx]), x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to return a NamedTuple with x if x is an input to the function? I'm pretty sure returning Any is better than returning @NamedTuple{expr::Any, x::BasicSymbolic{T})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can certainly be removed for now, it was part of the symbol counting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is removed

end

function is_cse_var(x)
startswith(string(nameof(x)), "##cse")
end

function validate_mul_shapes(A, B, C)
[shape(A)[1], shape(B)[2]] == shape(C)
end

function validate_mul_shapes(A, B, C...)
return true
[shape(A)[1], shape(B)[2]] == shape(first(C))
end

function detect_matmul_add_pattern(expr::Code.Let, state::Code.CSEState)
mul_candidates_idx = findall(expr.pairs) do x
iscall(rhs(x)) || return false
args = arguments(rhs(x))
all_arrays = all(y -> y <: AbstractArray, symtype.(args))
is_mul = operation(rhs(x)) === *
all_arrays && is_mul
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be phrased as operation(rhs(x)) === * || return false for an early-exit, since checking symtypes is relatively more expensive. Also, for array multiplication with a scalar coefficient (i.e. k * A * B * C for scalar k) is phrased as a Term with * as the operation and the scalar coefficient (which may itself be an expression) as the first argument. The check here will miss this case.

EDIT: Actually, this entire condition could just be isterm(x) && operation(x) === * && symtype(x) <: AbstractArray.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do, I wasn't sure how robust type propagation is with expressions or if we will be better served by ensuring all the arguments are explicitly arrays. Since we didn't handle scalars I wanted to explicitly exit for those cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct though? Can we assume that the optimizations we are running will be applicable for when the symtype is the only thing we are relying on? For example, we don't generate the broadcast/ allocation statement for when C is a scalar. Is it better to be a bit conservative for the first iteration?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. If it helps, multiplying a bunch of scalars and arrays is guaranteed to result in a Term for which args[1] is the scalar coefficient. If the coefficient is absent, then args[1] is just the first array.

end
mul_candidates = expr.pairs[mul_candidates_idx]

plus_candidates_idx = findall(expr.pairs) do x
iscall(rhs(x)) || return false
args = arguments(rhs(x))
all_arrays = all(y -> y <: AbstractArray, symtype.(args))
is_plus = operation(rhs(x)) === +
all_arrays && is_plus
end
Comment on lines 73 to 79
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check could be isadd(x) && symtype(x) <: AbstractArray

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I seem to get isadd(x) = false and symtype(x) = Assignment for every x. operation(x) = Assignment as well. So it seems like I will need to operate on the rhs at the least.

plus_candidates = expr.pairs[plus_candidates_idx]

mul_vals = lhs.(mul_candidates)
candidates = map(plus_candidates_idx, plus_candidates) do p_idx, p
map(mul_candidates_idx, mul_vals) do m_idx, m_v
if nameof(m_v) in nameof.(arguments(rhs(p)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to check nameof? Why not any(isequal(m_v), arguments(rhs(p)))? Additionally, pargs = Set(arguments(rhs(p)) outside the inner map can potentially save a lot of time.

nameof is also not always a valid operation.

(m_idx, m_v) => (p_idx, expr.pairs[p_idx])
end
end
end
candidates = filter(!isnothing, reduce(vcat, candidates))
Comment on lines 82 to 90
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, this entire block seems to be dead code? The result is not used anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this block can be removed, there is a potentially faster and better check which maps the candidates to the Assignments, but that is not implemented yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this block is now used, and the previous pattern matcher has been removed


pattern = map(plus_candidates_idx, plus_candidates) do plus_idx, c
plus_args = arguments(rhs(c))
mul_pattern = map(mul_candidates_idx, mul_candidates) do mul_idx, m
mul_val = lhs(m)

if nameof(mul_val) in nameof.(plus_args)
A, B = arguments(rhs(m))
Cs = filter(x -> nameof(x) != nameof(mul_val), plus_args)
validate_mul_shapes(A, B, Cs...) || return nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would really avoid doing this honestly. SU has a bunch of code validating shapes and this is just a lot of unnecessary work due to the array splat.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can avoid this. If you notice validate_mul_shapes currently short circuits back to true because the shape information suddenly stopped being available here. Not sure why. This would still need to do some validation to check for eltypes to allocate buffers appropriately.

return (; A, B, Cs, mul_candidate = m, plus_candidate = c, mul_idx, plus_idx, pattern="A*B + C")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a struct instead of a NamedTuple? Easier to document and less likely to miss or typo a field. Additionally, what role does pattern serve here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed a simple way to ID what optimization was detected when we have multiple passes assuming there is a fairly generic implementation for a code transformation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now a struct

end
end
filter(!isnothing, mul_pattern)
end
isempty(pattern) ? nothing : pattern
end

function transform_to_mul5_assignment(expr, match_data_, state::Code.CSEState)
Cset = Set(filter(!is_cse_var, reduce(vcat,getproperty.(match_data_, :Cs))))
plus_candidates_idx = getproperty.(match_data_, :plus_idx)

final_temps = []

m_ = map(match_data_) do match_data

A, B = match_data.A, match_data.B
C = pop!(Cset)
T = vartype(C)
Comment on lines 110 to 119
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function transform_to_mul5_assignment(expr, match_data_, state::Code.CSEState)
Cset = Set(filter(!is_cse_var, reduce(vcat,getproperty.(match_data_, :Cs))))
plus_candidates_idx = getproperty.(match_data_, :plus_idx)
final_temps = []
m_ = map(match_data_) do match_data
A, B = match_data.A, match_data.B
C = pop!(Cset)
T = vartype(C)
function transform_to_mul5_assignment(expr::BasicSymbolic{T}, match_data_, state::Code.CSEState) where {T}
Cset = Set(filter(!is_cse_var, reduce(vcat,getproperty.(match_data_, :Cs))))
plus_candidates_idx = getproperty.(match_data_, :plus_idx)
final_temps = BasicSymbolic{T}[]
m_ = map(match_data_) do match_data
A, B = match_data.A, match_data.B
C = pop!(Cset)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe expr::Let


# Create temporary variable for the result
temp_var_sym = gensym("mul5_temp")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not so sure of the gensym. This creates unnecessary entries in the hashconsing cache. Could this instead increment a counter outside this loop?

temp_var = Sym{T}(temp_var_sym; type=symtype(C))

copy_call = Term{T}(copy, [C]; type=symtype(C))
mul_call = Term{T}(LinearAlgebra.mul!,
[temp_var, A, B, Const{T}(1), Const{T}(1)];
type=symtype(C))
Comment on lines 124 to 137
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively since we're not running CSE on the result of this pass, all of these can pass unsafe = true as a kwarg to avoid hashconsing.


# Add assignments to CSE state
copy_assignment = Assignment(temp_var, copy_call)
mul_assignment = Assignment(temp_var, mul_call) # This overwrites temp_var with mul! result
final_assignment = Assignment(temp_var, temp_var)
push!(final_temps, temp_var)

[copy_assignment, mul_assignment, final_assignment]
end
m = m_ |> Base.Fix1(reduce, vcat)

transformed_idxs = getproperty.(match_data_, :plus_idx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
transformed_idxs = getproperty.(match_data_, :plus_idx)
transformed_idxs = plus_candidates_idx

No need to compute the same thing twice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

substitution_map = get_substitution_map(match_data_, m_)
rm_idxs = getproperty.(match_data_, :mul_idx)
transformations = Dict()
map(transformed_idxs, m_) do i, mm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, could this and related loops just iterate over match_data_ and access the relevant field? It avoids a lot of unnecessary allocations.

bank(transformations, i, mm)
end

new_pairs = []
for (i, e) in enumerate(expr.pairs)
if i in transformed_idxs
push!(new_pairs, transformations[i]...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
push!(new_pairs, transformations[i]...)
append!(new_pairs, transformations[i])

@show e
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@show e

elseif i in rm_idxs
push!(new_pairs, nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why push! here only to filter it out later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I needed to keep track of whether I had made any mistakes reconstructing the pairs. This will be removed before we are ready to merge

else
push!(new_pairs, e)
end
end
new_pairs = filter(!isnothing, new_pairs)

push!(state.sorted_exprs, m...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
push!(state.sorted_exprs, m...)
append!(state.sorted_exprs, m)

temp_var = last(m).lhs
new_let = Code.Let(new_pairs, expr.body, expr.let_block)
apply_substitution_map(new_let, substitution_map)
end

function get_substitution_map(match_data, transformations)
dic = Dict()
@assert length(match_data) == length(transformations)

plus_idxs = getproperty.(match_data, :plus_idx)

map(match_data, transformations) do m, t
bank(dic, m.plus_candidate.lhs, t[end].lhs)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use a for loop instead of a map. The latter does a bunch of extra work and allocates unnecessarily. plus_idxs also doesn't need to exist. It is unused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough

dic
end

function bank(dic, key, value)
if haskey(dic, key)
dic[key] = vcat(dic[key], value)
else
dic[key] = value
Comment on lines +188 to +190
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dic[key] = vcat(dic[key], value)
else
dic[key] = value
push!(dic[key], value)
else
dic[key] = copy(value)

Allocates once instead of every time a key is pushed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how that helps? At the time of insertion, we don't actually know if there is a N-to-1 map trivially in this function. And assuming that will make us always do some extra legwork on the codegen side.

end
end

function apply_substitution_map(expr::Code.Let, substitution_map::Dict)
substitute_in_ir(expr, substitution_map)
end

function substitute_in_ir(s::Symbol, substitution_map::Dict)
get(substitution_map, s, s)
end

function substitute_in_ir_base(s, substitution_map::Dict)
if haskey(substitution_map, s)
v = substitution_map[s]
if issym(v)
v
else
+(v...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a list of expressions, right?

Suggested change
+(v...)
SymbolicUtils.add_worker(vartype(first(v)), v...)

Even better if substitution_map is concretely typed and the vartype can be lifted from there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one still needs to be addressed. Splatting is probably dynamic dispatch here, and will run the suggested replacement anyway. I'm assuming v is an indexable collection of symbolics/numbers/arrays of numbers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done now

end
else
s
end
end

function substitute_in_ir(expr, substitution_map::Dict)
if iscall(expr)
new_args = map(arguments(expr)) do arg
substitute_in_ir(arg, substitution_map)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is substitute_in_ir not just SymbolicUtils.substitute? You can even cache the substituter with subs = Substituter{false}(substitution_map) and just call map(subs, arguments(expr)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Best I tried, SymbolicUtils.substitute didn't operate on IR, it worked on BasicSymbolic exprs directly. This can handle it on IR code.

end
return Code.Term{Code.vartype(expr)}(operation(expr), new_args; type=Code.symtype(expr))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return Code.Term{Code.vartype(expr)}(operation(expr), new_args; type=Code.symtype(expr))
return Term{vartype(expr)}(operation(expr), new_args; type=symtype(expr))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also done

elseif issym(expr)
substitute_in_ir_base(expr, substitution_map)
else
expr
end
end

function substitute_in_ir(x::Code.Assignment, substitution_map::Dict)
new_lhs = substitute_in_ir(Code.lhs(x), substitution_map)
new_rhs = substitute_in_ir(Code.rhs(x), substitution_map)
return Code.Assignment(new_lhs, new_rhs)
end

function substitute_in_ir(expr::Code.Let, substitution_map::Dict)
isempty(substitution_map) && return expr

new_pairs = map(expr.pairs) do p
substitute_in_ir(p, substitution_map)
end
new_body = substitute_in_ir(expr.body, substitution_map)
return Code.Let(new_pairs, new_body, expr.let_block)
end

const MATMUL_ADD_RULE = OptimizationRule(
"MatMul+Add",
detect_matmul_add_pattern,
transform_to_mul5_assignment,
10
)

Base.isempty(l::Code.Let) = isempty(l.pairs)

# Apply optimization rules during CSE
function apply_optimization_rules(expr, state::Code.CSEState, rules=[MATMUL_ADD_RULE])
for rule in sort(rules, by=r->r.priority, rev=true)
match_data = reduce(vcat, rule.detector(expr, state))
if match_data !== nothing # || !isempty(match_data)
return rule.transformer(expr, match_data, state)
end
end
return nothing
end
13 changes: 12 additions & 1 deletion src/simplify_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,19 @@ const NUMBER_SIMPLIFIER = RestartedChain((
If(is_operation(^), Chain(POW_RULES)),
))

is_array(x) = symtype(x) <: AbstractArray
const ARRAY_RULES = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule isn't exactly correct, and is untested. This won't result in the expression LinearAlgebra.mul!(copy(c), a, b, 1, 1) as I think you intended. Instead, it will error since copy is not implemented for BasicSymbolic. Disregarding that, mul! will trace since it is also not registered. Even if it were registered, simplify does not take codegen into consideration. It is simply a set of rules that transform a declarative expression into an equivalent heuristically "simpler" declarative expression.

Copy link
Contributor Author

@DhairyaLGandhi DhairyaLGandhi Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs Symbolics.jl#dg/arr which adds the @register_array_symbolic rules, and actually does the transformation. Test is fair though I am not sure about how to go about it since it needs the Symbolics branch first, so perhaps this can be a part of a follow on once that is released?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The registration can be done manually in methods.jl. It doesn't really need the macro. But this doesn't address the fact that this rule is conflating codegen with symbolic representation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that, we can get rid of it then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done fwiw

@rule( ~a::is_array * ~b::is_array + ~c::is_array => begin
tmp = copy(~c)
LinearAlgebra.mul!(tmp, ~a, ~b, 1, 1)
end),
)

const TRIG_EXP_SIMPLIFIER = Chain(TRIG_EXP_RULES)

const BOOLEAN_SIMPLIFIER = Chain(BOOLEAN_RULES)

const ARRAY_SIMPLIFIER = Chain(ARRAY_RULES)

function get_default_simplifier(; kw...)
IfElse(has_trig_exp,
Expand All @@ -159,7 +168,9 @@ function get_default_simplifier(; kw...)
Postwalk(Chain((If(x->symtype(x) <: Number,
NUMBER_SIMPLIFIER),
If(x->symtype(x) <: Bool,
BOOLEAN_SIMPLIFIER)))
BOOLEAN_SIMPLIFIER),
If(x -> symtype(x) <: AbstractArray,
ARRAY_SIMPLIFIER)))
; kw...))
end

Expand Down
72 changes: 72 additions & 0 deletions test/mul5_opt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using SymbolicUtils
using SymbolicUtils.Code
import SymbolicUtils as SU
using LinearAlgebra
using Test

# Helper function to check if optimization was applied
function has_mul5_optimization(ir)
if ir isa Code.Let
return any(ir.pairs) do assignment
rhs_expr = Code.rhs(assignment)
if SU.iscall(rhs_expr)
op = SU.operation(rhs_expr)
return op === LinearAlgebra.mul!
end
false
end
end
return false
end

# Helper function to build and evaluate both versions
function test_optimization(expr, args...)
cse_ir = SU.Code.cse(expr)
state = SU.Code.CSEState()
optimized_ir = SU.mul5_cse2(cse_ir, state)

# Check if optimization was applied
has_optimization = has_mul5_optimization(optimized_ir)
@test has_optimization

f_cse_expr = Func(collect(args), [], cse_ir)
f_cse = eval(toexpr(f_cse_expr))

f_opt_expr = Func(collect(args), [], optimized_ir)
f_opt = eval(toexpr(f_opt_expr))

test_A = randn(3, 3)
test_B = randn(3, 3)
test_C = randn(3, 3)
test_D = randn(3, 3)

# Get concrete test args
test_args = if length(args) == 3
(test_A, test_B, test_C)
else
(test_A, test_B, test_C, test_D)
end

# Evaluate both versions
result_cse = invokelatest(f_cse, test_args...)
result_opt = invokelatest(f_opt, test_args...)

# Assert correctness
@test isapprox(result_cse, result_opt, rtol=1e-10)
end

@testset "Mul5 Optimization Tests" begin
@syms A[1:3, 1:3] B[1:3, 1:3] C[1:3, 1:3] D[1:3, 1:3]

expr1 = A * B + C
test_optimization(expr1, A, B, C)

expr2 = A * B + C + D
test_optimization(expr2, A, B, C, D)

expr4 = A * B + C + D + C * D # multiple correct patterns
test_optimization(expr4, A, B, C, D)

expr5 = sin.(A * B + C + D + C * D)
test_optimization(expr5, A, B, C, D)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try

P = A + B
Q = B + C
R = C / D
P * Q + R

Does this correctly turn into mul!(R, P, Q, 1, 1)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the basic answer is we currently filter out some variables in our detection to find the C.

@@ -67,13 +74,17 @@ function detect_matmul_add_pattern(expr::Code.Let, state::Code.CSEState)
-    Cset = Set(filter(!is_cse_var, reduce(vcat,getproperty.(match_data_, :Cs))))
+    Cset = Set(Iterators.flatten(getproperty.(match_data_, :Cs)))

I need a simplified version which correctly identifies which Cs are valid for our transformation.

Fwiw with this applied:

:(function (A, B, C, D)
      #= /home/dhairyalgandhi/arpa/jsmo/clone/Symbolics.jl/src/build_function.jl:147 =# @inbounds begin
              #= /home/dhairyalgandhi/arpa/jsmo/clone/Symbolics.jl/src/build_function.jl:147 =#
              begin
                  #= /home/dhairyalgandhi/arpa/jsmo/clone/SymbolicUtils.jl/src/code.jl:507 =#
                  #= /home/dhairyalgandhi/arpa/jsmo/clone/SymbolicUtils.jl/src/code.jl:508 =#
                  #= /home/dhairyalgandhi/arpa/jsmo/clone/SymbolicUtils.jl/src/code.jl:509 =#
                  begin
                      var"##cse#1" = (/)(C, D)
                      var"##cse#2" = (+)(B, A)
                      var"##cse#3" = (+)(B, C)
                      var"##mul5_temp#554" = (copy)(var"##cse#1")
                      var"##mul5_temp#554" = (mul!)(var"##mul5_temp#554", var"##cse#2", var"##cse#3", 1, 1)
                      var"##mul5_temp#554" = var"##mul5_temp#554"
                      var"##mul5_temp#554"
                  end
              end
          end
  end)

end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,8 @@ using Pkg, Test, SafeTestsets
@safetestset "Recursive utilities" begin include("recursive_utils.jl") end
@safetestset "Misc" begin include("misc.jl") end
@safetestset "Method library" begin include("methods.jl") end

# Optimization
@safetestset "MatmulAdd Optimization" begin include("mul5_opt.jl") end
end
end
Loading