|
6 | 6 |
|
7 | 7 | from __future__ import annotations |
8 | 8 |
|
9 | | -from collections import Counter |
10 | | - |
11 | 9 | from egglog import * |
12 | 10 |
|
13 | 11 |
|
14 | 12 | class Math(Expr): |
15 | 13 | def __init__(self, x: i64Like) -> None: ... |
| 14 | + def __add__(self, other: MathLike) -> Math: ... |
| 15 | + def __radd__(self, other: MathLike) -> Math: ... |
| 16 | + def __mul__(self, other: MathLike) -> Math: ... |
| 17 | + def __rmul__(self, other: MathLike) -> Math: ... |
16 | 18 |
|
17 | 19 |
|
18 | | -@function |
19 | | -def square(x: Math) -> Math: ... |
20 | | - |
21 | | - |
22 | | -@ruleset |
23 | | -def math_ruleset(i: i64): |
24 | | - yield rewrite(square(Math(i))).to(Math(i * i)) |
25 | | - |
| 20 | +MathLike = Math | i64Like |
| 21 | +converter(i64, Math, Math) |
26 | 22 |
|
27 | | -egraph = EGraph() |
28 | 23 |
|
29 | | -xs = MultiSet(Math(1), Math(2), Math(3)) |
30 | | -egraph.register(xs) |
| 24 | +@function |
| 25 | +def sum(xs: MultiSetLike[Math, MathLike]) -> Math: ... |
31 | 26 |
|
32 | | -egraph.check(xs == MultiSet(Math(1), Math(3), Math(2))) |
33 | | -egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3))) |
34 | 27 |
|
35 | | -assert Counter(egraph.extract(xs).value) == Counter({Math(1): 1, Math(2): 1, Math(3): 1}) |
| 28 | +@function |
| 29 | +def product(xs: MultiSetLike[Math, MathLike]) -> Math: ... |
36 | 30 |
|
37 | 31 |
|
38 | | -inserted = MultiSet(Math(1), Math(2), Math(3), Math(4)) |
39 | | -egraph.register(inserted) |
40 | | -egraph.check(xs.insert(Math(4)) == inserted) |
41 | | -egraph.check(xs.contains(Math(1))) |
42 | | -egraph.check(xs.not_contains(Math(4))) |
43 | | -assert Math(1) in xs |
44 | | -assert Math(4) not in xs |
| 32 | +@function |
| 33 | +def square(x: Math) -> Math: ... |
45 | 34 |
|
46 | | -egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3))) |
47 | 35 |
|
48 | | -assert egraph.extract(xs.length()).value == 3 |
49 | | -assert len(xs) == 3 |
| 36 | +x = constant("x", Math) |
| 37 | +expr1 = 2 * (x + 3) |
| 38 | +expr2 = 6 + 2 * x |
50 | 39 |
|
51 | | -egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2)) |
52 | 40 |
|
53 | | -egraph.check(MultiSet(Math(1)).pick() == Math(1)) |
| 41 | +@ruleset |
| 42 | +def math_ruleset(a: Math, b: Math, c: Math, i: i64, j: i64, xs: MultiSet[Math], ys: MultiSet[Math], zs: MultiSet[Math]): |
| 43 | + yield rewrite(a + b).to(sum(MultiSet(a, b))) |
| 44 | + yield rewrite(a * b).to(product(MultiSet(a, b))) |
| 45 | + # 0 or 1 elements sums/products also can be extracted back to numbers |
| 46 | + yield rule(a == sum(xs), xs.length() == i64(1)).then(a == xs.pick()) |
| 47 | + yield rule(a == product(xs), xs.length() == i64(1)).then(a == xs.pick()) |
| 48 | + yield rewrite(sum(MultiSet[Math]())).to(Math(0)) |
| 49 | + yield rewrite(product(MultiSet[Math]())).to(Math(1)) |
| 50 | + # distributive rule (a * (b + c) = a*b + a*c) |
| 51 | + yield rule( |
| 52 | + b == product(ys), |
| 53 | + a == sum(xs), |
| 54 | + ys.contains(a), |
| 55 | + ys.length() > 1, |
| 56 | + zs == ys.remove(a), |
| 57 | + ).then( |
| 58 | + b == sum(xs.map(lambda x: product(zs.insert(x)))), |
| 59 | + ) |
| 60 | + # constants |
| 61 | + yield rule( |
| 62 | + a == sum(xs), |
| 63 | + b == Math(i), |
| 64 | + xs.contains(b), |
| 65 | + ys == xs.remove(b), |
| 66 | + c == Math(j), |
| 67 | + ys.contains(c), |
| 68 | + ).then( |
| 69 | + a == sum(ys.remove(c).insert(Math(i + j))), |
| 70 | + ) |
| 71 | + yield rule( |
| 72 | + a == product(xs), |
| 73 | + b == Math(i), |
| 74 | + xs.contains(b), |
| 75 | + ys == xs.remove(b), |
| 76 | + c == Math(j), |
| 77 | + ys.contains(c), |
| 78 | + ).then( |
| 79 | + a == product(ys.remove(c).insert(Math(i * j))), |
| 80 | + ) |
54 | 81 |
|
55 | | -mapped = xs.map(square) |
56 | | -egraph.register(mapped) |
57 | | -egraph.run(math_ruleset) |
58 | | -egraph.check(mapped == MultiSet(Math(1), Math(4), Math(9))) |
59 | 82 |
|
60 | | -egraph.check(xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3))) |
| 83 | +egraph = EGraph() |
| 84 | +egraph.register(expr1, expr2) |
| 85 | +egraph.run(math_ruleset.saturate()) |
| 86 | +egraph.check(expr1 == expr2) |
0 commit comments