What is grokking in machine learning?
Neural networks achieve perfect generalization, well beyond the point of overfitting, in some cases by analyzing a pattern in the data. In a potentially groundbreaking study, OpenAI researchers (Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, Vedant Misra) explored the generalization of neural networks on small, algorithmically generated datasets. The team explored generalization based on dataset size and found that smaller datasets require large amounts of optimization for generalization.
The generalization of overparameterized neural networks has long piqued the curiosity of the machine learning community because it runs counter to insights drawn from classical learning theory. Researchers have demonstrated that training networks on small, algorithmically generated datasets are prone to manifesting unusual generalization patterns – detached from performance on the training set – more visibly compared to training sets. data derived from natural data. Experiments can be replicated on a single GPU.
What is Grokking
Suppose you train an overparameterized neural network (one with more parameters than the number of data points in your dataset) beyond the point where it has stored training data (as indicated by low d loss). training and a high loss of value). In this case, the network will suddenly learn to generalize, as indicated by a rapidly decreasing loss of val (also called “grok”). Practitioners typically pause on training networks at the first sign of overfitting (as shown by a growing train/val loss gap). This goes against traditional statistical wisdom, which recommends using under-parameterized models to force the model to learn the rule (and thus generalize to new situations).
Source: Grokking: generalizing beyond overfitting on small algorithmic datasets
In their article “Grocking: Generalization Beyond Overfitting On Small Algorithmic Datasets”, the authors present some conclusions about Grokking and generalization:
- Neural networks can generalize to empty places in various tables of binary operations.
- Validation accuracy can jump sharply from random to perfect generalization long after significant overfitting. This is called “grokking”.
- For a variety of binary operations, there are data efficiency curves.
- Empirically, as the size of the dataset decreases, the amount of optimization needed for generalization increases rapidly.
- Weight loss is particularly useful for improving the generalization of grokking tasks.
- The symbol incorporations discovered by these networks sometimes reveal the discernible structure of the mathematical objects represented.
Deep learning practitioners see small improvements in validation accuracy once the validation loss stops decreasing. A double validation loss descent has been observed in rare cases and is considered an outlier.
The researchers found improved generalization after initial overfitting for a range of models, optimizers, and dataset sizes. Such behaviors are typical for all binary operations for dataset sizes close to the minimum dataset size for which the network has generalized within the allocated optimization budget, the researchers noted. For larger datasets, the learning and validation curves tend to align.
Researchers have used convolutional neural networks to study a wide variety of generalization or complexity measures to identify which are predictive of generalization performance. Flatness-based measurements that assess the sensitivity of the trained neural network to parameter perturbations have been shown to be the most predictive. Power et al. therefore hypothesized that the reported grokking phenomenon was caused by SGD noise forcing the optimization towards flatter/simpler solutions that generalize better.
Additionally, the researchers noticed an interesting phenomenon: the number of optimization steps required to achieve a given level of performance increases rapidly as the size of the training dataset is reduced. Since this represents a way to trade off computation for performance on smaller amounts of data, it would be worth investigating in future work whether the effect is also present for other data sets.