Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions test/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,99 @@ function test_matvec_jump_matrix_times_jump_vector_gradient()
return
end

# Pre-fills `reverse_storage` with a sentinel, runs `eval_objective_gradient`,
# then returns the evaluator's objective expression so callers can inspect
# whether specific tape ranges were touched. The reverse pass only initializes
# the root slot, so any slot left at the sentinel value was never written.
function _grad_with_sentinel(model, loss, x_in)
mode = ArrayDiff.Mode{Vector{Float64}}()
ad = ArrayDiff.model(mode)
MOI.Nonlinear.set_objective(ad, JuMP.moi_function(loss))
evaluator = MOI.Nonlinear.Evaluator(
ad,
mode,
JuMP.index.(JuMP.all_variables(model)),
)
MOI.initialize(evaluator, [:Grad])
expr = evaluator.backend.objective.expr
sentinel = 1.2345e10
fill!(expr.reverse_storage, sentinel)
g = zero(x_in)
MOI.eval_objective_gradient(evaluator, g, x_in)
return expr, g, sentinel
end

function _assert_value_block_slots_preserved(expr, sentinel)
const_ks =
findall(node -> node.type == ArrayDiff.NODE_VALUE_BLOCK, expr.nodes)
@test !isempty(const_ks)
for k in const_ks
rng = ArrayDiff._storage_range(expr.sizes, k)
@test all(==(sentinel), expr.reverse_storage[rng])
end
return
end

# `W * X` with `X` a constant matrix. `_matmul_reverse!` must skip the gradient
# w.r.t. `X`, leaving its `reverse_storage` slot at the pre-filled sentinel.
function test_matmul_constant_rhs_skips_reverse_write()
m, n = 2, 3
X = [0.4 -0.2 0.1; -0.3 0.5 0.2; 0.1 0.1 -0.4]
W_val = [0.6 -0.3 0.4; -0.1 0.5 0.2]
model = Model()
@variable(model, W[1:m, 1:n], container = ArrayDiff.ArrayOfVariables)
loss = sum((W * X) .^ 2)
expr, g, sentinel = _grad_with_sentinel(model, loss, vec(W_val))
@test g ≈ vec(2 * (W_val * X) * X')
_assert_value_block_slots_preserved(expr, sentinel)
return
end

# `X * W` with `X` a constant matrix. Same check on the left-operand branch.
function test_matmul_constant_lhs_skips_reverse_write()
a, b, c = 2, 3, 4
X = [0.4 -0.2 0.1; -0.3 0.5 0.2]
W_val = [
0.1 -0.2 0.3 0.4
0.5 -0.1 0.2 -0.3
-0.4 0.1 0.5 0.2
]
model = Model()
@variable(model, W[1:b, 1:c], container = ArrayDiff.ArrayOfVariables)
loss = sum((X * W) .^ 2)
expr, g, sentinel = _grad_with_sentinel(model, loss, vec(W_val))
@test g ≈ vec(2 * X' * (X * W_val))
_assert_value_block_slots_preserved(expr, sentinel)
return
end

# Negative control: with `W1 * W2` (both variables) the matmul reverse must
# write to both operand slots, so neither retains the sentinel. This guards
# against a regression where the skip predicate fires for non-constant nodes.
function test_matmul_both_variables_overwrites_reverse()
m, n, p = 2, 3, 4
W1_val = [0.4 -0.2 0.1; -0.3 0.5 0.2]
W2_val = [
0.1 -0.2 0.3 0.4
0.5 -0.1 0.2 -0.3
-0.4 0.1 0.5 0.2
]
model = Model()
@variable(model, W1[1:m, 1:n], container = ArrayDiff.ArrayOfVariables)
@variable(model, W2[1:n, 1:p], container = ArrayDiff.ArrayOfVariables)
loss = sum((W1 * W2) .^ 2)
x_in = [vec(W1_val); vec(W2_val)]
expr, _, sentinel = _grad_with_sentinel(model, loss, x_in)
var_ks =
findall(node -> node.type == ArrayDiff.NODE_VARIABLE_BLOCK, expr.nodes)
@test length(var_ks) == 2
for k in var_ks
rng = ArrayDiff._storage_range(expr.sizes, k)
@test !any(==(sentinel), expr.reverse_storage[rng])
end
return
end

end # module

TestJuMP.runtests()
Loading