Skip to content

Commit bada063

Browse files
committed
Define coeff function
1 parent b88fb66 commit bada063

File tree

5 files changed

+35
-31
lines changed

5 files changed

+35
-31
lines changed

src/SymbolicUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module SymbolicUtils
55

66
using DocStringExtensions
77

8-
export @syms, term, showraw, hasmetadata, getmetadata, setmetadata, name
8+
export @syms, term, showraw, hasmetadata, getmetadata, setmetadata, name, coeff
99

1010
using TermInterface
1111
using DataStructures

src/inspect.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ function AbstractTrees.nodevalue(x::BasicSymbolic)
1212
string(x.impl.val)
1313
elseif isadd(x)
1414
string(exprtype(x),
15-
(scalar = x.impl.coeff, coeffs = Tuple(k => v for (k, v) in x.impl.dict)))
15+
(scalar = coeff(x), coeffs = Tuple(k => v for (k, v) in x.impl.dict)))
1616
elseif ismul(x)
1717
string(exprtype(x),
18-
(scalar = x.impl.coeff, powers = Tuple(k => v for (k, v) in x.impl.dict)))
18+
(scalar = coeff(x), powers = Tuple(k => v for (k, v) in x.impl.dict)))
1919
elseif isdiv(x) || ispow(x)
2020
string(exprtype(x))
2121
else

src/polyform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ function quick_mulpow(x, y)
514514
den = _Pow(symtype(y), y.impl.base, y.impl.exp-d[y.impl.base])
515515
delete!(d, y.impl.base)
516516
end
517-
return _Mul(symtype(x), x.impl.coeff, d), den
517+
return _Mul(symtype(x), coeff(x), d), den
518518
else
519519
return x, y
520520
end

src/types.jl

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ function name(x::BasicSymbolic)
6868
x.impl.name
6969
end
7070

71+
function coeff(x::BasicSymbolic)
72+
x.impl.coeff
73+
end
74+
7175
# Same but different error messages
7276
@noinline error_on_type() = error("Internal error: unreachable reached!")
7377
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
@@ -293,7 +297,7 @@ function _isequal(a, b, E)
293297
if E === SYM
294298
nameof(a) === nameof(b)
295299
elseif E === ADD || E === MUL
296-
coeff_isequal(a.impl.coeff, b.impl.coeff) && isequal(a.impl.dict, b.impl.dict)
300+
coeff_isequal(coeff(a), coeff(b)) && isequal(a.impl.dict, b.impl.dict)
297301
elseif E === DIV
298302
isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den)
299303
elseif E === POW
@@ -337,7 +341,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
337341
h = s.hash[]
338342
!iszero(h) && return h
339343
hashoffset = isadd(s) ? ADD_SALT : SUB_SALT
340-
h′ = hash(hashoffset, hash(s.impl.coeff, hash(s.impl.dict, salt)))
344+
h′ = hash(hashoffset, hash(coeff(s), hash(s.impl.dict, salt)))
341345
s.hash[] = h′
342346
return h′
343347
elseif E === DIV
@@ -444,7 +448,7 @@ const Rat = Union{Rational, Integer}
444448

