diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 05ea4678..8c296876 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -236,14 +236,14 @@ function benchmark_utilities() [get_set_constants!(ex) for ex in exs], seconds = 10.0, setup = ( - operators = $operators; - ntrees = 100; - n = 20; - n_features = 5; - n_params = 3; - n_param_classes = 10; - rng = Random.MersenneTwister(0); - exs = [ + operators=($operators); + ntrees=100; + n=20; + n_features=5; + n_params=3; + n_param_classes=10; + rng=Random.MersenneTwister(0); + exs=[ let tree = gen_random_tree_fixed_size( n, operators, n_features, Float32, ParametricNode, rng ) diff --git a/ext/DynamicExpressionsSymbolicUtilsExt.jl b/ext/DynamicExpressionsSymbolicUtilsExt.jl index 30a987d3..33ae789f 100644 --- a/ext/DynamicExpressionsSymbolicUtilsExt.jl +++ b/ext/DynamicExpressionsSymbolicUtilsExt.jl @@ -90,7 +90,7 @@ function split_eq( op, args, operators::AbstractOperatorEnum, - ::Type{N}=Node; + (::Type{N})=Node; variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, # Deprecated: varMap=nothing, @@ -255,7 +255,7 @@ end function symbolic_to_node( eqn::SymbolicUtils.Symbolic, operators::AbstractOperatorEnum, - ::Type{N}=Node; + (::Type{N})=Node; variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, # Deprecated: varMap=nothing, diff --git a/ext/DynamicExpressionsZygoteExt.jl b/ext/DynamicExpressionsZygoteExt.jl index 5654c27e..41cad035 100644 --- a/ext/DynamicExpressionsZygoteExt.jl +++ b/ext/DynamicExpressionsZygoteExt.jl @@ -6,7 +6,7 @@ import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient, ZygoteGrad function _zygote_gradient(op::F, ::Val{1}) where {F} return ZygoteGradient{F,1,1}(op) end -function _zygote_gradient(op::F, ::Val{2}, ::Val{side}=Val(nothing)) where {F,side} +function _zygote_gradient(op::F, ::Val{2}, (::Val{side})=Val(nothing)) where {F,side} # side should be either nothing (for both), 1, or 2 @assert side === nothing || side in (1, 2) return ZygoteGradient{F,2,side}(op) diff --git a/src/EvaluationHelpers.jl b/src/EvaluationHelpers.jl index 950445e7..2faa37a0 100644 --- a/src/EvaluationHelpers.jl +++ b/src/EvaluationHelpers.jl @@ -86,7 +86,8 @@ to every constant in the expression. - `(evaluation, gradient, complete)::Tuple{AbstractVector{T}, AbstractMatrix{T}, Bool}`: the normal evaluation, the gradient, and whether the evaluation completed as normal (or encountered a nan or inf). """ -Base.adjoint(tree::AbstractExpressionNode) = +function Base.adjoint(tree::AbstractExpressionNode) ((args...; kws...) -> _grad_evaluator(tree, args...; kws...)) +end end diff --git a/src/Expression.jl b/src/Expression.jl index 9e7325a6..e81a5f56 100644 --- a/src/Expression.jl +++ b/src/Expression.jl @@ -520,7 +520,7 @@ end function copy_into!(::Nothing, src::AbstractExpression) return copy(src) end -function allocate_container(::AbstractExpression, ::Union{Nothing,Integer}=nothing) +function allocate_container(::AbstractExpression, (::Union{Nothing,Integer})=nothing) return nothing end # COV_EXCL_STOP diff --git a/src/ExpressionAlgebra.jl b/src/ExpressionAlgebra.jl index f5dcd7b4..944bec42 100644 --- a/src/ExpressionAlgebra.jl +++ b/src/ExpressionAlgebra.jl @@ -107,32 +107,28 @@ the operator is unary (1) or binary (2). macro declare_expression_operator(op, arity) @assert arity ∈ (1, 2) if arity == 1 - return esc( - quote - $op(l::AbstractExpression) = $(apply_operator)($op, l) - end, - ) + return esc(quote + $op(l::AbstractExpression) = $(apply_operator)($op, l) + end) elseif arity == 2 - return esc( - quote - function $op(l::AbstractExpression, r::AbstractExpression) - return $(apply_operator)($op, l, r) - end - function $op(l::T, r::AbstractExpression{T}) where {T} - return $(apply_operator)($op, l, r) - end - function $op(l::AbstractExpression{T}, r::T) where {T} - return $(apply_operator)($op, l, r) - end - # Convenience methods for Number types - function $op(l::Number, r::AbstractExpression{T}) where {T} - return $(apply_operator)($op, l, r) - end - function $op(l::AbstractExpression{T}, r::Number) where {T} - return $(apply_operator)($op, l, r) - end - end, - ) + return esc(quote + function $op(l::AbstractExpression, r::AbstractExpression) + return $(apply_operator)($op, l, r) + end + function $op(l::T, r::AbstractExpression{T}) where {T} + return $(apply_operator)($op, l, r) + end + function $op(l::AbstractExpression{T}, r::T) where {T} + return $(apply_operator)($op, l, r) + end + # Convenience methods for Number types + function $op(l::Number, r::AbstractExpression{T}) where {T} + return $(apply_operator)($op, l, r) + end + function $op(l::AbstractExpression{T}, r::Number) where {T} + return $(apply_operator)($op, l, r) + end + end) end end diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index 6e18c418..5de462c3 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -94,7 +94,7 @@ given the output of this function. Also return metadata that can will be used in the `set_scalar_constants!` function. """ function get_scalar_constants( - tree::AbstractExpressionNode{T}, ::Type{BT}=get_number_type(T) + tree::AbstractExpressionNode{T}, (::Type{BT})=get_number_type(T) ) where {T,BT} refs = filter_map( is_node_constant, node -> Ref(node), tree, Base.RefValue{typeof(tree)} @@ -160,7 +160,7 @@ end # as we trace over the node we are indexing on. preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false -function index_constant_nodes(tree::AbstractExpressionNode, ::Type{T}=UInt16) where {T} +function index_constant_nodes(tree::AbstractExpressionNode, (::Type{T})=UInt16) where {T} # Essentially we copy the tree, replacing the values # with indices constant_index = Ref(T(0)) diff --git a/src/OperatorEnumConstruction.jl b/src/OperatorEnumConstruction.jl index 96b84d00..9a2ed005 100644 --- a/src/OperatorEnumConstruction.jl +++ b/src/OperatorEnumConstruction.jl @@ -378,14 +378,12 @@ defined. macro extend_operators(operators, kws...) ex = _extend_operators(operators, false, kws, __module__) expected_type = AbstractOperatorEnum - return esc( - quote - if !isa($(operators), $expected_type) - error("You must pass an operator enum to `@extend_operators`.") - end - $ex - end, - ) + return esc(quote + if !isa($(operators), $expected_type) + error("You must pass an operator enum to `@extend_operators`.") + end + $ex + end) end """ @@ -399,14 +397,12 @@ and `internal` which is default `false`. macro extend_operators_base(operators, kws...) ex = _extend_operators(operators, true, kws, __module__) expected_type = AbstractOperatorEnum - return esc( - quote - if !isa($(operators), $expected_type) - error("You must pass an operator enum to `@extend_operators_base`.") - end - $ex - end, - ) + return esc(quote + if !isa($(operators), $expected_type) + error("You must pass an operator enum to `@extend_operators_base`.") + end + $ex + end) end """ diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 854e28d7..272560a3 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -302,7 +302,7 @@ function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T} elseif leaf.is_parameter Node(T; feature=leaf.parameter) else - Node(T; feature=leaf.feature + num_params) + Node(T; feature=(leaf.feature + num_params)) end, branch -> branch.op, (op, children...) -> Node(; op, children), diff --git a/src/Parse.jl b/src/Parse.jl index 10d121db..6fe64785 100644 --- a/src/Parse.jl +++ b/src/Parse.jl @@ -95,13 +95,13 @@ macro parse_expression(ex, kws...) return esc( :($(parse_expression)( $(Meta.quot(ex)); - operators=$(parsed_kws.operators), + operators=($(parsed_kws.operators)), binary_operators=nothing, unary_operators=nothing, - variable_names=$(parsed_kws.variable_names), - node_type=$(parsed_kws.node_type), - expression_type=$(parsed_kws.expression_type), - evaluate_on=$(parsed_kws.evaluate_on), + variable_names=($(parsed_kws.variable_names)), + node_type=($(parsed_kws.node_type)), + expression_type=($(parsed_kws.expression_type)), + evaluate_on=($(parsed_kws.evaluate_on)), $(parsed_kws.extra_metadata)..., )), ) @@ -188,8 +188,8 @@ end "You must specify the operators using either `operators`, or `binary_operators` and `unary_operators`" ) operators = :($(OperatorEnum)(; - binary_operators=$(binops === nothing ? :(Function[]) : binops), - unary_operators=$(unaops === nothing ? :(Function[]) : unaops), + binary_operators=($(binops === nothing ? :(Function[]) : binops)), + unary_operators=($(unaops === nothing ? :(Function[]) : unaops)), )) else @assert (binops === nothing && unaops === nothing) diff --git a/src/Random.jl b/src/Random.jl index bc3b546b..0e10b4b0 100644 --- a/src/Random.jl +++ b/src/Random.jl @@ -42,8 +42,9 @@ end Sample a node from a tree according to the default sampler `NodeSampler(; tree)`. """ -rand(rng::AbstractRNG, tree::Union{AbstractNode,AbstractExpression}) = +function rand(rng::AbstractRNG, tree::Union{AbstractNode,AbstractExpression}) rand(rng, NodeSampler(; tree)) +end """ rand(rng::AbstractRNG, sampler::NodeSampler) diff --git a/src/precompile.jl b/src/precompile.jl index d16bc6b7..16c7043c 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -1,17 +1,15 @@ import PrecompileTools: @compile_workload, @setup_workload macro ignore_domain_error(ex) - return esc( - quote - try - $ex - catch e - if !(e isa DomainError) - rethrow(e) - end + return esc(quote + try + $ex + catch e + if !(e isa DomainError) + rethrow(e) end - end, - ) + end + end) end """ @@ -21,8 +19,7 @@ Test all combinations of the given operators and types. Useful for precompilatio """ function test_all_combinations(; binary_operators, unary_operators, turbo, types) for binops in binary_operators, - unaops in unary_operators, - use_turbo in turbo, + unaops in unary_operators, use_turbo in turbo, T in types length(binops) == 0 && length(unaops) == 0 && continue diff --git a/test/test_deprecations.jl b/test/test_deprecations.jl index fc554a6a..29ecc672 100644 --- a/test/test_deprecations.jl +++ b/test/test_deprecations.jl @@ -24,23 +24,23 @@ end if VERSION >= v"1.9" @test_logs (:warn, r"Node\(d, c, v\) is deprecated.*") ( - n = Node(1, true, 1.0 + 0im); @assert (n.val isa ComplexF64) + n=Node(1, true, 1.0 + 0im); @assert (n.val isa ComplexF64) ) @test_logs (:warn, r"Node\(T, d, c, v\) is deprecated.*") ( - n = Node(Float32, 1, true, 1.0 + 0im); @assert (n.val isa Float32) + n=Node(Float32, 1, true, 1.0 + 0im); @assert (n.val isa Float32) ) @test_logs (:warn, r"Node\(T, d, c, v, f\) is deprecated.*") ( - n = Node(Float32, 1, false, nothing, 1); @assert (n.feature == 1) + n=Node(Float32, 1, false, nothing, 1); @assert (n.feature == 1) ) @test_logs (:warn, r"Node\(d, c, v, f, o, l\) is deprecated.*") ( - x1 = Node(; feature=1); - n = Node(1, true, nothing, 1, 3, x1); + x1=Node(; feature=1); + n=Node(1, true, nothing, 1, 3, x1); @assert (n.op == 3 && n.l === x1) ) @test_logs (:warn, r"Node\(d, c, v, f, o, l, r\) is deprecated.*") ( - x1 = Node(; feature=1); - x2 = Node(; feature=2); - n = Node(2, true, nothing, 1, 1, x1, x2); + x1=Node(; feature=1); + x2=Node(; feature=2); + n=Node(2, true, nothing, 1, 1, x1, x2); @assert (n.op == 1 && n.l === x1 && n.r === x2) ) end diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index f744bdf5..6f1e60e0 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -103,24 +103,27 @@ end @test repr(tree) == "cos(cos(3.0))" tree = convert(Node{T}, tree) truth = cos(cos(T(3.0f0))) - @test DynamicExpressions.EvaluateModule.deg1_l1_ll0_eval(tree, [zero(T)]', cos, cos, EvalOptions(; turbo)).x[1] ≈ - truth + @test DynamicExpressions.EvaluateModule.deg1_l1_ll0_eval( + tree, [zero(T)]', cos, cos, EvalOptions(; turbo) + ).x[1] ≈ truth # op(, ) tree = Node(1, Node(; val=3.0f0), Node(; val=4.0f0)) @test repr(tree) == "3.0 + 4.0" tree = convert(Node{T}, tree) truth = T(3.0f0) + T(4.0f0) - @test DynamicExpressions.EvaluateModule.deg2_l0_r0_eval(tree, [zero(T)]', (+), EvalOptions(; turbo)).x[1] ≈ - truth + @test DynamicExpressions.EvaluateModule.deg2_l0_r0_eval( + tree, [zero(T)]', (+), EvalOptions(; turbo) + ).x[1] ≈ truth # op(op(, )) tree = Node(1, Node(1, Node(; val=3.0f0), Node(; val=4.0f0))) @test repr(tree) == "cos(3.0 + 4.0)" tree = convert(Node{T}, tree) truth = cos(T(3.0f0) + T(4.0f0)) - @test DynamicExpressions.EvaluateModule.deg1_l2_ll0_lr0_eval(tree, [zero(T)]', cos, (+), EvalOptions(; turbo)).x[1] ≈ - truth + @test DynamicExpressions.EvaluateModule.deg1_l2_ll0_lr0_eval( + tree, [zero(T)]', cos, (+), EvalOptions(; turbo) + ).x[1] ≈ truth # Test for presence of NaNs: operators = OperatorEnum(; diff --git a/test/test_extra_node_fields.jl b/test/test_extra_node_fields.jl index 467c6226..c1326b9a 100644 --- a/test/test_extra_node_fields.jl +++ b/test/test_extra_node_fields.jl @@ -78,8 +78,8 @@ m.frozen = !m.frozen @test n != m # Try out an interface for freezing parts of an expression -freeze!(n) = (n.frozen = true; n) -thaw!(n) = (n.frozen = false; n) +freeze!(n) = (n.frozen=true; n) +thaw!(n) = (n.frozen=false; n) ex = parse_expression( :(x + $freeze!(sin($thaw!(y + 2.1))));