.png?width=1440&name=Papers%20wide%20(1).png)
Mar 06, 2025
February Papers: Learning to Scale
Written By:
Luke Prince, Luka Ribar, Paul Balanca, Luke Hudlass-Galley
Mar 06, 2025
Written By:
Luke Prince, Luka Ribar, Paul Balanca, Luke Hudlass-Galley
We're Hiring
Join us and build the next generation AI stack - including silicon, hardware and software - the worldwide standard for AI compute
Join our teamWelcome to Papers of the Month! This time around, our monthly selection of ML papers revolves around the central theme of scale – and learning how to scale efficiently. Scaling-laws for LLMs, multi-scale quantisation training and scaling test-time compute: it’s a rich buffet!
The first paper, Distillation Scaling Laws, presents a thorough study of distillation for Language Models, with the aim of estimating how student performance scales as a function of model size and amount of distillation data used – offering very useful insights, in an era where distillation pre-training of LLMs is becoming more and more widespread to improve “capability per watt”.
The problem of computational efficiency and cost reduction is also at the heart of Matryoshka Quantisation, DeepMind’s solution for training a quantised model that can then be easily served at different lower numerical precisions, by leveraging the nested structure of integer data types. And if you are a quantisation geek like we are, make sure to also read our summary of ParetoQ, a new unified framework to investigate the scaling laws that govern the trade-off between quantised model size and accuracy in extremely low-bit regimes.
Finally, we jump from training scaling laws to scaling up test-time compute, with a paper that introduces a recurrent block in LLMs at test-time to allow the model to perform iterative reasoning in latent space, without verbalizing its intermediate thoughts, to improve its performance.
We hope you enjoy these month’s papers as much as we did! If you have thoughts or questions, please reach out to us at .
Here’s our summary of this month’s chosen papers:
Authors: Dan Busbridge et al. (Apple)
When should you distil a small model from a larger one? Is compute better used to train a small model from scratch? The authors demonstrate that distillation is a better use of compute if a teacher model can be reused multiple times (e.g., for distilling many models, or for long term inference deployment) AND student compute budget is sufficiently limited.
Chasing ever more capable language models by parameter scaling is a double edged sword as the cost of compute-optimal training and deployment of models scales quadratically with parameter count. When accounting for inference costs, deep learning practitioners often choose to pre-train a smaller model for longer, although this produces diminishing capability returns in the limit of large datasets.
A well-known alternative is to attempt to distil the capabilities of a larger “teacher” model into a smaller “student” model. A standard technique for doing so is to minimise the KL-divergence between the teacher and student predicted token distributions. Despite being a well known technique, a compute-efficient distillation strategy has not been characterised yet. This is possibly due to some counter-intuitive relationships in capabilities of teacher and student models. One such relationship is that highly capable teacher models can distil worse student models compared to weaker teachers, a phenomenon known as the “capacity gap”, depicted in the following figure (where 𝑀=𝐷/𝑁 is the tokens-per-parameter ratio)
As such, the authors propose a scaling law for distillation that accounts for the capacity gap. They use this scaling law to analyse and derive an optimal distillation strategy under different assumed use cases for teacher models with different student training budgets.
Given the dependence of the authors’ scaling law on teacher parameter count 𝑁𝑇, teacher training tokens 𝐷𝑇, student parameter count 𝑁𝑆 and student training tokens 𝐷𝑆, the authors design two experiments to estimate parameters of their scaling law by isolating the contribution of teacher scaling and student scaling.
First they fix the ratio between teacher parameters and training tokens (𝐷𝑇/𝑁𝑇= 20) according to Chinchilla-optimal pre-training, then vary student parameters and training tokens for a given set of compute budgets.
The second set of experiments do the reverse: fix the ratio of student parameters and tokens (examining many ratios >= 20) and vary teacher parameters and training tokens for a given set of compute budgets.
The authors model the total compute cost of distillation and examine the optimal distillation strategy under different assumptions about the relative costs of teacher training and inference using their derived scaling law.
First, when there is sufficient data or compute available for training a student, pre-training outperforms distillation. However, as student models get larger, the crossover point where student pre-training becomes preferable requires increasingly more compute. This finding holds only when the cost of pre-training the teacher model can be amortised. Second, when pre-training teacher models can not be amortised, pre-training is always preferable to distillation.
This is a useful reference for designing recipes that produce efficient model families. It is clear that distillation can be useful under compute-constrained settings, but still requires pre-training highly capable large models where we expect to get a clear return in pre-training investment. This somewhat begs the question of how to incentivise pre-training of large models for the model distilleries of the world to benefit from it! The third-axis of compute-optimal test-time scaling muddies the waters even further. It will be important to understand to what extent small distilled models benefit from additional inference compute through chain-of-thought/retrieval.
Full paper:
Authors: Pranav Nair, et al. (Google DeepMind)
The authors showcase a method for training a single quantised model (from a pre-trained checkpoint) that can be used at different precision levels. They achieve this by simultaneously training the model to work at int8
, int4
, and int2
precisions, by optimising (at the same time) the eight, four, and two most significant bits of the integer representation, respectively. They show that this approach can be used in conjunction with learning-based quantisation methods, leading to minimal degradation at 8-bit and 4-bit precision levels (compared to optimising for a single precision level), and significantly improving the int2
baseline.
Quantisation methods are broadly classified into two categories: learning-free and learning-based. Learning-free methods independently quantise the model’s layers by minimising each layer’s output error on a small calibration dataset, and generally do not involve backpropagation. Due to this, they are computationally cheaper, but tend to perform worse at very low-bit quantisation. Learning-based methods tend to perform better at low-bit precision, at the cost of the more computationally expensive learning through backpropagation.
The authors focus on two gradient descent-based methods that they show can be used with their Matryoshka Quantisation (MatQuant) approach: quantisation-aware training (QAT) and OmniQuant.
QAT learns the quantised weights by minimising the end-to-end cross-entropy loss through gradient descent. The loss can be defined either in terms of the next-token prediction on a training dataset, or minimising the output differences between the original and the quantised models.
As the parameter quantisation function is non-differentiable, training is usually conducted by applying a straight-through estimator for the quantised parameters, i.e., the quantisation function is regarded as an identity function during the backwards pass calculation.
Like the “learning-free” methods, finds the quantised weights by minimising the output difference (between the original and the quantised model) of each layer independently; however, this is done by introducing additional scaling and shifting parameters that are optimised through backpropagation. As such, it is a computationally cheaper approach compared to QAT, as it only optimises over these additional parameters.
Quantisation methods discussed previously optimise a model at a single predefined precision. On the other hand, the proposed technique, MatQuant leads to an integer-quantised model, that can be used at different precision levels during inference. This is done by adding individual loss functions for each of the 8/4/2 most-significant bits of the integer representation, which can be used in conjunction with the standard learning-based methods. To obtain an 𝑟-bit representation from the full 𝑐-bit one, the numbers are first right-shifted by 𝑐−𝑟 bits, followed by a left-shift of the same order (with appropriate rounding). The individual-bit losses can then be summed up with additional scaling factors controlling the relative magnitude of each term in the loss.
This scheme changes the obtained quantised weight distribution compared to the baseline techniques: in Figure 1c, the weights tend to have larger magnitudes.
Table 1. Downstream task results using MatQuant with OmniQuant. MatQuant shows slight degradation at 3/4/6/8-bit integer precisions, but showcases improvement for int2.
The authors test their method using both OmniQuant and QAT approaches, optimising for 8, 4, and 2-bit integer precisions; the baseline numbers are the two methods used on a single precision level. The methods were tested using Gemma and Mistral models, on a variety of downstream tasks.
Main observations:
int8
and int4
precisions, but improves the baseline on int2
.int