我們知道在大語言模型(Large Language Model, LLM)中,存在所謂的尺度擴(kuò)展規(guī)律(Scaling Laws) [2],如Fig 1所示,即是:
LLM的性能會(huì)隨著模型的參數(shù)量、模型的訓(xùn)練量、模型的訓(xùn)練數(shù)據(jù)量的增加而增加
Fig 1. 大模型中的尺度擴(kuò)展規(guī)律,測(cè)試集損失隨著模型訓(xùn)練量、訓(xùn)練集數(shù)據(jù)量、模型參數(shù)量的增加而遞減(即是模型性能遞增)。
我們也知道模型的參數(shù)量、模型的訓(xùn)練量和模型的訓(xùn)練數(shù)據(jù)量都會(huì)影響到最終的計(jì)算預(yù)算(可以用FLOPs計(jì)算),因此LLM的性能可以說和計(jì)算預(yù)算直接掛鉤,這也是Fig 1 左圖所表示的。我們不禁會(huì)有個(gè)疑問,給定了模型的計(jì)算預(yù)算C,我們應(yīng)該怎么均衡模型參數(shù)量N和預(yù)訓(xùn)練的Token數(shù)量D,才能使得模型的預(yù)訓(xùn)練損失L最小化呢?我們期待得到最優(yōu)的模型參數(shù)Nopt和最優(yōu)的預(yù)訓(xùn)練Token數(shù)量Dopt,可以使得預(yù)訓(xùn)練損失最小,正如公式(1)所示。
(1)Nopt(C),Dopt(C)=arg?minN,D s.t. FLOPs(N,D)=CL(N,D)
作者探索這個(gè)規(guī)律的方法論也很直接,作者步進(jìn)遍歷了一遍不同的模型尺寸(從70M到16B參數(shù)量),也步進(jìn)遍歷了一遍預(yù)訓(xùn)練數(shù)據(jù)Token數(shù)量(從5B到400B),最終跑了超過400個(gè)組合的數(shù)據(jù)點(diǎn),不得不說有算力真的可以為所欲為。從直觀上看越大尺寸的模型需要越多訓(xùn)練的Token,當(dāng)然我們需要研究具體的比例,作者采用了三種不同的方法去找這個(gè)比例關(guān)系。
固定模型尺寸下的性能分析
這種方法是分別固定住模型尺寸(從70M到10B多個(gè)模型尺寸都需要實(shí)驗(yàn)),然后觀察訓(xùn)練了不同數(shù)量的Tokens數(shù)量后,在每一個(gè)節(jié)點(diǎn)時(shí)哪一個(gè)模型尺寸能夠達(dá)到最小的訓(xùn)練損失。如Fig 2 左圖 所示, 這里有些地方需要解釋。首先這里的橫坐標(biāo)是浮點(diǎn)計(jì)算量FLOPs,在不同模型尺寸下,相同的FLOPs能訓(xùn)練的Token數(shù)量是不同的,因此才會(huì)出現(xiàn)Fig 2左圖中同一個(gè)FLOPs中,大尺寸模型損失比小尺寸模型還大的情況。從Fig 2 左圖中,我們能發(fā)現(xiàn)在不同的FLOPs下,到達(dá)最小損失的模型尺寸是不一樣的(不太容易看出來,在左圖中是灰色點(diǎn),它們形成了一個(gè)包絡(luò)線),不同的FLOPs在對(duì)應(yīng)尺寸模型下能夠折算成訓(xùn)練過的Token數(shù)量,因此能夠畫出Fig 2 中圖和右圖,橫坐標(biāo)是FLOPs,縱坐標(biāo)是達(dá)到最小損失(也就是左圖的灰色點(diǎn))時(shí)的模型尺寸和過了的Tokens數(shù)。換句話說,F(xiàn)ig 2中圖和右圖就是給定計(jì)算預(yù)算C下的最佳模型尺寸Nopt和訓(xùn)練數(shù)據(jù)量Dopt,我們發(fā)現(xiàn)有Nopt∝Ca,Dopt∝Cb,通過實(shí)驗(yàn)可以算出a=0.50,b=0.50。
Fig 2. 訓(xùn)練曲線包絡(luò)。左側(cè)展示了我們所有不同的運(yùn)行情況。我們啟動(dòng)了一系列模型尺寸,從70M到10B,每個(gè)模型針對(duì)四個(gè)不同的余弦循環(huán)周期長(zhǎng)度。從這些曲線中,我們提取了每 FLOP 最小損失的包絡(luò)線,我們利用這些點(diǎn)來估計(jì)給定計(jì)算預(yù)算下的最佳模型尺寸(中間)和最佳訓(xùn)練 token 數(shù)量(右側(cè))。綠色顯示了基于訓(xùn)練 Gopher(5.76 × 10²³ FLOP)所用 FLOP 數(shù)量的最佳模型尺寸和訓(xùn)練 token 數(shù)量的預(yù)測(cè)。
固定計(jì)算預(yù)算下的性能分析
第一種方法的計(jì)算量FLOPs沒有固定,在此方法中我們固定計(jì)算量C(也就是所謂的IsoFLOP),分析等量計(jì)算下的最佳模型參數(shù)量Nopt。同時(shí),在知道了每個(gè)實(shí)驗(yàn)固定的計(jì)算量,和在此之下的最佳模型參數(shù)量后,也就可以反推訓(xùn)練Token數(shù)量。實(shí)驗(yàn)如Fig 3 左圖所示,可以發(fā)現(xiàn)在不同的固定計(jì)算量下(從6×1018到3×1021 FLOPs),遍歷不同尺寸的模型能夠發(fā)現(xiàn)在某些尺寸處會(huì)存在明顯的低谷,這個(gè)低谷就是在固定計(jì)算預(yù)算情況下的最佳模型參數(shù)量,由此也能繪制出Fig 3 中圖和右圖,繪制邏輯如第一種方法所述。不難發(fā)現(xiàn)同樣有Nopt∝Ca,Dopt∝Cb這個(gè)規(guī)律,算出a=0.49,b=0.51。
Fig 3. 等量浮點(diǎn)運(yùn)算曲線(IsoFLOP Curves):針對(duì)不同模型規(guī)模,通過調(diào)整訓(xùn)練令牌(token)數(shù)量,使得最終總浮點(diǎn)運(yùn)算量(FLOPs)保持恒定,并設(shè)置余弦周期長(zhǎng)度以匹配目標(biāo)FLOPs量。研究發(fā)現(xiàn),損失函數(shù)會(huì)出現(xiàn)一個(gè)明顯低谷(如左圖),這表明在給定FLOPs計(jì)算預(yù)算下,存在一個(gè)最優(yōu)的待訓(xùn)練模型?;谶@些低谷位置,我們推算出更大模型的最優(yōu)參數(shù)規(guī)模與令牌數(shù)量(中圖和右圖)。圖中綠色部分展示了在Gopher模型計(jì)算預(yù)算下,最優(yōu)模型的參數(shù)與令牌數(shù)量估計(jì)值。
對(duì)參數(shù)化損失函數(shù)進(jìn)行擬合
在第1和2中方法中已經(jīng)積累了很多最小損失L下的FLOPs(Nopt,Dopt)=C的數(shù)據(jù)點(diǎn)了,我們不妨把損失拆解為三大部分如公式(2)所示,其中第一項(xiàng)E為不可約損失,也就是自然文本的熵,是不可繼續(xù)減少的最基礎(chǔ)的損失項(xiàng)。第二項(xiàng)為(不完美的)參數(shù)量為N的Transformer模型訓(xùn)練過程中產(chǎn)生的損失(因?yàn)閰?shù)量N總是有限,也就是不完美的,因此總是在理想損失E的基礎(chǔ)上有超額損失),第三項(xiàng)則是(不完美的)訓(xùn)練數(shù)據(jù)量D下(因?yàn)橛?xùn)練數(shù)據(jù)量D不可能是無限的)的
產(chǎn)生的超額損失。
(2)L^(N,D)?E+ANα+BDβ
作者采用L-BFGS算法去最小化所謂的Huber loss(因?yàn)閿?shù)據(jù)點(diǎn)只有400多個(gè),這個(gè)loss作者說對(duì)離群點(diǎn)比較穩(wěn)健)去進(jìn)行估計(jì)(A,B,E,α,β),筆者也沒細(xì)究,讀者有興趣的可以翻閱 [3] 和 [4]。最終估計(jì)出來的參數(shù)為:
(3)E=1.69,A=406.4,B=410.7,α=0.34,β=0.28
在LLM Scaling Law的論文 [2] 中提出了一個(gè)估算: FLOPs(N,D)≈6ND,借此可以將公式(2)進(jìn)行變形,得到公式(4)
其中(4)Nopt(C)=G(C6)a,Dopt(C)=G−1(C6)b,其中G=(αAβB)1α+β,a=βα+β,b=αα+β
作者算得a=0.46,b=0.54,具體過程請(qǐng)自行參考原文。
給定計(jì)算量下的最優(yōu)設(shè)計(jì)
Fig 4是將以上三種預(yù)測(cè)方法繪制成計(jì)算量——最佳模型尺寸估計(jì)曲線圖
,其中那貼上了一些之前工作的估計(jì) [2] 和一些模型的對(duì)比,如Gopher(280B參數(shù)量)、GPT-3(175B參數(shù)量)和Megatron-NLG (530B)參數(shù)量。從圖中能發(fā)現(xiàn):
- 方法1和方法2估計(jì)出來的曲線基本上貼合,方法3估計(jì)出的模型尺寸在計(jì)算預(yù)算小的時(shí)候和前兩者基本貼合,但在大計(jì)算預(yù)算下會(huì)偏小些,不過也不會(huì)差距特別大。
- 主流的大模型,如Gopher、GPT3等在對(duì)應(yīng)的計(jì)算預(yù)算下,模型尺寸明顯偏大,基本上是貼著 [2] 的曲線走的。
為了證明本文提出的估計(jì)方法更佳準(zhǔn)確,作者在方法1和2中對(duì)齊Gopher的計(jì)算預(yù)算(大概是5.76×1023 FLOPs),找到了最佳模型尺寸,約是70B,作者將這個(gè)訓(xùn)練出來的模型稱之為Chinchilla,需要將這個(gè)模型的性能和Gopher進(jìn)行公平對(duì)比。注意到在方法1和2中,從Fig 2和Fig 3的右圖中可以找出給定預(yù)算下的最佳訓(xùn)練Token數(shù)量,對(duì)于Chinchilla來說是1.4-1.5T左右,因此Dopt/Nopt≈20。
作者在相當(dāng)多語言下游任務(wù)的基準(zhǔn)上進(jìn)行了測(cè)試,都發(fā)現(xiàn)Chinchilla對(duì)比Gopher存在普遍優(yōu)勢(shì),在一些任務(wù)中甚至超過了Megatron-NLG 530B模型。這些實(shí)驗(yàn)過于冗長(zhǎng),筆者就不展示細(xì)節(jié)了。
筆者讀后感
這篇論文的意義在于告訴我們,在給定了計(jì)算預(yù)算下,是存在一個(gè)最優(yōu)的模型尺寸和訓(xùn)練數(shù)據(jù)量的,他們存在一個(gè)比例(Dopt≈20Nopt),越大的模型就需要越多數(shù)據(jù)進(jìn)行訓(xùn)練,才能發(fā)揮出模型最優(yōu)的性能。這篇論文的發(fā)表時(shí)間比較早,是2022年,現(xiàn)在已經(jīng)有很多工作證實(shí)了在推理中進(jìn)行復(fù)雜策略可以有效提高模型性能 [5,6],并且這些推理策略同樣也存在Scaling Law。這意味著計(jì)算預(yù)算不僅可以花在預(yù)訓(xùn)練上,而且可以花在推理時(shí)的Scaling,這也是這篇文章沒有考慮到的點(diǎn)。當(dāng)然,在 [6] 中作者也承認(rèn),推理時(shí)的Scaling并非是萬能的,而是:
推理時(shí)計(jì)算與預(yù)訓(xùn)練計(jì)算并非一對(duì)一“可互換”。對(duì)于模型能力范圍內(nèi)的簡(jiǎn)單和中等難度問題,或者在推理(實(shí)時(shí)性)要求較低的情況下,測(cè)試時(shí)計(jì)算可以輕松彌補(bǔ)額外的預(yù)訓(xùn)練。然而,對(duì)于超出基礎(chǔ)模型能力范圍的具有挑戰(zhàn)性的問題,或者在推理(實(shí)時(shí)性)要求較高的情況下,預(yù)訓(xùn)練可能更有效于提升性能。也就是說預(yù)訓(xùn)練的地位并不是通過推理時(shí)的Scaling就可以替代的,預(yù)訓(xùn)練中分配一定量的計(jì)算預(yù)算對(duì)于全方面提高LLM的性能是必須的。結(jié)合了模型訓(xùn)練、模型推理的更為綜合的最優(yōu)配比,應(yīng)該是值得去研究的。
Reference
[1]. Hoffmann, Jordan, Sebastian Borgeaud, Arthur Mensch, Elena Buchatskaya, Trevor Cai, Eliza Rutherford, Diego de Las Casas et al. "Training compute-optimal large language models." arXiv preprint arXiv:2203.15556 (2022).
[2]. Kaplan, Jared, Sam McCandlish, Tom Henighan, Tom B. Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. "Scaling laws for neural language models." arXiv preprint arXiv:2001.08361 (2020).
[3]. J. Nocedal. Updating Quasi-Newton Matrices with Limited Storage. Mathematics of Computation, 35(151):773–782, 1980. ISSN 0025-5718. doi: 10.2307/2006193. URL https://www.jstor.org/stable/2006193 aka L-BFGS
[4]. P. J. Huber. Robust Estimation of a Location Parameter. The Annals of Mathematical Statistics, 35 (1):73–101, Mar. 1964. ISSN 0003-4851, 2168-8990. doi: 10.1214/aoms/1177703732. URL https://projecteuclid.org/journals/annals-of-mathematical-statistics/volume-35/issue-1/Robust-Estimation-of-a-Location-Parameter/10.1214/aoms/1177703732.full. aka Huber loss
[5]. https://fesianxu.github.io/2025/03/02/test-time-scaling-laws-20250302/, 《大模型推理時(shí)的尺度擴(kuò)展定律》
[6]. Snell, Charlie, Jaehoon Lee, Kelvin Xu, and Aviral Kumar. "Scaling llm test-time compute optimally can be more effective than scaling model parameters." arXiv preprint arXiv:2408.03314 (2024).