Back-propagation, an introduction
Given the sheer number of backpropagation tutorials on the internet, is there really need for another? One of us (Sanjeev) recently taught backpropagation in undergrad AI and couldn’t find any account he was happy with. So here’s our exposition, together with some history and context, as well as a few advanced notions at the end. This article assumes the reader knows the definitions of gradients and neural networks.
What is backpropagation?
It is the basic algorithm in training neural nets, apparently independently rediscovered several times in the 1970-80’s (e.g., see Werbos’ Ph.D. thesis and book, and Rumelhart et al.). Some related ideas existed in control theory in the 1960s. (One reader points out another independent rediscovery, the Baur-Strassen lemma from 1983.)
Backpropagation gives a fast way to compute the sensitivity of the output of a neural network to all of its parameters while keeping the inputs of the network fixed: specifically it computes all partial derivatives
Note that backpropagation computes the gradient exactly, but properly training neural nets needs many more tricks than just backpropagation. Understanding backpropagation is useful for appreciating some advanced tricks.
The importance of backpropagation derives from its efficiency. Assuming node operations take unit time, the running time is
linear, specifically,
Backpropagation can be efficiently implemented using highly parallel vector operations available in today’s GPUs (Graphical Processing Units), which play an important role in the the recent neural nets revolution.
Side Note: Expert readers will recognize that in the standard accounts of neural net training, the actual quantity of interest is the gradient of the training loss, which happens to be a simple function of the network output. But the above phrasing is fully general since one can simply add a new output node to the network that computes the training loss from the old output. Then the quantity of interest is indeed the gradient of this new output with respect to network parameters.
Problem Setup
Backpropagation applies only to acyclic networks with directed edges. (Later we briefly sketch its use on networks with cycles.)
Without loss of generality, acyclic networks can be visualized as being structured in numbered layers, with nodes in the
We start with a simple claim that reduces the problem of computing the gradient to the problem of computing partial derivatives with respect to the nodes:
Claim 1: To compute the desired gradient with respect to the parameters, it suffices to compute
for every node .
Let’s be clear what
Claim 1 is a direct application of chain rule, and let’s illustrate it for a simple neural nets (we address more general networks later). Suppose node
Hence, we see that having computed
Multivariate Chain Rule
Towards computing the derivatives with respect to the nodes, we first recall the multivariate Chain rule, which handily describes the relationships between these partial derivatives (depending on the graph structure).
Suppose a variable
This is a direct generalization of eqn. (2) and a sub-case of eqn. (11) in this description of chain rule.
This formula is perfectly suitable for our cases. Below is the same example as we used before but with a different focus and numbering of the nodes.
We see that given we’ve computed the derivatives with respect to all the nodes that is above the node
Basic assumption: If
is a node at level and is any node at level whose output is an input to , then computing takes unit time on our computer.
Naive feedforward algorithm (not efficient!)
It is useful to first point out the naive quadratic time algorithm implied by the chain rule. Most authors skip this trivial version, which we think is analogous to teaching sorting using only quicksort, and skipping over the less efficient bubblesort.
The naive algorithm is to compute
This computation can be done in feedforward fashion. If such value has been obtained for every
Backpropagation (Linear Time)
The more efficient backpropagation, as the name suggests, computes the partial derivatives in the reverse direction. Messages are passed in one wave backwards from higher number layers to lower number layers. (Some presentations of the algorithm describe it as dynamic programming.)
Messaging protocol: The node
receives a message along each outgoing edge from the node at the other end of that edge. It sums these messages to get a number (if is the output of the entire net, then define ) and then it sends the following message to any node adjacent to it at a lower level:
Clearly, the amount of work done by each node is proportional to its degree, and thus overall work is the sum of the node degrees. Summing all node degrees counts each edge twice, and thus the overall work is
To prove correctness, we prove the following:
Main Claim: At each node
, the value is exactly .
Base Case: At the output layer this is true, since
Inductive case: Suppose the claim was true for layers
Auto-differentiation
Since the exposition above used almost no details about the network and the operations that the node perform, it extends to every computation that can be organized as an acyclic graph whose each node computes a differentiable function of its incoming neighbors. This observation underlies many auto-differentiation packages such as autograd or tensorflow: they allow computing the gradient of the output of such a computation with respect to the network parameters.
We first observe that Claim 1 continues to hold in this very general setting. This is without loss of generality because we can view the parameters associated to the edges as also sitting on the nodes (actually, leaf nodes). This can be done via a simple
transformation to the network; for a single node it is shown in the picture below; and one would need to continue to do this transformation in the rest of the networks feeding into
Then, we can use the messaging protocol to compute the derivatives with respect to the nodes, as long as the local partial derivative can be computed efficiently. We note that the algorithm can be implemented in a fairly modular manner: For every node
Extension to vector messages: In fact (b) can be done efficiently in more general settings where we allow the output of each node in the network to be a vector (or even matrix/tensor) instead of only a real number. Here we need to replace
For example, as illustrated below, suppose the node
Such vector operations can also be implemented efficiently using today’s GPUs.
Notable Extensions
1) Allowing weight tying. In many neural architectures, the designer wants to force many network units such as edges or nodes to share the same parameter. For example, in convolutional neural nets, the same filter has to be applied all over the image, which implies reusing the same parameter for a large set of edges between the two layers.
For simplicity, suppose two parameters
2) Backpropagation on networks with loops. The above exposition assumed the network is acyclic. Many cutting-edge applications such as machine translation and language understanding use networks with directed loops (e.g., recurrent neural networks). These architectures —all examples of the “differentiable computing” paradigm below—can get complicated and may involve operations on a separate memory as well as mechanisms to shift attention to different parts of data and memory.
Networks with loops are trained using gradient descent as well, using back-propagation through time, which consists of expanding the network through a finite number of time steps into an acyclic graph, with replicated copies of the same network. These replicas share the weights (weight tying!) so the gradient can be computed. In practice an issue may arise with exploding or vanishing gradients which impact convergence. Such issues can be carefully addressed in practice by clipping the gradient or re-parameterization techniques such as long short-term memory.
The fact that the gradient can be computed efficiently for such general networks with loops has motivated neural net models with memory or even data structures (see for example neural Turing machines and differentiable neural computer). Using gradient descent, one can optimize over a family of parameterized networks with loops to find the best one that solves a certain computational task (on the training examples). The limits of these ideas are still being explored.
3) Hessian-vector product in linear time. It is possible to generalize backprop to enable 2nd order optimization in “near-linear” time, not just gradient descent, as shown in recent independent manuscripts of
Carmon et al. and
Agarwal et al. (NB: Tengyu is a coauthor on this one.). One essential step is to compute the product of the
Hessian matrix and a vector, for which
Pearlmutter’93 gave an efficient algorithm. Here we show how to do this in
Claim (informal): Suppose an acyclic network with
nodes and edges has output and leaves . Then there exists a network of size that has as input nodes and as output nodes.
The proof of the Claim follows in straightforward fashion from implementing the message passing protocol as an acyclic circuit.
Next we show how to compute
Note that by construction,
##That’s all!
Please write your comments on this exposition and whether it can be improved.