445449
function ratcoeff(x)
446450
if ismul(x)
447-
ratcoeff(x.impl.coeff)
451+
ratcoeff(coeff(x))
448452
elseif x isa Rat
449453
(true, x)
450454
else
@@ -455,7 +459,7 @@ ratio(x::Integer,y::Integer) = iszero(rem(x,y)) ? div(x,y) : x//y
455459
ratio(x::Rat,y::Rat) = x//y
456460
function maybe_intcoeff(x)
457461
if ismul(x)
458-
coeff = x.impl.coeff
462+
coeff = coeff(x)
459463
if coeff isa Rational && isone(denominator(coeff))
460464
_Mul(symtype(x), coeff.num, x.impl.dict; metadata = x.metadata)
461465
else
@@ -537,7 +541,7 @@ function toterm(t::BasicSymbolic{T}) where {T}
537541
return t
538542
elseif E === ADD || E === MUL
539543
args = BasicSymbolic[]
540-
push!(args, t.impl.coeff)
544+
push!(args, coeff(t))
541545
for (k, coeff) in t.impl.dict
542546
push!(
543547
args, coeff == 1 ? k : _Term(T, E === MUL ? (^) : (*), [_Const(coeff), k]))
@@ -562,7 +566,7 @@ function makeadd(sign, coeff, xs...)
562566
d = Dict{BasicSymbolic, Any}()
563567
for x in xs
564568
if isadd(x)
565-
coeff += x.impl.coeff
569+
coeff += coeff(x)
566570
_merge!(+, d, x.impl.dict, filter = _iszero)
567571
continue
568572
end
@@ -572,7 +576,7 @@ function makeadd(sign, coeff, xs...)
572576
end
573577
if ismul(x)
574578
k = _Mul(symtype(x), 1, x.impl.dict)
575-
v = sign * x.impl.coeff + get(d, k, 0)
579+
v = sign * coeff(x) + get(d, k, 0)
576580
else
577581
k = x
578582
v = sign + get(d, x, 0)
@@ -593,7 +597,7 @@ function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}())
593597
elseif x isa Number
594598
coeff *= x
595599
elseif ismul(x)
596-
coeff *= x.impl.coeff
600+
coeff *= coeff(x)
597601
_merge!(+, d, x.impl.dict, filter = _iszero)
598602
else
599603
v = 1 + get(d, x, 0)
@@ -1219,10 +1223,10 @@ function +(a::SN, b::SN)
12191223
!issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata
12201224
if isadd(a) && isadd(b)
12211225
return _Add(
1222-
add_t(a, b), a.impl.coeff + b.impl.coeff, _merge(+, a.impl.dict, b.impl.dict, filter = _iszero))
1226+
add_t(a, b), coeff(a) + coeff(b), _merge(+, a.impl.dict, b.impl.dict, filter = _iszero))
12231227
elseif isadd(a)
12241228
coeff, dict = makeadd(1, 0, b)
1225-
return _Add(add_t(a, b), a.impl.coeff + coeff, _merge(+, a.impl.dict, dict, filter = _iszero))
1229+
return _Add(add_t(a, b), coeff(a) + coeff, _merge(+, a.impl.dict, dict, filter = _iszero))
12261230
elseif isadd(b)
12271231
return b + a
12281232
end
@@ -1236,7 +1240,7 @@ function +(a::Number, b::SN)
12361240
!issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata
12371241
iszero(a) && return b
12381242
if isadd(b)
1239-
_Add(add_t(a, b), a + b.impl.coeff, b.impl.dict)
1243+
_Add(add_t(a, b), a + coeff(b), b.impl.dict)
12401244
else
12411245
_Add(add_t(a, b), makeadd(1, a, b)...)
12421246
end
@@ -1254,15 +1258,15 @@ function -(a::SN)
12541258
return term(-, a)
12551259
end
12561260
if isadd(a)
1257-
_Add(sub_t(a), -a.impl.coeff, mapvalues((_, v) -> -v, a.impl.dict))
1261+
_Add(sub_t(a), -coeff(a), mapvalues((_, v) -> -v, a.impl.dict))
12581262
else
12591263
_Add(sub_t(a), makeadd(-1, 0, a)...)
12601264
end
12611265
end
12621266
function -(a::SN, b::SN)
12631267
(!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b)
12641268
if isadd(a) && isadd(b)
1265-
_Add(sub_t(a, b), a.impl.coeff - b.impl.coeff, _merge(-, a.impl.dict, b.impl.dict, filter = _iszero))
1269+
_Add(sub_t(a, b), coeff(a) - coeff(b), _merge(-, a.impl.dict, b.impl.dict, filter = _iszero))
12661270
else
12671271
a + (-b)
12681272
end
@@ -1289,16 +1293,16 @@ function *(a::SN, b::SN)
12891293
elseif isdiv(b)
12901294
_Div(a * b.impl.num, b.impl.den)
12911295
elseif ismul(a) && ismul(b)
1292-
_Mul(mul_t(a, b), a.impl.coeff * b.impl.coeff,
1296+
_Mul(mul_t(a, b), coeff(a) * coeff(b),
12931297
_merge(+, a.impl.dict, b.impl.dict, filter = _iszero))
12941298
elseif ismul(a) && ispow(b)
12951299
if b.impl.exp isa Number
12961300
_Mul(mul_t(a, b),
1297-
a.impl.coeff,
1301+
coeff(a),
12981302
_merge(+, a.impl.dict, Base.ImmutableDict(b.impl.base => b.impl.exp),
12991303
filter = _iszero))
13001304
else
1301-
_Mul(mul_t(a, b), a.impl.coeff,
1305+
_Mul(mul_t(a, b), coeff(a),
13021306
_merge(+, a.impl.dict, Base.ImmutableDict(b => 1), filter = _iszero))
13031307
end
13041308
elseif ispow(a) && ismul(b)
@@ -1321,7 +1325,7 @@ function *(a::Number, b::SN)
13211325
elseif isone(-a) && isadd(b)
13221326
# -1(a+b) -> -a - b
13231327
T = promote_symtype(+, typeof(a), symtype(b))
1324-
_Add(T, b.impl.coeff * a,
1328+
_Add(T, coeff(b) * a,
13251329
Dict{BasicSymbolic, Any}(k => v * a for (k, v) in b.impl.dict))
13261330
else
13271331
_Mul(mul_t(a, b), makemul(a, b)...)
@@ -1346,7 +1350,7 @@ function ^(a::SN, b)
13461350
elseif b isa Number && b < 0
13471351
_Div(1, a^(-b))
13481352
elseif ismul(a) && b isa Number
1349-
coeff = unstable_pow(a.impl.coeff, b)
1353+
coeff = unstable_pow(coeff(a), b)
13501354
_Mul(promote_symtype(^, symtype(a), symtype(b)),
13511355
coeff, mapvalues((k, v) -> b * v, a.impl.dict))
13521356
else

test/basics.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -344,18 +344,18 @@ end
344344
@testset "div" begin
345345
@syms x::SafeReal y::Real
346346
@test issym((2x / 2y).impl.num)
347-
@test (2x / 3y).impl.num.impl.coeff == 2
348-
@test (2x / 3y).impl.den.impl.coeff == 3
349-
@test (2x / -3x).impl.num.impl.coeff == -2
350-
@test (2x / -3x).impl.den.impl.coeff == 3
351-
@test (2.5x / 3x).impl.num.impl.coeff == 2.5
352-
@test (2.5x / 3x).impl.den.impl.coeff == 3
353-
@test (x / 3x).impl.den.impl.coeff == 3
347+
@test coeff((2x / 3y).impl.num) == 2
348+
@test coeff((2x / 3y).impl.den) == 3
349+
@test coeff((2x / -3x).impl.num) == -2
350+
@test coeff((2x / -3x).impl.den) == 3
351+
@test coeff((2.5x / 3x).impl.num) == 2.5
352+
@test coeff((2.5x / 3x).impl.den) == 3
353+
@test coeff((x / 3x).impl.den) == 3
354354

355355
@syms x y
356356
@test issym((2x / 2y).impl.num)
357-
@test (2x / 3y).impl.num.impl.coeff == 2
358-
@test (2x / 3y).impl.den.impl.coeff == 3
357+
@test coeff((2x / 3y).impl.num) == 2
358+
@test coeff((2x / 3y).impl.den) == 3
359359
@test (2x / -3x) == -2 // 3
360360
@test (2.5x / 3x).impl.num == 2.5
361361
@test (2.5x / 3x).impl.den == 3

0 commit comments

Comments
 (0)