Projecting to Manifolds via Unsupervised Learning
Updated: Feb 1
H. Heaton*, S. Wu Fung*, A.T. Lin*, S. Osher, W. Yin.
Preprint available here
TL; DR - A new algorithm is presented that performs projections onto the low dimensional manifolds that efficiently represent true data.
Inverse problems attempt to recover a signal from noisy/incomplete measurements. Often the true signal is estimated as the solution to an optimization problem. That is, given data d, we assume the true signal u* that we wish to recover is the solution to the minimization problem
The first term above, called a loss function, measures the compliance of the signal u with the provided measurements d and the second term, called a regularizer, measures how well the features of the signal u match the features of a typical true signal. There are many ways to choose the loss function and regularizer. The aim is to choose these in such a way that the solutions to the above minimization problem model true signals as well as possible. We restrict our present interests to a special case where the true signal is contained in a small "nice" set M (described as a manifold below). With this assumption, the "ideal" regularizer to pick would yield the value zero on the nice set and be infinite everywhere else. In this case, our minimization problem becomes the constrained optimization problem
Contribution This blog describes the first theoretically justified method that can leverage available data (i.e., does not use an analytic formula) to project signals onto the manifold of true data. This enables us to provably approximate solutions to this constrained optimization problem.
High dimensional data often has a low dimensional structure, commonly called a manifold M, that we can exploit (e.g., sparsity, smoothness).
By solving an unsupervised learning problem, we can approximate the function that gives the distance between our sample data and the manifold.
We can project samples onto the manifold by iteratively updating the samples using our distance function estimate and then generating new distance function estimates for the updated samples.
Projecting samples onto the manifold enables us to use standard optimization methods (e.g., projected gradient descent or ADMM) to solve the constrained optimization problem above.
Low Dimensional Manifolds
The current era of big data has given rise to many problems that suffer from the curse of dimensionality (check out this article). In order to translate the high dimensional signals found in practice into interpretable visualizations, dimensionality reduction techniques have been introduced (e.g., PCA, Isomap, Laplacian eigenmaps, and t-SNE). It is often assumed that high dimensional signals can be represented compactly due to redundancies (e.g., adjacent pixels in an image are often correlated). This compact representation is loosely called a "manifold" in the machine learning literature (n.b. usage of this term does not always coincide with the formal math definition). However, knowing that the true data have a concise representation is not enough. The harder question to address is how to approximate this manifold in some meaningful way. It turns out, to leverage the structure of the manifold, what we need in constrained optimization algorithms is a projection operator, which finds the closest point on the manifold to an input signal. Thus, our task will be completed, more or less, if we can use data samples to learn how to perform these projections.
Estimating Distances to Manifolds
The core computational task at hand is to approximate a function that gives the distance between input samples and the manifold of true data. This distance function happens to be the solution to an unsupervised learning problem. We can furthermore approximate this problem using a parameterized function in the form of a special type of neural network (namely, one that is 1-Lipschitz and gradient norm preserving). The task here is to find the optimal weights for this network, i.e.,
The first expectation above is over the distribution of true signals while the second expectation is over the distribution of noisy signals. Note that we do not need to have the same number of true data and noisy data samples or even to know a direct correspondence between noisy and true signals (unlike supervised learning). For the optimal weights found by solving the above problem, we obtain the relation
This nifty result allows us to then iteratively update our noisy samples, as described below.
Having the ability to estimate the distance function, we can estimate projections by using a variation of gradient descent with an anchoring term. We approximate the projection of a point z onto the manifold by creating a sequence, starting from z, and using the update formula
Performing enough steps, we can approximate the projection via one of our iterates, i.e.,
We illustrated the above theory with some numerical examples. The first was a toy problem. Here the manifold is a half circle. Samples are given in the region around the manifold (blue dots) along with samples of the manifold (red dots). Using our training scheme with these samples, we obtained a neural network that approximates the projection onto the manifold. This was then used to solve an optimization problem that required finding points on the intersection of a feasible line and the manifold. The landscape of the distance function we generated is also provided (darker shading indicates lower distance to manifold). We welcome readers to run the code for this toy example online via Google Colab here.
We also solved an optimization problem to perform CT image reconstructions. Photos of our reconstructions (taken from our paper) among comparable methods are provided below. Our approach is currently state-of-the-art for solving these tasks using unsupervised learning techniques.
Click on the image to see a slideshow.
This work admits numerous natural extensions, including improving the computational efficiency by performing the bulk of the projection computations in a low dimensional space and then mapping back up to the high dimensional space. (Imagine working in the latent space of an auto-encoder.) Readers should also check out the related work RED-PRO.
As always, I am happy to discuss in the comments below any comments, suggestions, criticisms, and/or questions pertaining to this post.
* equal contribution