Skip to content

Commit 14255e0

Browse files
committed
Fix #3186 - create leaf/non-leaf/requires_grad/retain_grad tutorial
This was originally bundled with PR #3389, but now broken into two separate tutorials after discussing with PyTorch team.
1 parent b5637fa commit 14255e0

File tree

4 files changed

+342
-0
lines changed

4 files changed

+342
-0
lines changed
Loading
Loading
Loading
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
"""
2+
Understanding requires_grad, retain_grad, Leaf, and Non-leaf tensors
3+
====================================================================
4+
5+
**Author:** `Justin Silver <https://github.com/j-silv>`__
6+
7+
This tutorial explains the subtleties of ``requires_grad``,
8+
``retain_grad``, leaf, and non-leaf tensors using a simple example.
9+
10+
Before starting, make sure you understand `tensors and how to manipulate
11+
them <https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html>`__.
12+
A basic knowledge of `how autograd
13+
works <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html>`__
14+
would also be useful.
15+
16+
"""
17+
18+
19+
######################################################################
20+
# Setup
21+
# -----
22+
#
23+
# First, make sure `PyTorch is
24+
# installed <https://pytorch.org/get-started/locally/>`__ and then import
25+
# the necessary libraries.
26+
#
27+
28+
import torch
29+
import torch.nn as nn
30+
import torch.optim as optim
31+
import torch.nn.functional as F
32+
import matplotlib.pyplot as plt
33+
34+
35+
######################################################################
36+
# Next, we instantiate a simple network to focus on the gradients. This
37+
# will be an affine layer, followed by a ReLU activation, and ending with
38+
# a MSE loss between prediction and label tensors.
39+
#
40+
# .. math::
41+
#
42+
# \mathbf{y}_{\text{pred}} = \text{ReLU}(\mathbf{x} \mathbf{W} + \mathbf{b})
43+
#
44+
# .. math::
45+
#
46+
# L = \text{MSE}(\mathbf{y}_{\text{pred}}, \mathbf{y})
47+
#
48+
# Note that the ``requires_grad=True`` is necessary for the parameters
49+
# (``W`` and ``b``) so that PyTorch tracks operations involving those
50+
# tensors. We’ll discuss more about this in a future
51+
# `section <#requires-grad>`__.
52+
#
53+
54+
# tensor setup
55+
x = torch.ones(1, 3) # input with shape: (1, 3)
56+
W = torch.ones(3, 2, requires_grad=True) # weights with shape: (3, 2)
57+
b = torch.ones(1, 2, requires_grad=True) # bias with shape: (1, 2)
58+
y = torch.ones(1, 2) # output with shape: (1, 2)
59+
60+
# forward pass
61+
z = (x @ W) + b # pre-activation with shape: (1, 2)
62+
y_pred = F.relu(z) # activation with shape: (1, 2)
63+
loss = F.mse_loss(y_pred, y) # scalar loss
64+
65+
66+
######################################################################
67+
# Leaf vs. non-leaf tensors
68+
# -------------------------
69+
#
70+
# After running the forward pass, PyTorch autograd has built up a `dynamic
71+
# computational
72+
# graph <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#computational-graph>`__
73+
# which is shown below. This is a `Directed Acyclic Graph
74+
# (DAG) <https://en.wikipedia.org/wiki/Directed_acyclic_graph>`__ which
75+
# keeps a record of input tensors (leaf nodes), all subsequent operations
76+
# on those tensors, and the intermediate/output tensors (non-leaf nodes).
77+
# The graph is used to compute gradients for each tensor starting from the
78+
# graph roots (outputs) to the leaves (inputs) using the `chain
79+
# rule <https://en.wikipedia.org/wiki/Chain_rule>`__ from calculus:
80+
#
81+
# .. math::
82+
#
83+
# \mathbf{y} = \mathbf{f}_k\bigl(\mathbf{f}_{k-1}(\dots \mathbf{f}_1(\mathbf{x}) \dots)\bigr)
84+
#
85+
# .. math::
86+
#
87+
# \frac{\partial \mathbf{y}}{\partial \mathbf{x}} =
88+
# \frac{\partial \mathbf{f}_k}{\partial \mathbf{f}_{k-1}} \cdot
89+
# \frac{\partial \mathbf{f}_{k-1}}{\partial \mathbf{f}_{k-2}} \cdot
90+
# \cdots \cdot
91+
# \frac{\partial \mathbf{f}_1}{\partial \mathbf{x}}
92+
#
93+
# .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-1.png
94+
# :alt: Computational graph after forward pass
95+
#
96+
# Computational graph after forward pass
97+
#
98+
# PyTorch considers a node to be a *leaf* if it is not the result of a
99+
# tensor operation with at least one input having ``requires_grad=True``
100+
# (e.g. ``x``, ``W``, ``b``, and ``y``), and everything else to be
101+
# *non-leaf* (e.g. ``z``, ``y_pred``, and ``loss``). You can verify this
102+
# programmatically by probing the ``is_leaf`` attribute of the tensors:
103+
#
104+
105+
# prints True because new tensors are leafs by convention
106+
print(f"{x.is_leaf=}")
107+
108+
# prints False because tensor is the result of an operation with at
109+
# least one input having requires_grad=True
110+
print(f"{z.is_leaf=}")
111+
112+
113+
######################################################################
114+
# The distinction between leaf and non-leaf determines whether the
115+
# tensor’s gradient will be stored in the ``grad`` property after the
116+
# backward pass, and thus be usable for `gradient
117+
# descent <https://en.wikipedia.org/wiki/Gradient_descent>`__. We’ll cover
118+
# this some more in the `following section <#retain-grad>`__.
119+
#
120+
# Let’s now investigate how PyTorch calculates and stores gradients for
121+
# the tensors in its computational graph.
122+
#
123+
124+
125+
######################################################################
126+
# ``requires_grad``
127+
# -----------------
128+
#
129+
# To build the computational graph which can be used for gradient
130+
# calculation, we need to pass in the ``requires_grad=True`` parameter to
131+
# a tensor constructor. By default, the value is ``False``, and thus
132+
# PyTorch does not track gradients on any created tensors. To verify this,
133+
# try not setting ``requires_grad``, re-run the forward pass, and then run
134+
# backpropagation. You will see:
135+
#
136+
# ::
137+
#
138+
# >>> loss.backward()
139+
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
140+
#
141+
# This error means that autograd can’t backpropagate to any leaf tensors
142+
# because ``loss`` is not tracking gradients. If you need to change the
143+
# property, you can call ``requires_grad_()`` on the tensor (notice the \_
144+
# suffix).
145+
#
146+
# We can sanity check which nodes require gradient calculation, just like
147+
# we did above with the ``is_leaf`` attribute:
148+
#
149+
150+
print(f"{x.requires_grad=}") # prints False because requires_grad=False by default
151+
print(f"{W.requires_grad=}") # prints True because we set requires_grad=True in constructor
152+
print(f"{z.requires_grad=}") # prints True because tensor is a non-leaf node
153+
154+
155+
######################################################################
156+
# It’s useful to remember that a non-leaf tensor has
157+
# ``requires_grad=True`` by definition, since backpropagation would fail
158+
# otherwise. If the tensor is a leaf, then it will only have
159+
# ``requires_grad=True`` if it was specifically set by the user. Another
160+
# way to phrase this is that if at least one of the inputs to a tensor
161+
# requires the gradient, then it will require the gradient as well.
162+
#
163+
# There are two exceptions to this rule:
164+
#
165+
# 1. Any ``nn.Module`` that has ``nn.Parameter`` will have
166+
# ``requires_grad=True`` for its parameters (see
167+
# `here <https://docs.pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#creating-models>`__)
168+
# 2. Locally disabling gradient computation with context managers (see
169+
# `here <https://docs.pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation>`__)
170+
#
171+
# In summary, ``requires_grad`` tells autograd which tensors need to have
172+
# their gradients calculated for backpropagation to work. This is
173+
# different from which tensors have their ``grad`` field populated, which
174+
# is the topic of the next section.
175+
#
176+
177+
178+
######################################################################
179+
# ``retain_grad``
180+
# ---------------
181+
#
182+
# To actually perform optimization (e.g. SGD, Adam, etc.), we need to run
183+
# the backward pass so that we can extract the gradients.
184+
#
185+
186+
loss.backward()
187+
188+
189+
######################################################################
190+
# Calling ``backward()`` populates the ``grad`` field of all leaf tensors
191+
# which had ``requires_grad=True``. The ``grad`` is the gradient of the
192+
# loss with respect to the tensor we are probing. Before running
193+
# ``backward()``, this attribute is set to ``None``.
194+
#
195+
196+
print(f"{W.grad=}")
197+
print(f"{b.grad=}")
198+
199+
200+
######################################################################
201+
# You might be wondering about the other tensors in our network. Let’s
202+
# check the remaining leaf nodes:
203+
#
204+
205+
# prints all None because requires_grad=False
206+
print(f"{x.grad=}")
207+
print(f"{y.grad=}")
208+
209+
210+
######################################################################
211+
# The gradients for these tensors haven’t been populated because we did
212+
# not explicitly tell PyTorch to calculate their gradient
213+
# (``requires_grad=False``).
214+
#
215+
# Let’s now look at an intermediate non-leaf node:
216+
#
217+
218+
print(f"{z.grad=}")
219+
220+
221+
######################################################################
222+
# PyTorch returns ``None`` for the gradient and also warns us that a
223+
# non-leaf node’s ``grad`` attribute is being accessed. Although autograd
224+
# has to calculate intermediate gradients for backpropagation to work, it
225+
# assumes you don’t need to access the values afterwards. To change this
226+
# behavior, we can use the ``retain_grad()`` function on a tensor. This
227+
# tells the autograd engine to populate that tensor’s ``grad`` after
228+
# calling ``backward()``.
229+
#
230+
231+
# we have to re-run the forward pass
232+
z = (x @ W) + b
233+
y_pred = F.relu(z)
234+
loss = F.mse_loss(y_pred, y)
235+
236+
# tell PyTorch to store the gradients after backward()
237+
z.retain_grad()
238+
y_pred.retain_grad()
239+
loss.retain_grad()
240+
241+
# have to zero out gradients otherwise they would accumulate
242+
W.grad = None
243+
b.grad = None
244+
245+
# backpropagation
246+
loss.backward()
247+
248+
# print gradients for all tensors that have requires_grad=True
249+
print(f"{W.grad=}")
250+
print(f"{b.grad=}")
251+
print(f"{z.grad=}")
252+
print(f"{y_pred.grad=}")
253+
print(f"{loss.grad=}")
254+
255+
256+
######################################################################
257+
# We get the same result for ``W.grad`` as before. Also note that because
258+
# the loss is scalar, the gradient of the loss with respect to itself is
259+
# simply ``1.0``.
260+
#
261+
# If we look at the state of the computational graph now, we see that the
262+
# ``retains_grad`` attribute has changed for the intermediate tensors. By
263+
# convention, this attribute will print ``False`` for any leaf node, even
264+
# if it requires its gradient.
265+
#
266+
# .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-2.png
267+
# :alt: Computational graph after backward pass
268+
#
269+
# Computational graph after backward pass
270+
#
271+
# If you call ``retain_grad()`` on a non-leaf node, it results in a no-op.
272+
# If we call ``retain_grad()`` on a node that has ``requires_grad=False``,
273+
# PyTorch actually throws an error, since it can’t store the gradient if
274+
# it is never calculated.
275+
#
276+
# ::
277+
#
278+
# >>> x.retain_grad()
279+
# RuntimeError: can't retain_grad on Tensor that has requires_grad=False
280+
#
281+
282+
283+
######################################################################
284+
# Summary table
285+
# -------------
286+
#
287+
# Using ``retain_grad()`` and ``retains_grad`` only make sense for
288+
# non-leaf nodes, since the ``grad`` attribute will already be populated
289+
# for leaf tensors that have ``requires_grad=True``. By default, these
290+
# non-leaf nodes do not retain (store) their gradient after
291+
# backpropagation. We can change that by rerunning the forward pass,
292+
# telling PyTorch to store the gradients, and then performing
293+
# backpropagation.
294+
#
295+
# The following table can be used as a reference which summarizes the
296+
# above discussions. The following scenarios are the only ones that are
297+
# valid for PyTorch tensors.
298+
#
299+
#
300+
#
301+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
302+
# | ``is_leaf`` | ``requires_grad`` | ``retains_grad`` | ``require_grad()`` | ``retain_grad()`` |
303+
# +================+========================+========================+===================================================+=====================================+
304+
# | ``True`` | ``False`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
305+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
306+
# | ``True`` | ``True`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
307+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
308+
# | ``False`` | ``True`` | ``False`` | no-op | sets ``retains_grad`` to ``True`` |
309+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
310+
# | ``False`` | ``True`` | ``True`` | no-op | no-op |
311+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
312+
#
313+
314+
315+
######################################################################
316+
# Conclusion
317+
# ----------
318+
#
319+
# In this tutorial, we covered when and how PyTorch computes gradients for
320+
# leaf and non-leaf tensors. By using ``retain_grad``, we can access the
321+
# gradients of intermediate tensors within autograd’s computational graph.
322+
#
323+
# If you would like to learn more about how PyTorch’s autograd system
324+
# works, please visit the `references <#references>`__ below. If you have
325+
# any feedback for this tutorial (improvements, typo fixes, etc.) then
326+
# please use the `PyTorch Forums <https://discuss.pytorch.org/>`__ and/or
327+
# the `issue tracker <https://github.com/pytorch/tutorials/issues>`__ to
328+
# reach out.
329+
#
330+
331+
332+
######################################################################
333+
# References
334+
# ----------
335+
#
336+
# - `A Gentle Introduction to
337+
# torch.autograd <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__
338+
# - `Automatic Differentiation with
339+
# torch.autograd <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial>`__
340+
# - `Autograd
341+
# mechanics <https://docs.pytorch.org/docs/stable/notes/autograd.html>`__
342+
#

0 commit comments

Comments
 (0)