Skip to content

Commit 9c28360

Browse files
committed
feat: assume :terminates_locally for most loops
1 parent b705e46 commit 9c28360

File tree

6 files changed

+35
-26
lines changed

6 files changed

+35
-26
lines changed

src/Evaluate.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import ..NodeModule:
66
AbstractExpressionNode, constructorof, get_children, get_child, with_type_parameters
77
import ..StringsModule: string_tree
88
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
9-
import ..UtilsModule: fill_similar, counttuple, ResultOk
9+
import ..UtilsModule: fill_similar, counttuple, ResultOk, @finite
1010
import ..NodeUtilsModule: is_constant
1111
import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization_loaded
1212
import ..ValueInterfaceModule: is_valid, is_valid_array
@@ -295,7 +295,7 @@ end
295295
# Fast general implementation of `cumulators[1] .= op.(cumulators[1], cumulators[2], ...)`
296296
quote
297297
Base.Cartesian.@nexprs($N, i -> cumulator_i = cumulators[i])
298-
@inbounds @simd for j in eachindex(cumulator_1)
298+
@finite @inbounds @simd for j in eachindex(cumulator_1)
299299
cumulator_1[j] = Base.Cartesian.@ncall($N, op, i -> cumulator_i[j])::T
300300
end # COV_EXCL_LINE
301301
return ResultOk(cumulator_1, true)
@@ -581,7 +581,7 @@ function deg1_l2_ll0_lr0_eval(
581581
@return_on_nonfinite_val(eval_options, val_ll, cX)
582582
feature_lr = get_child(get_child(tree, 1), 2).feature
583583
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
584-
@inbounds @simd for j in axes(cX, 2)
584+
@finite @inbounds @simd for j in axes(cX, 2)
585585
x_l = op_l(val_ll, cX[feature_lr, j])::T
586586
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
587587
cumulator[j] = x
@@ -592,7 +592,7 @@ function deg1_l2_ll0_lr0_eval(
592592
val_lr = get_child(get_child(tree, 1), 2).val
593593
@return_on_nonfinite_val(eval_options, val_lr, cX)
594594
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
595-
@inbounds @simd for j in axes(cX, 2)
595+
@finite @inbounds @simd for j in axes(cX, 2)
596596
x_l = op_l(cX[feature_ll, j], val_lr)::T
597597
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
598598
cumulator[j] = x
@@ -602,7 +602,7 @@ function deg1_l2_ll0_lr0_eval(
602602
feature_ll = get_child(get_child(tree, 1), 1).feature
603603
feature_lr = get_child(get_child(tree, 1), 2).feature
604604
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
605-
@inbounds @simd for j in axes(cX, 2)
605+
@finite @inbounds @simd for j in axes(cX, 2)
606606
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T
607607
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
608608
cumulator[j] = x
@@ -630,7 +630,7 @@ function deg1_l1_ll0_eval(
630630
else
631631
feature_ll = get_child(get_child(tree, 1), 1).feature
632632
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
633-
@inbounds @simd for j in axes(cX, 2)
633+
@finite @inbounds @simd for j in axes(cX, 2)
634634
x_l = op_l(cX[feature_ll, j])::T
635635
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
636636
cumulator[j] = x
@@ -659,7 +659,7 @@ function deg2_l0_r0_eval(
659659
val_l = get_child(tree, 1).val
660660
@return_on_nonfinite_val(eval_options, val_l, cX)
661661
feature_r = get_child(tree, 2).feature
662-
@inbounds @simd for j in axes(cX, 2)
662+
@finite @inbounds @simd for j in axes(cX, 2)
663663
x = op(val_l, cX[feature_r, j])::T
664664
cumulator[j] = x
665665
end # COV_EXCL_LINE
@@ -669,7 +669,7 @@ function deg2_l0_r0_eval(
669669
feature_l = get_child(tree, 1).feature
670670
val_r = get_child(tree, 2).val
671671
@return_on_nonfinite_val(eval_options, val_r, cX)
672-
@inbounds @simd for j in axes(cX, 2)
672+
@finite @inbounds @simd for j in axes(cX, 2)
673673
x = op(cX[feature_l, j], val_r)::T
674674
cumulator[j] = x
675675
end # COV_EXCL_LINE
@@ -678,7 +678,7 @@ function deg2_l0_r0_eval(
678678
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
679679
feature_l = get_child(tree, 1).feature
680680
feature_r = get_child(tree, 2).feature
681-
@inbounds @simd for j in axes(cX, 2)
681+
@finite @inbounds @simd for j in axes(cX, 2)
682682
x = op(cX[feature_l, j], cX[feature_r, j])::T
683683
cumulator[j] = x
684684
end # COV_EXCL_LINE
@@ -697,14 +697,14 @@ function deg2_l0_eval(
697697
if get_child(tree, 1).constant
698698
val = get_child(tree, 1).val
699699
@return_on_nonfinite_val(eval_options, val, cX)
700-
@inbounds @simd for j in eachindex(cumulator)
700+
@finite @inbounds @simd for j in eachindex(cumulator)
701701
x = op(val, cumulator[j])::T
702702
cumulator[j] = x
703703
end # COV_EXCL_LINE
704704
return ResultOk(cumulator, true)
705705
else
706706
feature = get_child(tree, 1).feature
707-
@inbounds @simd for j in eachindex(cumulator)
707+
@finite @inbounds @simd for j in eachindex(cumulator)
708708
x = op(cX[feature, j], cumulator[j])::T
709709
cumulator[j] = x
710710
end # COV_EXCL_LINE
@@ -723,14 +723,14 @@ function deg2_r0_eval(
723723
if get_child(tree, 2).constant
724724
val = get_child(tree, 2).val
725725
@return_on_nonfinite_val(eval_options, val, cX)
726-
@inbounds @simd for j in eachindex(cumulator)
726+
@finite @inbounds @simd for j in eachindex(cumulator)
727727
x = op(cumulator[j], val)::T
728728
cumulator[j] = x
729729
end # COV_EXCL_LINE
730730
return ResultOk(cumulator, true)
731731
else
732732
feature = get_child(tree, 2).feature
733-
@inbounds @simd for j in eachindex(cumulator)
733+
@finite @inbounds @simd for j in eachindex(cumulator)
734734
x = op(cumulator[j], cX[feature, j])::T
735735
cumulator[j] = x
736736
end # COV_EXCL_LINE

src/EvaluateDerivative.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module EvaluateDerivativeModule
22

33
import ..NodeModule: AbstractExpressionNode, constructorof, get_children
44
import ..OperatorEnumModule: OperatorEnum
5-
import ..UtilsModule: fill_similar, ResultOk2
5+
import ..UtilsModule: fill_similar, ResultOk2, @finite
66
import ..ValueInterfaceModule: is_valid_array
77
import ..NodeUtilsModule: count_constant_nodes, index_constant_nodes, NodeIndex
88
import ..EvaluateModule:
@@ -105,7 +105,7 @@ end
105105
dx_cumulator_i = dx_cumulators[i]
106106
end)
107107
diff_op = _zygote_gradient(op, Val(N))
108-
@inbounds @simd for j in eachindex(x_cumulator_1)
108+
@finite @inbounds @simd for j in eachindex(x_cumulator_1)
109109
x = Base.Cartesian.@ncall($N, op, i -> x_cumulator_i[j])
110110
Base.Cartesian.@ntuple($N, i -> grad_i) = Base.Cartesian.@ncall(
111111
$N, diff_op, i -> x_cumulator_i[j]
@@ -346,7 +346,7 @@ end
346346
d_cumulator_i = d_cumulators[i]
347347
end)
348348
diff_op = _zygote_gradient(op, Val($N))
349-
@inbounds @simd for j in eachindex(x_cumulator_1)
349+
@finite @inbounds @simd for j in eachindex(x_cumulator_1)
350350
x = Base.Cartesian.@ncall($N, op, i -> x_cumulator_i[j])
351351
Base.Cartesian.@ntuple($N, i -> grad_i) = Base.Cartesian.@ncall(
352352
$N, diff_op, i -> x_cumulator_i[j]

src/NodeUtils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module NodeUtilsModule
22

3-
using ..UtilsModule: Nullable
3+
using ..UtilsModule: Nullable, @finite
44

55
import ..NodeModule:
66
AbstractNode,
@@ -108,7 +108,7 @@ function get_scalar_constants(
108108
else
109109
vals = Vector{BT}(undef, count_scalar_constants(tree))
110110
i = firstindex(vals)
111-
for ref in refs
111+
@finite for ref in refs
112112
i = pack_scalar_constants!(vals, i, ref[].val::T)
113113
end
114114
return vals, refs
@@ -123,13 +123,13 @@ Set the constants in a tree, in depth-first order. The function
123123
"""
124124
function set_scalar_constants!(tree::AbstractExpressionNode{T}, constants, refs) where {T}
125125
if T <: Number
126-
@inbounds for i in eachindex(refs, constants)
126+
@finite @inbounds for i in eachindex(refs, constants)
127127
refs[i][].val = constants[i]
128128
end
129129
else
130130
nums_i = 1
131131
refs_i = 1
132-
while nums_i <= length(constants) && refs_i <= length(refs)
132+
@finite while nums_i <= length(constants) && refs_i <= length(refs)
133133
ix, v = unpack_scalar_constants(constants, nums_i, refs[refs_i][].val::T)
134134
refs[refs_i][].val = v
135135
nums_i = ix

src/ParametricExpression.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
88
using ..ExpressionModule:
99
AbstractExpression, Metadata, with_contents, with_metadata, unpack_metadata
1010
using ..ChainRulesModule: NodeTangent
11-
using ..UtilsModule: Nullable, set_nan!
11+
using ..UtilsModule: Nullable, set_nan!, @finite
1212

1313
import ..NodeModule:
1414
constructorof,
@@ -241,15 +241,15 @@ function _get_constants_array(parameter_refs, ::Type{BT}) where {BT}
241241
size = sum(count_scalar_constants, parameter_refs)
242242
flat = Vector{BT}(undef, size)
243243
ix = 1
244-
for p in parameter_refs
244+
@finite for p in parameter_refs
245245
ix = pack_scalar_constants!(flat, ix, p)
246246
end
247247
return flat
248248
end
249249

250250
function _set_constants_array!(parameter_refs, flat)
251251
ix, i = 1, 1
252-
while ix <= length(flat) && i <= length(parameter_refs)
252+
@finite while ix <= length(flat) && i <= length(parameter_refs)
253253
ix, parameter_refs[i] = unpack_scalar_constants(flat, ix, parameter_refs[i])
254254
i += 1
255255
end
@@ -378,7 +378,7 @@ function eval_tree_array(
378378
@assert length(classes) == size(X, 2)
379379
@assert maximum(classes) <= size(ex.metadata.parameters, 2) # TODO: Remove when comfortable
380380
parameters = ex.metadata.parameters
381-
indexed_parameters = [
381+
indexed_parameters = @finite [
382382
parameters[i_parameter, classes[i_row]] for
383383
i_parameter in eachindex(axes(parameters, 1)), i_row in eachindex(classes)
384384
]

src/Strings.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module StringsModule
22

3-
using ..UtilsModule: deprecate_varmap
3+
using ..UtilsModule: deprecate_varmap, @finite
44
using ..OperatorEnumModule: AbstractOperatorEnum
55
using ..NodeModule: AbstractExpressionNode, tree_mapreduce, max_degree
66

@@ -122,7 +122,7 @@ function combine_op_with_inputs(op, args::Vararg{Any,D})::Vector{Char} where {D}
122122
# "op(l, r)"
123123
out = copy(op)
124124
push!(out, '(')
125-
for i in 1:(D - 1)
125+
@finite for i in 1:(D - 1)
126126
append!(out, strip_brackets(args[i]))
127127
push!(out, ',')
128128
push!(out, ' ')

src/Utils.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,13 @@ function set_nan!(out)
7575
return nothing
7676
end
7777

78+
"""
79+
@finite ex
80+
81+
Wraps `ex` in a `Base.@assume_effects :terminates_locally` block.
82+
"""
83+
macro finite(ex)
84+
return esc(:(Base.@assume_effects :terminates_locally $ex))
85+
end
86+
7887
end

0 commit comments

Comments
 (0)