Deep Learning through the lens of Operators
Updated: Feb 15
TL;DR - Viewing deep networks as approximations of fixed point iterations yields arbitrarily deep networks with fixed memory costs.
This blog is based primarily on the following three NeurIPs papers:
A growing movement in past years has sought to study and implement extremely deep neural networks. These can have hundreds of layers and exhibit great performance. Unfortunately, the increasing depths require increasing memory resources. This implies network depth is, to some degree, limited by the available memory of the devices used to train the networks. This blog overviews some recent works that have alleviated this memory problem (for weight-tied networks*) by casting machine learning tasks as fixed point problems. Unlike commonplace feed forward networks, in this setting neural network weights are used implicitly to define a fixed point condition. This approach effectively results in a weight-tied feed forward network of infinite depth. And, in some instances, state of the art results have been obtained (for fixed parameter amounts). Forward propagation can take several forms; in each case, one uses an iterative procedure for finding a fixed point of the neural network. During training, one can use implicit differentiation to obtain a gradient of the associated loss function. Using the implicit differentiation amounts to solving a linear system involving a Jacobian of the network operator. This approach is of great practical use since it avoids backpropagating through all the layers. Hence memory and depth are no longer linked. However, we now make a trade-off: more computation vs. more memory. We'll dive into this in more details below.
Training deep networks usually requires additional memory for each additional layer.
Viewing network layers as operators, we re-frame the problem as finding a fixed point.
Implicit differentiation enables one to store fixed amount of memory for the trade-off of additional computations.
Fixed Point Iterations
Here we consider a mapping that takes as input a signal u and data d and outputs another signal. Provided input data d, the goal is to find u* such that if this is input into the mapping, then the output will also be u*, i.e.,
The T with the subscript denotes the operator. Deep Equilibrium Models (DEQs) proposes solving this fixed point method using something like a quasi-Newton method. For a special case where T has one layer, Monotone Operator Equilibrium Networks (MONs) presents operator splitting approaches for finding the fixed point.
Backpropagation Through Fixed Points
This is where the key idea of these papers enables one to use fixed memory, regardless of the depth of the network. For these networks, the aim is to solve a minimization problem of the form
Here the subscript d is used to denote the data input into the network. The first argument u* of the loss function is a fixed point of a parameterized operator, i.e.,
and the second argument y is the provided "true" label. To find the optimal weights, one must compute derivatives of the loss function with respect to the weights. Using the chain rule, we write
We then use the chain rule again together with the fact that our point of interest is a fixed point of the network. This yields
Assuming all the computations are "nice," we can perform algebraic manipulations to obtain
Although we may not in practice be able to directly compute this inverse, we can approximate the right hand side using various iterative solvers. Thus, we trade memory cost for computational cost. This is illustrated in the DEQ paper by the following figure.
(Notational note: In DEQ, the authors use z in place of u, x in place of d, and f in place of T.)
Below is a table snapshot from the MON paper demonstrating state of the art results for implicit depth networks, outperforming Neural ODE-based approaches.
We will not go into the details of the next example, but simply overview by saying this is for a natural language processing tasks and the results indicate a greatly reduced memory footprint during training (rightmost column).
Together with Neural ODEs and its variants (to be discussed in a later blog), these limit-type approaches have brought together both a new perspective on deep learning and opened the door to creating deep networks. Additionally, although the DEQ paper did not provide any substantial theory, the MON paper introduced conditions for ensuring the fixed point approach is well-defined and admits a unique solution. This opens the door to stronger theoretical guarantees for deep learning. Stay tuned.
As always, I am happy to discuss in the comments below any comments, suggestions, criticisms, and/or questions pertaining to this post.
* Here "weight tied" means the same weights are used for each hidden layer of the network.