|
| 1 | +# # AutoDiff API |
| 2 | + |
| 3 | +# The goal of this tutorial is to give users already familiar with automatic |
| 4 | +# differentiation (AD) an overview |
| 5 | +# of the Enzyme differentiation API for the following differentiation modes |
| 6 | +# * Reverse mode |
| 7 | +# * Forward mode |
| 8 | +# * Forward over reverse mode |
| 9 | +# * Vector Forward over reverse mode |
| 10 | +# # Defining a function |
| 11 | +# Enzyme differentiates arbitrary multivariate vector functions as the most |
| 12 | +# general case in automatic differentiation |
| 13 | +# ```math |
| 14 | +# f: \mathbb{R}^n \rightarrow \mathbb{R}^m, y = f(x) |
| 15 | +# ``` |
| 16 | +# For simplicity we define a vector function with ``m=1``. However, this |
| 17 | +# tutorial can easily be applied to arbitrary ``m \in \mathbb{N}``. |
| 18 | +using Enzyme |
| 19 | + |
| 20 | +function f(x::Array{Float64}, y::Array{Float64}) |
| 21 | + y[1] = x[1] * x[1] + x[2] * x[1] |
| 22 | + return nothing |
| 23 | +end; |
| 24 | + |
| 25 | +# # Reverse mode |
| 26 | +# The reverse model in AD is defined as |
| 27 | +# ```math |
| 28 | +# \begin{aligned} |
| 29 | +# y &= f(x) \\ |
| 30 | +# \bar{x} &= \bar{y} \cdot \nabla f(x) |
| 31 | +# \end{aligned} |
| 32 | +# ``` |
| 33 | +# bar denotes an adjoint variable. Note that executing an AD in reverse mode |
| 34 | +# computes both ``y`` and the adjoint ``\bar{x}``. |
| 35 | +x = [2.0, 2.0] |
| 36 | +bx = [0.0, 0.0] |
| 37 | +y = [0.0] |
| 38 | +by = [1.0]; |
| 39 | + |
| 40 | +# Enzyme stores the value and adjoint of a variable in an object of type |
| 41 | +# `Duplicated` where the first element represent the value and the second the |
| 42 | +# adjoint. Evaluating the reverse model using Enzyme is done via the following |
| 43 | +# call. |
| 44 | +Enzyme.autodiff(f, Duplicated(x, bx), Duplicated(y, by)); |
| 45 | +# This yields the gradient of `f` in `bx` at point `x = [2.0, 2.0]`. `by` is called the seed and has |
| 46 | +# to be set to ``1.0`` in order to compute the gradient. Let's save the gradient for later. |
| 47 | +g = copy(bx) |
| 48 | + |
| 49 | +# # Forward mode |
| 50 | +# The forward model in AD is defined as |
| 51 | +# ```math |
| 52 | +# \begin{aligned} |
| 53 | +# y &= f(x) \\ |
| 54 | +# \dot{y} &= \nabla f(x) \cdot x |
| 55 | +# \end{aligned} |
| 56 | +# ``` |
| 57 | +# To obtain the first element of the gradient using the forward model we have to |
| 58 | +# seed ``\dot{x}`` with ``\dot{x} = [1.0,0.0]`` |
| 59 | +x = [2.0, 2.0] |
| 60 | +dx = [1.0, 0.0] |
| 61 | +y = [0.0] |
| 62 | +dy = [0.0]; |
| 63 | +# In the forward mode the second element of `Duplicated` stores the tangent. |
| 64 | +Enzyme.autodiff(Forward, f, Duplicated(x, dx), Duplicated(y, dy)); |
| 65 | + |
| 66 | +# We can now verify that indeed the reverse mode and forward mode yield the same |
| 67 | +# result for the first component of the gradient. Note that to acquire the full |
| 68 | +# gradient one needs to execute the forward model a second time with the seed |
| 69 | +# `dx` set to `[0.0,1.0]`. |
| 70 | + |
| 71 | +# Let's verify whether the reverse and forward model agree. |
| 72 | +g[1] == dy[1] |
| 73 | + |
| 74 | +# # Forward over reverse |
| 75 | +# The forward over reverse (FoR) model is obtained by applying the forward model |
| 76 | +# to the reverse model using the chain rule for the product in the adjoint statement. |
| 77 | +# ```math |
| 78 | +# \begin{aligned} |
| 79 | +# y &= f(x) \\ |
| 80 | +# \dot{y} &= f(x) \cdot \dot{x} \\ |
| 81 | +# \bar{x} &= \bar{y} \cdot \nabla f(x) \\ |
| 82 | +# \dot{\bar{x}} &= \bar{y} \cdot \nabla^2 f(x) \cdot \dot{x} + \dot{\bar{y}} \cdot \nabla f(x) |
| 83 | +# \end{aligned} |
| 84 | +# ``` |
| 85 | +# To obtain the first column/row of the Hessian ``\nabla^2 f(x)`` we have to |
| 86 | +# seed ``\dot{\bar{y}}`` with ``[0.0]``, ``\bar{y}`` with ``[1.0]`` and ``\dot{x}`` with ``[1.0, 0.0]``. |
| 87 | + |
| 88 | +y = [0.0] |
| 89 | +x = [2.0, 2.0] |
| 90 | + |
| 91 | +dy = [0.0] |
| 92 | +dx = [1.0, 0.0] |
| 93 | + |
| 94 | +bx = [0.0, 0.0] |
| 95 | +by = [1.0] |
| 96 | +dbx = [0.0, 0.0] |
| 97 | +dby = [0.0] |
| 98 | + |
| 99 | +Enzyme.autodiff( |
| 100 | + Forward, |
| 101 | + (x,y) -> Enzyme.autodiff_deferred(f, x, y), |
| 102 | + Duplicated(Duplicated(x, bx), Duplicated(dx, dbx)), |
| 103 | + Duplicated(Duplicated(y, by), Duplicated(dy, dby)), |
| 104 | +) |
| 105 | + |
| 106 | +# The FoR model also computes the forward model from before, giving us again the first component of the gradient. |
| 107 | +g[1] == dy[1] |
| 108 | +# In addition we now have the first row/column of the Hessian. |
| 109 | +dbx[1] == 2.0 |
| 110 | +dbx[2] == 1.0 |
| 111 | + |
| 112 | +# # Vector forward over reverse |
| 113 | +# The vector FoR allows us to propagate several tangents at once through the |
| 114 | +# second-order model. This allows us the acquire the Hessian in one autodiff |
| 115 | +# call. The multiple tangents are organized in tuples. Following the same seeding strategy as before, we now seed both |
| 116 | +# in the `vdx[1]=[1.0, 0.0]` and `vdx[2]=[0.0, 1.0]` direction. These tuples have to be put into a `BatchDuplicated` type. |
| 117 | +y = [0.0] |
| 118 | +x = [2.0, 2.0] |
| 119 | + |
| 120 | +vdy = ([0.0],[0.0]) |
| 121 | +vdx = ([1.0, 0.0], [0.0, 1.0]) |
| 122 | + |
| 123 | +bx = [0.0, 0.0] |
| 124 | +by = [1.0] |
| 125 | +vdbx = ([0.0, 0.0], [0.0, 0.0]) |
| 126 | +vdby = ([0.0], [0.0]); |
| 127 | + |
| 128 | +# The `BatchedDuplicated` objects are constructed using the broadcast operator |
| 129 | +# on our tuples of `Duplicated` for the tangents. |
| 130 | +Enzyme.autodiff( |
| 131 | + Forward, |
| 132 | + (x,y) -> Enzyme.autodiff_deferred(f, x, y), |
| 133 | + BatchDuplicated(Duplicated(x, bx), Duplicated.(vdx, vdbx)), |
| 134 | + BatchDuplicated(Duplicated(y, by), Duplicated.(vdy, vdby)), |
| 135 | +); |
| 136 | + |
| 137 | +# Again we obtain the first-order gradient. |
| 138 | +g[1] == vdy[1][1] |
| 139 | +# We have now the first row/column of the Hessian |
| 140 | +vdbx[1][1] == 2.0 |
| 141 | + |
| 142 | +vdbx[1][2] == 1.0 |
| 143 | +# as well as the second row/column |
| 144 | +vdbx[2][1] == 1.0 |
| 145 | + |
| 146 | +vdbx[2][2] == 0.0 |
0 commit comments