Skip to content

Commit 0a380ed

Browse files
committed
Enables full (nested) associative-commutative pattern matching for + and * operators.
1 parent aab293a commit 0a380ed

File tree

3 files changed

+59
-8
lines changed

3 files changed

+59
-8
lines changed

src/matchers.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66
# 3. Callback: takes arguments Dictionary × Number of elements matched
77
#
88
function matcher(val::Any)
9-
iscall(val) && return term_matcher(val)
9+
matcher(val, false)
10+
end
11+
12+
# `fullac_flag == true` enables fully nested associative-commutative pattern matching
13+
function matcher(val::Any, fullac_flag)
14+
iscall(val) && return term_matcher(val, fullac_flag)
1015
function literal_matcher(next, data, bindings)
1116
islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing
1217
end
1318
end
1419

15-
function matcher(slot::Slot)
20+
function matcher(slot::Slot, fullac_flag) # fullac_flag unused but needed to keep the interface uniform
1621
function slot_matcher(next, data, bindings)
1722
!islist(data) && return
1823
val = get(bindings, slot.name, nothing)
@@ -56,7 +61,7 @@ function trymatchexpr(data, value, n)
5661
end
5762
end
5863

59-
function matcher(segment::Segment)
64+
function matcher(segment::Segment, fullac_flag) # fullac_flag unused but needed to keep the interface uniform
6065
function segment_matcher(success, data, bindings)
6166
val = get(bindings, segment.name, nothing)
6267

@@ -84,8 +89,8 @@ function matcher(segment::Segment)
8489
end
8590
end
8691

87-
function term_matcher(term)
88-
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
92+
function term_matcher(term, fullac_flag = false)
93+
matchers = (matcher(operation(term), fullac_flag), map(a -> matcher(a, fullac_flag), arguments(term))...,)
8994
function term_matcher(success, data, bindings)
9095

9196
!islist(data) && return nothing
@@ -103,6 +108,23 @@ function term_matcher(term)
103108
end
104109
end
105110

106-
loop(car(data), bindings, matchers) # Try to eat exactly one term
111+
if !(fullac_flag && iscall(term) && operation(term) in ((+), (*)))
112+
loop(car(data), bindings, matchers) # Try to eat exactly one term
113+
else # try all permutations of `car(data)` to see if a match is possible
114+
data1 = car(data)
115+
args = arguments(data1)
116+
op = operation(data1)
117+
data_arg_perms = permutations(args)
118+
result = nothing
119+
T = symtype(data)
120+
for perm in data_arg_perms
121+
data_permuted = Term{T}(op, perm)
122+
result = loop(data_permuted, bindings, matchers) # Try to eat exactly one term
123+
if !(result isa Nothing)
124+
break
125+
end
126+
end
127+
return result
128+
end
107129
end
108130
end

src/rule.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,30 @@ whether the predicate holds or not.
297297
298298
_In the consequent pattern_: Use `(@ctx)` to access the context object on the right hand side
299299
of an expression.
300+
301+
**Full (nested) associative-commutative matching**:
302+
303+
@rule LHS => RHS fullac
304+
305+
creates a rule that fully respects associative-commutative (AC) operations. Unlike `@acrule LHS => RHS` which only considers AC properties of the top-level function, here we impose AC properties on all subexpressions.
306+
307+
```
308+
julia> @syms a b;
309+
310+
julia> r = @rule ~a + ~a*~b => ~a * (1+~b) fullac;
311+
312+
julia> r(b + a*b)
313+
(1 + a)*b
314+
315+
```
300316
"""
301-
macro rule(expr)
317+
macro rule(expr, option...)
302318
@assert expr.head == :call && expr.args[1] == :(=>)
319+
fullac = false
320+
if length(option) > 0
321+
@assert option[1] == :fullac "@rule only accepts one option `fullac` after the rule itself"
322+
fullac = true
323+
end
303324
lhs = expr.args[2]
304325
rhs = rewrite_rhs(expr.args[3])
305326
keys = Symbol[]
@@ -310,7 +331,7 @@ macro rule(expr)
310331
lhs_pattern = $(lhs_term)
311332
Rule($(QuoteNode(expr)),
312333
lhs_pattern,
313-
matcher(lhs_pattern),
334+
matcher(lhs_pattern, $fullac),
314335
__MATCHES__ -> $(makeconsequent(rhs)),
315336
rule_depth($lhs_term))
316337
end

test/rewrite.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ end
4343
@eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, [])
4444
end
4545

46+
@testset "Full associative-commutative matching" begin
47+
@eqtest (@rule ~a + ~a*~b => ~a * (1+~b) fullac)(a + a*b) == a * (1+b)
48+
@eqtest (@rule ~a + ~a*~b => ~a * (1+~b) fullac)(b + a*b) == b * (1+a) # fails with @acrule
49+
@eqtest (@rule ~a*~b + ~a => ~a * (1+~b) fullac)(b + a*b) == b * (1+a) # fails with @acrule
50+
@eqtest (@rule ~a*~b + ~a*~c => ~a * (~b+~c) fullac)(a*b + a*c) == a * (b+c)
51+
@eqtest (@rule ~a*~b + ~a*~c => ~a * (~b+~c) fullac)(a*b + b*c) == b * (a+c) # fails with @acrule
52+
end
53+
4654
using SymbolicUtils: @capture
4755

4856
@testset "Capture form" begin

0 commit comments

Comments
 (0)