為了提高LLM底座的通用能力,通常預(yù)訓(xùn)練數(shù)據(jù)都會包含有各種領(lǐng)域的數(shù)據(jù),比如The Pile [2] 就是一個800GB大小的,涵蓋了22個不同領(lǐng)域的常用預(yù)訓(xùn)練數(shù)據(jù)集,如Fig 1所示。對于LLM預(yù)訓(xùn)練而言(采用next token prediction的自回歸方式進(jìn)行預(yù)訓(xùn)練),不同領(lǐng)域數(shù)據(jù)的配比很重要,之前的工作大部分基于啟發(fā)式的方法拍定配比(比如均勻采樣,或者根據(jù)不同領(lǐng)域數(shù)據(jù)的數(shù)據(jù)量大小進(jìn)行采樣),由于不同領(lǐng)域數(shù)據(jù)的學(xué)習(xí)難度各不相同,這些啟發(fā)式的配比方法不能保證LLM在預(yù)訓(xùn)練階段充分利用數(shù)據(jù)。本文旨在利用一個小規(guī)模的代理模型(280M參數(shù)量),通過Group DRO (Distributionally Robust Optimization) [3-4] [腳注1] 的方式去尋找最佳的數(shù)據(jù)配比,然后在全量參數(shù)的LLM(8B參數(shù)量)上進(jìn)行預(yù)訓(xùn)練,通過這種方式,作者發(fā)現(xiàn)預(yù)訓(xùn)練速度能加速2.6倍,并且最終能獲得6.5%的效果提升,最終整個方法被稱之為 DoReMi (Domain Reweighting with Minimax Optimization),接下來讓我們具體看看這個方法是怎么實(shí)施的。
Fig 1. The Pile數(shù)據(jù)集中包含有22個不同領(lǐng)域的數(shù)據(jù)。
首先,我們先理解一個點(diǎn):LLM在通過自回歸進(jìn)行預(yù)訓(xùn)練的時候,其最理想狀態(tài)應(yīng)該是能擬合不同領(lǐng)域數(shù)據(jù)的分布,也就是各個領(lǐng)域數(shù)據(jù)的困惑度都最小化 [腳注2]。因此,如果我們能在“監(jiān)控”預(yù)訓(xùn)練過程中的各個領(lǐng)域數(shù)據(jù)中,通過損失大小確定擬合程度最差的樣本來自于哪個領(lǐng)域,然后適當(dāng)提高這個領(lǐng)域數(shù)據(jù)的采樣率,那么理論上LLM就能有更多機(jī)會去學(xué)習(xí)較難擬合的領(lǐng)域的數(shù)據(jù)分布,從而以更小的訓(xùn)練代價實(shí)現(xiàn)更好的擬合效果。當(dāng)然,完整的LLM參數(shù)量動輒10億(B)規(guī)模,訓(xùn)練一次的成本較高,我們假設(shè)參數(shù)量較小的(億級別,如280M)LLM模型的預(yù)訓(xùn)練趨勢和完整參數(shù)量(8B)的LLM模型是類似的,因此可以共享相同的數(shù)據(jù)配比,我們稱此處的較小參數(shù)量模型為代理模型,通過代理模型我們能找出最佳數(shù)據(jù)配比,然后給完整參數(shù)模型進(jìn)行預(yù)訓(xùn)練,整個過程可以參考Fig 2.示意。
由于需要判斷擬合程度最差的樣本,這是一個“比較”的過程,比較的一方是代理模型(proxy model),而比較的另一方則是一個參考模型(reference model) ,參考模型指的是在默認(rèn)的啟發(fā)式數(shù)據(jù)配比下,用和代理模型同樣參數(shù)量的模型(也就是280M),預(yù)訓(xùn)練得到的模型。
Fig 2. DoReMi的基本框架示意圖,通過一個代理模型對預(yù)訓(xùn)練數(shù)據(jù)集合中不同領(lǐng)域數(shù)據(jù)的采樣率進(jìn)行重新參數(shù)化,得到最佳數(shù)據(jù)配比后去訓(xùn)練最終的LLM。
至此,我們需要了解的DoReMi的基本思路已經(jīng)介紹完了,我們開始深入細(xì)節(jié)。首先先進(jìn)行形式化表示,假設(shè)預(yù)訓(xùn)練數(shù)據(jù)集中具有k個域的數(shù)據(jù),表示為Di,i=1,?,k,每個域的權(quán)重記為α∈Δk,可以理解為在k個域中的采樣概率分布,也就是∑ikαi=1。那么訓(xùn)練數(shù)據(jù)的分布可以表示為:
(1)Pα=∑i=1kαi⋅unif(Di)
其中unif(D)=1|D|∑x∈Dδx是一個在數(shù)據(jù)集合D上的均勻分布,如果在x=x′ 的時候 δx(x′)=1,否則δx(x′)=0。整個DoReMi的過程,可以描述為:
第一步,訓(xùn)練一個參考模型
首先采用初始化的數(shù)據(jù)配比αref先訓(xùn)練一個280M參數(shù)量的參考模型pref,此處的數(shù)據(jù)配比可以根據(jù)啟發(fā)式的方式選取,比如采用數(shù)據(jù)集中不同領(lǐng)域數(shù)據(jù)的數(shù)據(jù)量比例作為配比,見Fig 1.的baseline一列。
第二步,通過Group DRO的方式訓(xùn)練一個代理模型并且得到新的域權(quán)重
有了參考模型pref之后,就可以開始訓(xùn)練代理模型和學(xué)習(xí)新的域權(quán)重了(也即是學(xué)習(xí)得到α¯),整個過程采用DRO-LM的框架 [5] ,采用Group DRO優(yōu)化器進(jìn)行訓(xùn)練,其中的θ為代理模型的參數(shù)。整個框架的公式如公式(2)所示,不難看出,這是一個minmax
過程,其優(yōu)化目標(biāo)正如我們前文討論的,優(yōu)化(體現(xiàn)在min目標(biāo)上)各域上最差擬合程度樣本(體現(xiàn)在max目標(biāo)上)的損失,我們著重講解下這個公式。
(2)minθmaxα∈ΔkL(θ,α):=∑i=1kαi⋅[1∑x∈Di|x|∑x∈Di?θ(x)−?ref(x)]
其中的?θ(x)=−log?pθ(x)和?ref(x)=−log?pref(x)為代理模型和參考模型的負(fù)對數(shù)似然概率 [腳注3],|x|是當(dāng)前樣本x的token數(shù)量。?θ(x)−?ref(x) 可視為是超額損失(excess loss),這個損失度量了對于代理模型而言,當(dāng)前樣本x的可優(yōu)化空間。超額損失越大,說明該樣本需要更多訓(xùn)練才能達(dá)到參考模型的效果;超額損失越小,則有兩種可能,第一種可能是?ref很大,這表明了這個是一個高熵的樣本,其生成會趨向于均勻分布,因此難以學(xué)習(xí),第二種可能是?θ很小,這說明對于樣本x而言,代理模型已經(jīng)學(xué)習(xí)得足夠好了,屬于簡單樣本,這兩種情況下都需要適當(dāng)減少權(quán)重,減少該域樣本的采樣。首先從這個minmax
的內(nèi)層目標(biāo)(也即是max
)看起,此時代理模型不進(jìn)行參數(shù)更新,優(yōu)化項是α∈Δk, 就是根據(jù)超額損失的大小去啟發(fā)式更新權(quán)重,然后看到外層的min
目標(biāo),此時優(yōu)化項是代理模型參數(shù)θ,也即是從內(nèi)層找到了最大的超額損失后,嘗試讓代理模型去擬合這個超額損失。通過交替地進(jìn)行最小最大優(yōu)化,從而訓(xùn)練得到一個代理模型和新的域權(quán)重α¯。 最終的域權(quán)重,則是從每一步的訓(xùn)練過程中進(jìn)行累計平均得到,即是α¯=1T∑t=1Tαt。
這個過程,用算法描述的形式表述則是:
輸入:各個域的數(shù)據(jù)D={D1,?,Dk},訓(xùn)練步數(shù)T,batch size大小b和更新步長η ,平滑系數(shù)c∈[0,1],本實(shí)現(xiàn)采用的c=10−3。
- 初始化代理模型參數(shù)θ0
- 初始化域權(quán)重α0=1k1∈Rk
對于t從1到T開始循環(huán):
- 從 Pu 中采樣一個大小為 b 的小批量B={x1,?,xb},其中u=1k1
- 令|x|為樣本x的token長度
- 計算每個域i∈{1,?,k}的超額損失(?∗,j(x)是第j個token的損失),此處的max(?,0)就是在實(shí)現(xiàn)公式(2)里面提到內(nèi)層
max
過程,需要保證超額損失的非負(fù)性,其中的?∗,j(x)=−log?p∗(xj|x1,?,xj−1)。(3)λt[i]←1|x|∑x∈B∩Di∑j=1|x|max(?θt−1,j(x)−?ref,j(x),0) - 啟發(fā)式地更新域權(quán)重(指數(shù)更新): α′←αt−1⊙exp?(ηλt)
- 更新歸一化和平滑的域權(quán)重:αt←(1−c)αt′∑i=1kαt′[i]+cu,αt∈Rk, 此處采用平滑更新的方式,是希望整個域權(quán)重更新的過程更加平滑(因?yàn)槊看沃灰姷搅水?dāng)前batch的數(shù)據(jù),因此可能存在噪聲),最好是在先驗(yàn)分布u的基礎(chǔ)上進(jìn)行增量更新。
- 使用目標(biāo)L(θt−1,αt)更新代理模型的參數(shù)θt(可以采用Adam優(yōu)化器)。
結(jié)束循環(huán)
返回:α¯=1T∑t=1Tαt
第三步,用新的權(quán)重訓(xùn)練一個完整的LLM
采用第二步得到的α¯構(gòu)建新的訓(xùn)練分布Pα¯,從中采樣數(shù)據(jù)去預(yù)訓(xùn)練最終需要的完整參數(shù)量的LLM(本實(shí)驗(yàn)中是8B)。
迭代式的DoReMi
第一輪迭代的DoReMi的αref是啟發(fā)式得到的,不是最優(yōu)的選擇,因此整個DoReMi過程是可以迭代進(jìn)行的,即是當(dāng)?shù)谝惠喌械玫搅?span data-eeimg="1" data-tex="\bar{\alpha}_1">α¯1之后,可以將第二輪迭代的αref:=α¯1,然后重復(fù)整個DoReMi的過程,直到域權(quán)重收斂為止,在本文收斂的定義是||α¯−αref||∞<10−3,此處的無窮范數(shù)即是α¯和αref差的最大值。從作者的經(jīng)驗(yàn)看,在GLaM數(shù)據(jù)集上只需要3輪迭代即可收斂。
介紹完了整個DoReMi的操作過程,我們看下實(shí)驗(yàn)結(jié)果。作者是在The Pile和GLaM [6] 這兩個預(yù)訓(xùn)練數(shù)據(jù)集上進(jìn)行預(yù)訓(xùn)練的,The Pile的情況前文介紹了,GLaM是一個具有8個域的文本數(shù)據(jù)集,由于這個數(shù)據(jù)集的域權(quán)重根據(jù)下游任務(wù)的效果而調(diào)節(jié)得到,因此GLaM的域權(quán)重可以視為是真實(shí)標(biāo)簽,由此檢驗(yàn)DoReMi的域權(quán)重調(diào)整結(jié)果。從Fig 1.中可以看出,經(jīng)過DoReMi后,在Pile數(shù)據(jù)集上不同領(lǐng)域的權(quán)重已經(jīng)有了很大的變化,有些域的權(quán)重存在大幅度的下降,然而如果我們看到DoReMi (280M -> 8B) 在Pile數(shù)據(jù)集上保留驗(yàn)證集上所有域的困惑度,如Fig 3.所示,則會發(fā)現(xiàn)DoReMi在所有域上的困惑度都是明顯下降。這并不是很符合直覺,因?yàn)槟承┯颍ㄈ?code>Arxiv、PubMed central
)的權(quán)重下降了很多,意味著LLM預(yù)訓(xùn)練過程中采樣到這些域數(shù)據(jù)的幾率下降了,為什么還能得到困惑度的下降呢?
一種可能性是,正如在前文討論的,超額損失低的那部分樣本都是難樣本(接近均勻分布)或者簡單樣本,簡單樣本容易學(xué)習(xí),不需要那么多樣本支持學(xué)習(xí),而難樣本則由于難度太高,即便提高了樣本量也無法學(xué)習(xí)出來,因此降低采樣也沒帶來效果損失。并且很幸運(yùn)的,提高了其他中度難度樣本的采樣比例后,讓模型泛化能力得到了進(jìn)步,因此在各個域數(shù)據(jù)上的表現(xiàn)都沒折損,都有提升。
Fig 3. 對比基線,采用了DoReMi后在Pile數(shù)據(jù)集得到所有域的保留驗(yàn)證集上都得到了困惑度的明顯下降。
讓我們看看DoReMi在下游任務(wù)上的表現(xiàn),從Fig 4. (a) 中能發(fā)現(xiàn),在利用The Pile預(yù)訓(xùn)練集的情況下,采用了DoReMi后,在所有的訓(xùn)練步中,性能(下游任務(wù)指標(biāo),用的是精準(zhǔn)匹配的比例,筆者覺得可能類似ROUGE的指標(biāo))都能得到持續(xù)的提升。從Fig 4. (b) 是采用GLaM數(shù)據(jù)集作為預(yù)訓(xùn)練的結(jié)果,有以下結(jié)論:
- 采用多輪迭代的DoReMi(round 1 vs round 2),采用多輪迭代的效果會持續(xù)比單輪的好。
- 采用了單輪的DoReMi效果沒有基線好,可能是由于GLaM數(shù)據(jù)集本身的域只有8個,DoReMi的發(fā)揮空間不如The Pile數(shù)據(jù)集。而采用了多輪DoReMi的效果能超越基線,可能說明對于域比較少的數(shù)據(jù)集,需要多輪迭代才能得到較好效果。
- 采用多輪迭代的DoReMi,其效果接近最佳權(quán)重(通過下游任務(wù)調(diào)優(yōu)得到)的效果。
再看到Fig 4. (c), 這是在GLaM數(shù)據(jù)集中多輪迭代DoReMi的權(quán)重和最佳權(quán)重的對比,可以發(fā)現(xiàn)采用了DoReMi后,的確權(quán)重都趨向于了最佳權(quán)重了,這也證實(shí)了DoReMi的合理性和有效性。
Fig 4. DoReMi在下游任務(wù)中的模型性能對比
作者同樣做了消融試驗(yàn),去回答DoReMi中幾個關(guān)鍵問題:
- Q1:是否挑選最難的樣本或者最簡單的樣本,而不是超額損失最大的樣本效果會更好?A1:當(dāng)前的挑選標(biāo)準(zhǔn)是超額損失,即是?θ(x)−?ref(x),如果挑選標(biāo)準(zhǔn)變成最難樣本,即是?θ(x),或者挑選最簡單樣本,即是−?ref(x),試驗(yàn)效果會如何呢?見Fig 5. 右側(cè)所示,我們發(fā)現(xiàn)采用了最簡單樣本或者最難樣本的試驗(yàn),效果都不如DoReMi,并且采用了最簡單樣本的方案,對比基線都遠(yuǎn)遠(yuǎn)落后。這說明了,學(xué)習(xí)最簡單樣本明顯不是一個好主意,這使得大模型的底座能力缺失嚴(yán)重,單純學(xué)習(xí)最難的樣本也能提升LLM的能力,但是光是關(guān)注最難的樣本,而忽略了中等難度的樣本,則也不是最優(yōu)的方案。這個也和前面分析困惑度的試驗(yàn)結(jié)論遙相呼應(yīng)了。
Fig 5. 在 The Pile 數(shù)據(jù)集上訓(xùn)練的模型的平均下游準(zhǔn)確度。左側(cè): 在 DoReMi 中將參考/代理模型的規(guī)模從 70M 增加到 280M,可以提高 8B 主模型的下游準(zhǔn)確度,但這一趨勢在 1B 代理模型中并未繼續(xù)。我們假設(shè) Group DRO 優(yōu)化器在更大規(guī)模的代理模型中表現(xiàn)不佳。右側(cè): 僅針對最難或最容易的領(lǐng)域
- Q2:提高代理模型的參數(shù)尺寸,是否能獲得最后效果的提升?A2:考慮采用不同的代理模型尺寸,如70M、150M、280M、1B參數(shù)量,而最終模型的尺寸仍然是8B,是否會觀察到scaling law呢?如Fig 5. 左側(cè)圖片所示,適當(dāng)提升代理模型尺寸(
70M -> 150M -> 280M
)可以提高最終模型的效果,但是當(dāng)代理模型尺寸達(dá)到一定大?。ㄈ?B)后,反而出現(xiàn)了性能下降。因此對于代理模型而言,也并不是尺寸越大越好。 - Q3:如果代理模型達(dá)到和最終模型同樣的尺寸,代理模型和最終模型的效果對比如何?A3:這個問題其實(shí)也很符合直覺,代理模型和最終模型采用的采樣策略是不同的(損失重參數(shù)化 vs 重采樣)。作者嘗試將代理模型和最終模型的參數(shù)量都設(shè)置為相同(為了試驗(yàn)對比公平),然后對比基線、DoReMi (x -> x)和代理模型的表現(xiàn),如Fig 6所示,我們發(fā)現(xiàn)采用了代理模型的表現(xiàn)都低于最終的主模型,并且隨著模型尺寸增大,性能差別則越大。并且在1B規(guī)模的代理模型中,甚至性能還不如基線(但是其DoReMi結(jié)果還是比基線好),這意味即便代理模型沒有訓(xùn)練得很好,在整個DoReMi體系下仍然能提升最終模型的效果。
Fig 6. DoReMi 主模型和相同規(guī)模代理模型的困惑度,盡管 1B 代理模型的質(zhì)量相對較低,但由此產(chǎn)生的領(lǐng)域權(quán)重仍然能夠改善主模型。
Reference
[1]. Xie, Sang Michael, Hieu Pham, Xuanyi Dong, Nan Du, Hanxiao Liu, Yifeng Lu, Percy S. Liang, Quoc V. Le, Tengyu Ma, and Adams Wei Yu. "Doremi: Optimizing data mixtures speeds up language model pretraining." Advances in Neural Information Processing Systems 36 (2024). aka DoReMi
[2]. Leo Gao, Stella Biderman, Sid Black, Laurence Golding, Travis Hoppe, Charles Foster, Jason Phang, Horace He, Anish Thite, Noa Nabeshima, Shawn Presser, and Connor Leahy. The pile: An 800gb dataset of diverse text for language modeling. arXiv, 2020. aka The Pile
[3]. Arkadi Nemirovski, Anatoli Juditsky, Guanghui Lan, and Alexander Shapiro. Robust stochastic approximation approach to stochastic programming. SIAM Journal on optimization, 19(4):1574–1609, 2009.
[4]. Shiori Sagawa, Pang Wei Koh, Tatsunori B. Hashimoto, and Percy Liang. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. In International Conference on Learning Representations (ICLR), 2020.
[5]. Yonatan Oren, Shiori Sagawa, Tatsunori Hashimoto, and Percy Liang. Distributionally robust language modeling. In Empirical Methods in Natural Language Processing (EMNLP), 2019. aka DRO-LM
[6]. Nan Du, Yanping Huang, Andrew M. Dai, Simon Tong, Dmitry Lepikhin, Yuanzhong Xu, M. Krikun, Yanqi Zhou, Adams Wei Yu, Orhan Firat, Barret Zoph, Liam Fedus, Maarten Bosma, Zongwei Zhou, Tao Wang, Yu Emma Wang, Kellie Webster, Marie Pellat, Kevin Robinson, K. Meier-Hellstern, Toju Duke, Lucas Dixon, Kun Zhang, Quoc V. Le, Yonghui Wu, Zhifeng Chen, and Claire Cui. GLaM: Efficient scaling of language models with mixture-of-experts. arXiv, 2021. aka GLaM
腳注區(qū)域:
[腳注1]: Group DRO 的關(guān)鍵在于通過最小化最壞情況下的損失來優(yōu)化領(lǐng)域權(quán)重,從而使得模型在所有領(lǐng)域上都能達(dá)到較好的性能。
[腳注2]: 在自然語言處理(NLP)中,困惑度(Perplexity)是評估語言模型性能的一個重要指標(biāo)。它衡量模型對測試數(shù)據(jù)的預(yù)測能力,具體計算方法如下:對于一個給定的詞序列 W=(w1,w2,?,wT),困惑度 PP(W) 的計算公式為:PP(W)=P(w1,w2,?,wT)−1T,其中的P(w1,?,wT)為語言模型建模的聯(lián)合概率分布,而T為序列長度。低困惑度: 如果語言模型能夠準(zhǔn)確預(yù)測詞序列,那么它給出的聯(lián)合概率會較高,從而導(dǎo)致困惑度較低。這意味著模型對測試數(shù)據(jù)的預(yù)測能力較強(qiáng)。高困惑度: 如果模型對詞序列的預(yù)測能力較差,給出的聯(lián)合概率會較低,導(dǎo)致困惑度較高
[腳注3]: 此處計算的方法是,給定一個樣本x={x1,x2,?,xN},其中N表示序列長度,xi表示token,那么p(xi|xi−1,xi−2,?,x1)就是當(dāng)前token xi被預(yù)測正確的概率。通過求對數(shù)似然和,能得到log?p(x),也即是對每個token的對數(shù)概率進(jìn)行加和。這代表了這個序列x被當(dāng)前語言模型采樣出來的概率。