Deep Neural Networks are biased, at initialisation, towards simple functions
And why is this a very important step in understanding why they work?
This is a follow-up post to Neural networks are fundamentally Bayesian, which presented a candidate theory of generalisation from [1], based on ideas in [2,3]. The reading order is not too important — this post predominately summarises material in [2.3], which study why neural networks are biased at initialisation towards simple functions (with low Kolmogorov complexity). This work motivated results in [1], which explained how this bias at initialisation translates into good generalisation in the real world, and thus why neural networks enjoy the success they do! Thus, this post is best thought of as a prequel.
We begin by asking the following question:
What functions do neural networks express before being trained?
More specifically, what is the probability that a neural network will express a function f after being randomly initialised? We will call this quantity P(f),¹ and take the random initialisation to be i.i.d. Gaussian (although [2] suggests that P(f) is not too sensitive to the type of initialisation).
In general, changing one parameter in a neural network by a small amount will affect the raw output of the network— so in general, there is a 1–1 correspondence between the parameters in a neural network and the function it expresses². However, for many problems, we are not interested in the raw output of the network. This is best explained with an example problem.
¹Note P(f) is denoted Pᵦ ( f ) in Neural networks are fundamentally Bayesian.
²If we take the domain to be the entire vector space over which the network is defined.
Classification problem on MNIST
Consider the problem of correctly classifying images of the handwritten digits 0–9 (the MNIST dataset). We are clearly not interested in the raw output of the network — we only care about its final (discrete) decision. Let us imagine, for simplicity, that we want to only classify the images as even or odd numbers. This can be done with a neural network with a single output neuron, and thresholding at zero — a positive output implies an even number, a negative output implies an odd number.
So, if we consider a subset of m images in MNIST, which we call S, then the network N models a function
f : S → {0, 1}ᵐ,
where 1 corresponds to even and 0 to odd. This is because ultimately we (mostly) care about the post-thresholded output of the network — not the raw outputs — and in which case, small changes to parameters in N may not change the function expressed (i.e. the classification will not change). The notation {0, 1}ᵐ denotes that m images are being mapped either 0 or 1.
We can then ask, what is P(f) for these functions?
See Figures 1 and 2 below for two visualisations of P(f), for the system discussed above. P(f) was calculated by sampling from 10⁷ different random initialisations of a 2-hidden layer fully connected network, using 100 images from MNIST.
It is clear from Figures 1 and 2 that there is a huge range in P(f). Results in [2] suggest that the range for the above problem is over 30 orders of magnitude. But why does this matter?
It matters because, for a set of 100 images, there are 2¹⁰⁰ ≈ 10³⁰ different possible functions (i.e. possible ways of classifying each image as even or odd). Without information about N, we might assume that each function is equally likely, meaning P(f) ≈ 10⁻³⁰ (this would be the case where images are classified by unbiased coin flip).
Given that P(f) can be as large as 0.05, neural networks are clearly not unbiased at initialisation. Instead, there is a strong bias towards certain types of functions at initialisation — before any training has taken place.
Is there any reason to expect this?
There is a theorem originally due to Levin, and repurposed for input-output maps [4] which when used in the context of neural networks [3] states the following: For the map from the parameters of a neural network N to the function expressed by N, the following result holds:
P(f) ≤ 2⁻ᴷ⁽ ᶠ ⁾⁺ᴼ⁽¹⁾,
where K(f) is the Kolmogorov complexity of the function f and the O(1) terms are independent of f but dependent on N. There are some further conditions that need to be satisfied for the bound to hold for neural networks, but empirical evidence [3] plus a few theoretical results [2] indicate that it does, and is non-vacuous.
In essence, this says that complex functions will have low P(f), and simple functions can have large P(f), if the bound is tight. However, Kolmogorov complexity is uncomputable — so proving this for general architectures and datasets would be, at best, non-trivial. Instead, a very clever experiment in [3] allows us to empirically test this upper bound. The results of this experiment are shown in Figure 3, where a fully connected network models functions of the form:
f : {0,1}ⁿ→{0,1},
chosen because a suitable complexity measure exists — see [3] for details.
Evidently, P(f) is exponentially larger for simple f. There are functions that lie below the bound, but it is argued in [7] that (very informally) there is a limit on the number of functions that can lie beyond a certain distance from the bound.
We call this a simplicity bias — because P(f) is higher for simple functions.
There is substantial further evidence from [1,2,3] that this simplicity bias is a general property of neural networks of different architectures and datasets³. For example, similar experiments were performed on MNIST and Cifar10 [1,3] where the CSR complexity measure was used to approximate K(f). There is also an analytic result that perceptrons acting on the boolean hypercube are biased towards low-entropy (and thus simple) functions, in [2].
In summary:
- P(f) is the probability of the network randomly initialising to some function f.
- Theoretical and experimental results strongly suggest that P(f) is exponentially larger for simple functions than complex functions
Also note that functions with large P(f) have greater ‘volumes’ in parameter-space (see [1,2] for details). This is intuitively obvious — if you are more likely to randomly sample some function, it must have more associated parameters, and thus a greater volume in parameter-space.
³Bear in mind that the function is defined relative to a dataset, as it specifies its domain and co-domain.
But is this a good thing for learning?
It is thought [5] that real-world data has an underlying simple description. For example, when we read handwritten digits we do not worry too much about precise details — if it’s a single oval-like shape, it’s probably a zero. We don’t need to take the exact pixel value of every pixel into account.
If real-world data that we are interested in learning really does have a simple underlying description, then simple functions will generalise better than complex functions.
Consider a supervised learning problem — a training set S from MNIST, containing m examples. Then, an ideal learning agent would be able to calculate the Kolmogorov complexity of all functions⁴ from the images of the digits (i.e. from the pixel values) to the classifications of the digits. It would then throw out all functions that did not correctly predict all m examples in S. Finally, of these functions, it would choose the one with the lowest Kolmogorov complexity.
In other words, an ideal learning agent would choose the simplest function that fits the training data.
⁴Defined relative to some UTM (see [3] for details).
So do neural networks work like this?
At this point, if you have read Neural networks are fundamentally Bayesian, you can stop reading, as you already know the answer! If you haven’t, then please check it out, as it:
- Explains how we could in principle use random sampling in neural networks to learn, in a Bayesian fashion (and what this means), such that after training, it expresses a function f with probability proportional to P(f).
- Summarises substantial empirical evidence in [1] which demonstrates that stochastic optimisers like SGD act in a very similar way to the hypothetical learning agent discussed in the previous point.
Thus, neural networks generalise well because P(f) is much larger for simple functions, which generalise better.
Conclusion
This has been a very brief summary of the main results in [2,3]. There are also further experiments which demonstrate that P(f) is not sensitive to the choice of initialisation, and how the observed simplicity bias is found in real-world datasets (e.g. cifar10 and MNIST), using a complexity measure called CSR.
One final point concerns the strength of the inductive bias — obviously we only have an upper bound — it does not guarantee that the probability really does vary exponentially with complexity. If the bias were too weak, then we would not get good generalisation with high probability.
The PAC-Bayes bound provides a probabilistic bound on generalisation. Applications of this bound in [1,6] show that, for cutting-edge architectures on real-world datasets, the simplicity bias in P(f) is sufficient to guarantee good generalisation. This will be the subject of a future post!
Finally, if you think I have missed anything or said anything inaccurate, please let me know. Also note that this my interpretation of work done with a number of co-authors, and while I believe it to accurately approximate their views, it may not always be a perfect representation!
References
[1] C. Mingard, G. Valle-Pérez, J. Skalse, A. Louis. Is SGD a Bayesian Sampler? Well, almost. (2020) https://arxiv.org/abs/2006.15191
[2] C. Mingard, J. Skalse, G. Valle-Perez, D. Martinez-Rubio, V. Mikulik, A. Louis. Neural Networks are a-priori biased towards low entropy functions. (2019) https://arxiv.org/abs/1909.11522
[3] G. Valle-Pérez, C. Camargo, A. Louis. Deep learning generalizes because the parameter-function map is biased towards simple functions. (2018) https://arxiv.org/abs/1805.08522
[4] Kamaludin Dingle, Chico Q Camargo, and Ard A Louis. Input–output maps are strongly biased towards simple outputs (2018). Nature communications, 9(1):761.
[5] Jürgen Schmidhuber. Discovering problem solutions with low kolmogorov complexity and high generalization capability (1994). MACHINE LEARNING: PROCEEDINGS OF THE TWELFTH INTERNATIONAL CONFERENCE.
[6] Guillermo Valle-Pérez, Ard A. Louis. Generalization bounds for deep learning (2020). https://arxiv.org/abs/2012.04115
[7] Dingle, K., Pérez, G.V. & Louis, A.A. Generic predictions of output probability based on complexities of inputs and outputs. Sci Rep 10, 4415 (2020). https://doi.org/10.1038/s41598-020-61135-7