We show that the error of iteratively magnitude-pruned networks empirically follows a scaling law with interpretable coefficients that depend on the architecture and task. We functionally approximate the error of the pruned networks, showing it is predictable in terms of an invariant tying width, depth, and pruning level, such that networks of vastly different pruned densities are interchangeable.
We demonstrate the accuracy of this approximation over orders of magnitude in depth, width, dataset size, and density. We show that the functional form holds (generalizes) for large scale data (eg. ImageNet) and architectures (eg. ResNets).
As neural networks become ever larger and costlier to train, our findings suggest a framework for reasoning conceptually and analytically about a standard method for unstructured pruning.
ā¦In summary, our contributions are as follows:
We develop a scaling law that accurately estimates the error when pruning a single network with IMP.
We observe and characterize an invariant that allows error-preserving interchangeability among depth, width, and pruning density.
Using this invariant, we extend our single-network scaling law into a joint scaling law that predicts the error of all members of a network family at all dataset sizes and all pruning densities.
In doing so, we demonstrate that there is structure to the behavior of the error of iteratively magnitude-pruned networks that we can capture explicitly with a simple functional form and interpretable parameters.
Our scaling law enables a framework for reasoning analytically about IMP, allowing us to answer our motivating question and similar questions about pruning.
Figure 1: Relationship between density and error when pruning CIFAR-10 ResNets.w varies, l = 20, n = N (left). Low-error plateau, power-law region, and high-error plateau when l = 20, w = 1, n = N (center). Visualizing Equation 1 and the roles of free parameters (right).
ā¦In Figure 1 (left), we plot the error of these pruned networks for CIFAR-10 ResNets with depth l = 20 and different widths w. All these curves follow a similar pattern:
Low-error plateau: The densest networks (right part of curves) have similar error to the unpruned network: εnp(w). We call this the ālow-error plateauā.
Power law region: When pruned further, error increases linearly on the logarithmic axes of the figure.
Linear behavior on a logarithmic scale is the functional form of a power law, where error relates to density via exponent γ and coefficient c: ε(d, w) ā cdāγ. In particular, γ is the slope of the line on the logarithmic axes.
High-error plateau: When pruned further, error again flattens; we call this the high-error plateau and call the error of the plateau εā.
Figure 2: Estimated (blue dots) and actual error (solid lines) when pruning CIFAR-10 ResNets.w varies, l = 20, n = N (left). Estimated versus actual error for the same networks (center). Estimated versus actual error for all CIFAR-10 resnet configurations (right).
Equation 1 (1): functional form of deep learning pruning scaling law.
ā¦Studying this optimization problem reveals a useful insight aboutāin this caseāthe CIFAR-10 ResNets. In the pruning literature, it is typical to report the minimum density where the pruned network matches the error εnp(l, w) of the unpruned network (eg. Hanet al2015ās IMP). However, our scaling law suggests this is not the smallest model to achieve error εnp(l, w). Instead, it is better to train a larger network with depth lā² and width wā² and prune until error reaches εnp(l, w), despite the fact that error will be higher than εnp(lā², wā²). This analytic result parallels and extends the findings of Liet al2020on NLP tasks. However, unlike Liet al2020, our scaling law suggests starting too large is detrimental for the CIFAR-10 ResNets, leading to a higher parameter-count at errorεk.
Figure 8 (left) illustrates this behavior concretely: it shows the error predicted by our scaling law for CIFAR-10 ResNets with varying widths. The dotted black line shows the minimal parameter-count at which we predict it is possible to achieve each error. Importantly, none of the low-error plateaus intersect this black dotted line, meaning a model cannot be minimal until it has been pruned to the point where it increases in error. This occurs because the transitions of our functional form are gradual. On the other hand, if we start with a model that is too large, it will no longer be on the black line when it has been pruned to the point where its error reaches εnp(l, w); this behavior occurs because error decreases as a function of the invariant mā± rather than the parameter-count m and because m āĢø mā±. In Figure 8 (right), we plot the same information from the actual CIFAR-10 data and see the same phenomena occur in practice. The difference between the estimated and actual optimal parameter count is no more than 25%.
Figure 8: Estimated error as width varies for the CIFAR-10 ResNets (left). Actual error as width varies for the CIFAR-10 ResNets (right).
The dotted black line is the minimal number of parameters necessary to reach each error εk among all of the pruned networks. Reaching this point requires starting with a particular lower-error network (purple) and pruning until error increases to εk. Starting too large (pink) will miss this point.