-
Notifications
You must be signed in to change notification settings - Fork 124
Add 5-arg mul! Detection
#805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
mul! Detection
Benchmark Results (Julia vlts)Time benchmarks
Memory benchmarks
|
AayushSabharwal
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the broad strokes of implementation are here, but I have a few overarching concerns:
- There's a lot that can be refactored in the core codegen to make it more efficient and type-stable. That's not within the scope of this PR, but I feel that we should try and ensure anything we're adding now at least attempts to be as type-stable as reasonably possible. A lot of the data structures here can be more concretely typed. This also helps convey what they store/represent when passed around.
- In a similar vein, there are a significant number of extra allocations that can be avoided without much effort.
- Could the functions (and some core data structures) be documented to make this easier to refactor in the future? It's fine if AI/Claude writes the docs, as long as they're human-reviewed and the commits are marked as being written by AI.
src/matmuladd.jl
Outdated
| iscall(rhs(x)) || return false | ||
| args = arguments(rhs(x)) | ||
| all_arrays = all(y -> y <: AbstractArray, symtype.(args)) | ||
| is_mul = operation(rhs(x)) === * | ||
| all_arrays && is_mul |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be phrased as operation(rhs(x)) === * || return false for an early-exit, since checking symtypes is relatively more expensive. Also, for array multiplication with a scalar coefficient (i.e. k * A * B * C for scalar k) is phrased as a Term with * as the operation and the scalar coefficient (which may itself be an expression) as the first argument. The check here will miss this case.
EDIT: Actually, this entire condition could just be isterm(x) && operation(x) === * && symtype(x) <: AbstractArray.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do, I wasn't sure how robust type propagation is with expressions or if we will be better served by ensuring all the arguments are explicitly arrays. Since we didn't handle scalars I wanted to explicitly exit for those cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct though? Can we assume that the optimizations we are running will be applicable for when the symtype is the only thing we are relying on? For example, we don't generate the broadcast/ allocation statement for when C is a scalar. Is it better to be a bit conservative for the first iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. If it helps, multiplying a bunch of scalars and arrays is guaranteed to result in a Term for which args[1] is the scalar coefficient. If the coefficient is absent, then args[1] is just the first array.
| plus_candidates_idx = findall(expr.pairs) do x | ||
| iscall(rhs(x)) || return false | ||
| args = arguments(rhs(x)) | ||
| all_arrays = all(y -> y <: AbstractArray, symtype.(args)) | ||
| is_plus = operation(rhs(x)) === + | ||
| all_arrays && is_plus | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check could be isadd(x) && symtype(x) <: AbstractArray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I seem to get isadd(x) = false and symtype(x) = Assignment for every x. operation(x) = Assignment as well. So it seems like I will need to operate on the rhs at the least.
src/matmuladd.jl
Outdated
| mul_vals = lhs.(mul_candidates) | ||
| candidates = map(plus_candidates_idx, plus_candidates) do p_idx, p | ||
| map(mul_candidates_idx, mul_vals) do m_idx, m_v | ||
| if nameof(m_v) in nameof.(arguments(rhs(p))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need to check nameof? Why not any(isequal(m_v), arguments(rhs(p)))? Additionally, pargs = Set(arguments(rhs(p)) outside the inner map can potentially save a lot of time.
nameof is also not always a valid operation.
| candidates = map(plus_candidates_idx, plus_candidates) do p_idx, p | ||
| map(mul_candidates_idx, mul_vals) do m_idx, m_v | ||
| if nameof(m_v) in nameof.(arguments(rhs(p))) | ||
| (m_idx, m_v) => (p_idx, expr.pairs[p_idx]) | ||
| end | ||
| end | ||
| end | ||
| candidates = filter(!isnothing, reduce(vcat, candidates)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, this entire block seems to be dead code? The result is not used anywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this block can be removed, there is a potentially faster and better check which maps the candidates to the Assignments, but that is not implemented yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this block is now used, and the previous pattern matcher has been removed
| dic[key] = vcat(dic[key], value) | ||
| else | ||
| dic[key] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| dic[key] = vcat(dic[key], value) | |
| else | |
| dic[key] = value | |
| push!(dic[key], value) | |
| else | |
| dic[key] = copy(value) |
Allocates once instead of every time a key is pushed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see how that helps? At the time of insertion, we don't actually know if there is a N-to-1 map trivially in this function. And assuming that will make us always do some extra legwork on the codegen side.
src/matmuladd.jl
Outdated
|
|
||
| function find_cse_expr(x, state) | ||
| idx = findfirst(y -> nameof(lhs(y)) == nameof(x), state.sorted_exprs) | ||
| isnothing(idx) ? nothing : (; expr = rhs(state.sorted_exprs[idx]), x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need to return a NamedTuple with x if x is an input to the function? I'm pretty sure returning Any is better than returning @NamedTuple{expr::Any, x::BasicSymbolic{T})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can certainly be removed for now, it was part of the symbol counting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is removed
src/matmuladd.jl
Outdated
| A, B = arguments(rhs(m)) | ||
| Cs = filter(x -> nameof(x) != nameof(mul_val), plus_args) | ||
| validate_mul_shapes(A, B, Cs...) || return nothing | ||
| return (; A, B, Cs, mul_candidate = m, plus_candidate = c, mul_idx, plus_idx, pattern="A*B + C") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be a struct instead of a NamedTuple? Easier to document and less likely to miss or typo a field. Additionally, what role does pattern serve here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I needed a simple way to ID what optimization was detected when we have multiple passes assuming there is a fairly generic implementation for a code transformation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now a struct
src/simplify_rules.jl
Outdated
| )) | ||
|
|
||
| is_array(x) = symtype(x) <: AbstractArray | ||
| const ARRAY_RULES = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This rule isn't exactly correct, and is untested. This won't result in the expression LinearAlgebra.mul!(copy(c), a, b, 1, 1) as I think you intended. Instead, it will error since copy is not implemented for BasicSymbolic. Disregarding that, mul! will trace since it is also not registered. Even if it were registered, simplify does not take codegen into consideration. It is simply a set of rules that transform a declarative expression into an equivalent heuristically "simpler" declarative expression.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs Symbolics.jl#dg/arr which adds the @register_array_symbolic rules, and actually does the transformation. Test is fair though I am not sure about how to go about it since it needs the Symbolics branch first, so perhaps this can be a part of a follow on once that is released?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The registration can be done manually in methods.jl. It doesn't really need the macro. But this doesn't address the fact that this rule is conflating codegen with symbolic representation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that, we can get rid of it then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done fwiw
src/SymbolicUtils.jl
Outdated
| import .Code | ||
| import .Code: CSEState, Let, Assignment, Func, MakeArray, | ||
| MakeSparseArray, AtIndex, DestructuredArgs, SpawnFetch, | ||
| LiteralExpr, BasicSymbolic, ForLoop, SetArray, MakeTuple, | ||
| lhs, rhs | ||
| include("matmuladd.jl") | ||
|
|
||
| # Updated mul5_cse2 that uses the rule system | ||
| function mul5_cse2(expr, state::CSEState) | ||
|
|
||
| # Try to apply optimization rules | ||
| optimized = apply_optimization_rules(expr, state) | ||
| if optimized !== nothing | ||
| return optimized | ||
| end | ||
|
|
||
| # If no optimization applied, return original expression | ||
| return expr | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
matmuladd.jl should be part of code.jl, alongside all other codegen aspects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code.jl is set up as a module in a file, and it would be better to refactor it as a subpackage and not bloat the file too much
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't include("matmuladd.jl") just be moved to code.jl? The source doesn't need to be inlined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| test_optimization(expr4, A, B, C, D) | ||
|
|
||
| expr5 = sin.(A * B + C + D + C * D) | ||
| test_optimization(expr5, A, B, C, D) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try
P = A + B
Q = B + C
R = C / D
P * Q + RDoes this correctly turn into mul!(R, P, Q, 1, 1)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the basic answer is we currently filter out some variables in our detection to find the C.
@@ -67,13 +74,17 @@ function detect_matmul_add_pattern(expr::Code.Let, state::Code.CSEState)
- Cset = Set(filter(!is_cse_var, reduce(vcat,getproperty.(match_data_, :Cs))))
+ Cset = Set(Iterators.flatten(getproperty.(match_data_, :Cs)))I need a simplified version which correctly identifies which Cs are valid for our transformation.
Fwiw with this applied:
:(function (A, B, C, D)
#= /home/dhairyalgandhi/arpa/jsmo/clone/Symbolics.jl/src/build_function.jl:147 =# @inbounds begin
#= /home/dhairyalgandhi/arpa/jsmo/clone/Symbolics.jl/src/build_function.jl:147 =#
begin
#= /home/dhairyalgandhi/arpa/jsmo/clone/SymbolicUtils.jl/src/code.jl:507 =#
#= /home/dhairyalgandhi/arpa/jsmo/clone/SymbolicUtils.jl/src/code.jl:508 =#
#= /home/dhairyalgandhi/arpa/jsmo/clone/SymbolicUtils.jl/src/code.jl:509 =#
begin
var"##cse#1" = (/)(C, D)
var"##cse#2" = (+)(B, A)
var"##cse#3" = (+)(B, C)
var"##mul5_temp#554" = (copy)(var"##cse#1")
var"##mul5_temp#554" = (mul!)(var"##mul5_temp#554", var"##cse#2", var"##cse#3", 1, 1)
var"##mul5_temp#554" = var"##mul5_temp#554"
var"##mul5_temp#554"
end
end
end
end)|
Re: function f!(buf, ...)
...
end
function f(...)
buf = similar(...)
f!(buf, ...)
buf
endThe value in hoisting allocations in these cases is also only when |
No description provided.