diff --git a/test/JuMP.jl b/test/JuMP.jl index 7c71562..8ab0620 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -1201,6 +1201,99 @@ function test_matvec_jump_matrix_times_jump_vector_gradient() return end +# Pre-fills `reverse_storage` with a sentinel, runs `eval_objective_gradient`, +# then returns the evaluator's objective expression so callers can inspect +# whether specific tape ranges were touched. The reverse pass only initializes +# the root slot, so any slot left at the sentinel value was never written. +function _grad_with_sentinel(model, loss, x_in) + mode = ArrayDiff.Mode{Vector{Float64}}() + ad = ArrayDiff.model(mode) + MOI.Nonlinear.set_objective(ad, JuMP.moi_function(loss)) + evaluator = MOI.Nonlinear.Evaluator( + ad, + mode, + JuMP.index.(JuMP.all_variables(model)), + ) + MOI.initialize(evaluator, [:Grad]) + expr = evaluator.backend.objective.expr + sentinel = 1.2345e10 + fill!(expr.reverse_storage, sentinel) + g = zero(x_in) + MOI.eval_objective_gradient(evaluator, g, x_in) + return expr, g, sentinel +end + +function _assert_value_block_slots_preserved(expr, sentinel) + const_ks = + findall(node -> node.type == ArrayDiff.NODE_VALUE_BLOCK, expr.nodes) + @test !isempty(const_ks) + for k in const_ks + rng = ArrayDiff._storage_range(expr.sizes, k) + @test all(==(sentinel), expr.reverse_storage[rng]) + end + return +end + +# `W * X` with `X` a constant matrix. `_matmul_reverse!` must skip the gradient +# w.r.t. `X`, leaving its `reverse_storage` slot at the pre-filled sentinel. +function test_matmul_constant_rhs_skips_reverse_write() + m, n = 2, 3 + X = [0.4 -0.2 0.1; -0.3 0.5 0.2; 0.1 0.1 -0.4] + W_val = [0.6 -0.3 0.4; -0.1 0.5 0.2] + model = Model() + @variable(model, W[1:m, 1:n], container = ArrayDiff.ArrayOfVariables) + loss = sum((W * X) .^ 2) + expr, g, sentinel = _grad_with_sentinel(model, loss, vec(W_val)) + @test g ≈ vec(2 * (W_val * X) * X') + _assert_value_block_slots_preserved(expr, sentinel) + return +end + +# `X * W` with `X` a constant matrix. Same check on the left-operand branch. +function test_matmul_constant_lhs_skips_reverse_write() + a, b, c = 2, 3, 4 + X = [0.4 -0.2 0.1; -0.3 0.5 0.2] + W_val = [ + 0.1 -0.2 0.3 0.4 + 0.5 -0.1 0.2 -0.3 + -0.4 0.1 0.5 0.2 + ] + model = Model() + @variable(model, W[1:b, 1:c], container = ArrayDiff.ArrayOfVariables) + loss = sum((X * W) .^ 2) + expr, g, sentinel = _grad_with_sentinel(model, loss, vec(W_val)) + @test g ≈ vec(2 * X' * (X * W_val)) + _assert_value_block_slots_preserved(expr, sentinel) + return +end + +# Negative control: with `W1 * W2` (both variables) the matmul reverse must +# write to both operand slots, so neither retains the sentinel. This guards +# against a regression where the skip predicate fires for non-constant nodes. +function test_matmul_both_variables_overwrites_reverse() + m, n, p = 2, 3, 4 + W1_val = [0.4 -0.2 0.1; -0.3 0.5 0.2] + W2_val = [ + 0.1 -0.2 0.3 0.4 + 0.5 -0.1 0.2 -0.3 + -0.4 0.1 0.5 0.2 + ] + model = Model() + @variable(model, W1[1:m, 1:n], container = ArrayDiff.ArrayOfVariables) + @variable(model, W2[1:n, 1:p], container = ArrayDiff.ArrayOfVariables) + loss = sum((W1 * W2) .^ 2) + x_in = [vec(W1_val); vec(W2_val)] + expr, _, sentinel = _grad_with_sentinel(model, loss, x_in) + var_ks = + findall(node -> node.type == ArrayDiff.NODE_VARIABLE_BLOCK, expr.nodes) + @test length(var_ks) == 2 + for k in var_ks + rng = ArrayDiff._storage_range(expr.sizes, k) + @test !any(==(sentinel), expr.reverse_storage[rng]) + end + return +end + end # module TestJuMP.runtests()