本文轉(zhuǎn)自徐飛翔的“一文理解Ranking Loss/Contrastive Loss/Margin Loss/Triplet Loss/Hinge Loss”
版權(quán)聲明:本文為博主原創(chuàng)文章,遵循 CC 4.0 BY-SA 版權(quán)協(xié)議,轉(zhuǎn)載請(qǐng)附上原文出處鏈接和本聲明。
ranking loss函數(shù):度量學(xué)習(xí)
不像其他損失函數(shù),比如交叉熵?fù)p失和均方差損失函數(shù),這些損失的設(shè)計(jì)目的就是學(xué)習(xí)如何去直接地預(yù)測(cè)標(biāo)簽,或者回歸出一個(gè)值,又或者是在給定輸入的情況下預(yù)測(cè)出一組值,這是在傳統(tǒng)的分類(lèi)任務(wù)和回歸任務(wù)中常用的。ranking loss的目的是去預(yù)測(cè)輸入樣本之間的相對(duì)距離。這個(gè)任務(wù)經(jīng)常也被稱(chēng)之為度量學(xué)習(xí)(metric learning)。
在訓(xùn)練集上使用ranking loss函數(shù)是非常靈活的,我們只需要一個(gè)可以衡量數(shù)據(jù)點(diǎn)之間的相似度度量就可以使用這個(gè)損失函數(shù)了。這個(gè)度量可以是二值的(相似/不相似)。比如,在一個(gè)人臉驗(yàn)證數(shù)據(jù)集上,我們可以度量某個(gè)兩張臉是否屬于同一個(gè)人(相似)或者不屬于同一個(gè)人(不相似)。通過(guò)使用ranking loss函數(shù),我們可以訓(xùn)練一個(gè)CNN網(wǎng)絡(luò)去對(duì)這兩張臉是否屬于同一個(gè)人進(jìn)行推斷。(當(dāng)然,這個(gè)度量也可以是連續(xù)的,比如余弦相似度。)
在使用ranking loss的過(guò)程中,我們首先從這兩張(或者三張,見(jiàn)下文)輸入數(shù)據(jù)中提取出特征,并且得到其各自的嵌入表達(dá)(embedded representation,譯者:見(jiàn)[1]中關(guān)于數(shù)據(jù)嵌入的理解)。然后,我們定義一個(gè)距離度量函數(shù)用以度量這些表達(dá)之間的相似度,比如說(shuō)歐式距離。最終,我們訓(xùn)練這個(gè)特征提取器,以對(duì)于特定的樣本對(duì)(sample pair)產(chǎn)生特定的相似度度量。
盡管我們并不需要關(guān)心這些表達(dá)的具體值是多少,只需要關(guān)心樣本之間的距離是否足夠接近或者足夠遠(yuǎn)離,但是這種訓(xùn)練方法已經(jīng)被證明是可以在不同的任務(wù)中都產(chǎn)生出足夠強(qiáng)大的表征的。
ranking loss的表達(dá)式
正如我們一開(kāi)始所說(shuō)的,ranking loss有著很多不同的別名,但是他們的表達(dá)式卻是在眾多設(shè)置或者場(chǎng)景中都是相同的并且是簡(jiǎn)單的。我們主要針對(duì)以下兩種不同的設(shè)置,進(jìn)行兩種類(lèi)型的ranking loss的辨析
- 使用一對(duì)的訓(xùn)練數(shù)據(jù)點(diǎn)(即是兩個(gè)一組)
- 使用三元組的訓(xùn)練數(shù)據(jù)點(diǎn)(即是三個(gè)數(shù)據(jù)點(diǎn)一組)
這兩種設(shè)置都是在訓(xùn)練數(shù)據(jù)樣本中進(jìn)行距離度量比較。
成對(duì)樣本的ranking loss
在這個(gè)設(shè)置中,由訓(xùn)練樣本中采樣到的正樣本和負(fù)樣本組成的兩種樣本對(duì)作為訓(xùn)練輸入使用。正樣本對(duì)(?,
?)由兩部分組成,一個(gè)錨點(diǎn)樣本
?和 另一個(gè)和之標(biāo)簽相同的樣本
,這個(gè)樣本
與錨點(diǎn)樣本在我們需要評(píng)價(jià)的度量指標(biāo)上應(yīng)該是相似的(經(jīng)常體現(xiàn)在標(biāo)簽一樣);負(fù)樣本對(duì)
由一個(gè)錨點(diǎn)樣本
?和一個(gè)標(biāo)簽不同的樣本
組成,
?在度量上應(yīng)該和
不同。(體現(xiàn)在標(biāo)簽不一致)
現(xiàn)在,我們的目標(biāo)就是學(xué)習(xí)出一個(gè)特征表征,這個(gè)表征使得正樣本對(duì)中的度量距離盡可能的小,而在負(fù)樣本對(duì)中,這個(gè)距離應(yīng)該要大于一個(gè)人為設(shè)定的超參數(shù)——閾值
。成對(duì)樣本的ranking loss強(qiáng)制樣本的表征在正樣本對(duì)中擁有趨向于0的度量距離,而在負(fù)樣本對(duì)中,這個(gè)距離則至少大于一個(gè)閾值。用
分別表示這些樣本的特征表征,我們可以有以下的式子:
對(duì)于正樣本對(duì)來(lái)說(shuō),這個(gè)loss隨著樣本對(duì)輸入到網(wǎng)絡(luò)生成的表征之間的距離的減小而減少,增大而增大,直至變成0為止。
對(duì)于負(fù)樣本來(lái)說(shuō),這個(gè)loss只有在所有負(fù)樣本對(duì)的元素之間的表征的距離都大于閾值 的時(shí)候才能變成0。當(dāng)實(shí)際負(fù)樣本對(duì)的距離小于閾值的時(shí)候,這個(gè)loss就是個(gè)正值,因此網(wǎng)絡(luò)的參數(shù)能夠繼續(xù)更新優(yōu)化,以便產(chǎn)生更適合的表征。這個(gè)項(xiàng)的loss最大值不會(huì)超過(guò)
,在
的時(shí)候取得。這里設(shè)置閾值的目的是,當(dāng)某個(gè)負(fù)樣本對(duì)中的表征足夠好,體現(xiàn)在其距離足夠遠(yuǎn)的時(shí)候,就沒(méi)有必要在該負(fù)樣本對(duì)中浪費(fèi)時(shí)間去增大這個(gè)距離了,因此進(jìn)一步的訓(xùn)練將會(huì)關(guān)注在其他更加難分別的樣本對(duì)中。
假設(shè)用分別表示樣本對(duì)兩個(gè)元素的表征,
是一個(gè)二值的數(shù)值,在輸入的是負(fù)樣本對(duì)時(shí)為0,正樣本對(duì)時(shí)為1,距離
是歐式距離,我們就能有最終的loss函數(shù)表達(dá)式:
三元組樣本對(duì)的ranking loss
三元組樣本對(duì)的ranking loss稱(chēng)之為triplet loss。在這個(gè)設(shè)置中,與二元組不同的是,輸入樣本對(duì)是一個(gè)從訓(xùn)練集中采樣得到的三元組。這個(gè)三元組 由一個(gè)錨點(diǎn)樣本
?,一個(gè)正樣本
?,一個(gè)負(fù)樣本
組成。其目標(biāo)是錨點(diǎn)樣本與負(fù)樣本之間的距離
與錨點(diǎn)樣本和正樣本之間的距離
之差大于一個(gè)閾值
,可以表示為:
在訓(xùn)練過(guò)程中,對(duì)于一個(gè)可能的三元組,我們的triplet loss可能有三種情況:
- “簡(jiǎn)單樣本”的三元組(easy triplet):
。在這種情況中,在嵌入空間(譯者:指的是以嵌入特征作為表征的歐幾里德空間,空間的每個(gè)基底都是一個(gè)特征維,一般是賦范空間)中,對(duì)比起正樣本來(lái)說(shuō),負(fù)樣本和錨點(diǎn)樣本已經(jīng)有足夠的距離了(即是大于
)。此時(shí)loss為0,網(wǎng)絡(luò)參數(shù)將不會(huì)繼續(xù)更新。
- “難樣本”的三元組(hard triplet):
。在這種情況中,負(fù)樣本比起正樣本,更接近錨點(diǎn)樣本,此時(shí)loss為正值(并且比
大),網(wǎng)絡(luò)可以繼續(xù)更新。
- “半難樣本”的三元組(semi-hard triplet):
。在這種情況下,負(fù)樣本到錨點(diǎn)樣本的距離比起正樣本來(lái)說(shuō),雖然是大于后者,但是并沒(méi)有大于設(shè)定的閾值
,此時(shí)loss仍然為正值,但是小于
,此時(shí)網(wǎng)絡(luò)可以繼續(xù)更新。
負(fù)樣本的挑選
在訓(xùn)練中使用Triplet loss的一個(gè)重要選擇就是我們需要對(duì)負(fù)樣本進(jìn)行挑選,稱(chēng)之為負(fù)樣本選擇(negative selection)或者三元組采集(triplet mining)。選擇的策略會(huì)對(duì)訓(xùn)練效率和最終性能結(jié)果有著重要的影響。一個(gè)明顯的策略就是:簡(jiǎn)單的三元組應(yīng)該盡可能被避免采樣到,因?yàn)槠鋖oss為0,對(duì)優(yōu)化并沒(méi)有任何幫助。
第一個(gè)可供選擇的策略是離線三元組采集(offline triplet mining),這意味著在訓(xùn)練的一開(kāi)始或者是在每個(gè)世代(epoch)之前,就得對(duì)每個(gè)三元組進(jìn)行定義(也即是采樣)。第二種策略是在線三元組采集(online triplet mining),這種方案意味著在訓(xùn)練中的每個(gè)批次(batch)中,都得對(duì)三元組進(jìn)行動(dòng)態(tài)地采樣,這種方法經(jīng)常具有更高的效率和更好的表現(xiàn)。
然而,最佳的負(fù)樣本采集方案是高度依賴(lài)于任務(wù)特性的。但是在本篇文中不會(huì)在此深入討論,因?yàn)楸疚闹皇菍?duì)ranking loss的不同別名的綜述并且討論而已??梢詤⒖糩2]以對(duì)負(fù)樣本采樣進(jìn)行更深的了解。
ranking loss的別名們~名兒可真多啊
ranking loss家族正如以上介紹的,在不同的應(yīng)用中都有廣泛應(yīng)用,然而其表達(dá)式都是大同小異的,雖然他們?cè)诓煌墓ぷ髦忻麅翰⒉灰恢?,這可真讓人頭疼。在這里,我嘗試對(duì)為什么采用不同的別名,進(jìn)行解釋?zhuān)?/p>
- ranking loss:這個(gè)名字來(lái)自于信息檢索領(lǐng)域,在這個(gè)應(yīng)用中,我們期望訓(xùn)練一個(gè)模型對(duì)項(xiàng)目(items)進(jìn)行特定的排序。比如文件檢索中,對(duì)某個(gè)檢索項(xiàng)目的排序等。
- Margin loss:這個(gè)名字來(lái)自于一個(gè)事實(shí)——我們介紹的這些loss都使用了邊界去比較衡量樣本之間的嵌入表征距離,見(jiàn)Fig 2.3
- Contrastive loss:我們介紹的loss都是在計(jì)算類(lèi)別不同的兩個(gè)(或者多個(gè))數(shù)據(jù)點(diǎn)的特征嵌入表征。這個(gè)名字經(jīng)常在成對(duì)樣本的ranking loss中使用。但是我從沒(méi)有在以三元組為基礎(chǔ)的工作中使用這個(gè)術(shù)語(yǔ)去進(jìn)行表達(dá)。
- Triplet loss:這個(gè)是在三元組采樣被使用的時(shí)候,經(jīng)常被使用的名字。
- Hinge loss:也被稱(chēng)之為max-margin objective。通常在分類(lèi)任務(wù)中訓(xùn)練SVM的時(shí)候使用。他有著和SVM目標(biāo)相似的表達(dá)式和目的:都是一直優(yōu)化直到到達(dá)預(yù)定的邊界為止。
Siamese 網(wǎng)絡(luò)和 Triplet網(wǎng)絡(luò)
Siamese網(wǎng)絡(luò)(Siamese Net)和Triplet網(wǎng)絡(luò)(Triplet Net)分別是在成對(duì)樣本和三元組樣本 ranking loss采用的情況下訓(xùn)練模型。
Siamese網(wǎng)絡(luò)
這個(gè)網(wǎng)絡(luò)由兩個(gè)相同并且共享參數(shù)的CNN網(wǎng)絡(luò)(兩個(gè)網(wǎng)絡(luò)都有相同的參數(shù))組成。這些網(wǎng)絡(luò)中的每一個(gè)都處理著一個(gè)圖像并且產(chǎn)生對(duì)于的特征表達(dá)。這兩個(gè)表達(dá)之間會(huì)進(jìn)行比較,并且計(jì)算他們之間的距離。然后,一個(gè)成對(duì)樣本的ranking loss將會(huì)作為損失函數(shù)進(jìn)行訓(xùn)練模型。
我們用 表示這個(gè)CNN網(wǎng)絡(luò),我們有Siamese網(wǎng)絡(luò)的損失函數(shù)如:
Triplet網(wǎng)絡(luò)
這個(gè)基本上和Siamese網(wǎng)絡(luò)的思想相似,但是損失函數(shù)采用了Triplet loss,因此整個(gè)網(wǎng)絡(luò)有三個(gè)分支,每個(gè)分支都是一個(gè)相同的,并且共享參數(shù)的CNN網(wǎng)絡(luò)。同樣的,我們能有Triplet網(wǎng)絡(luò)的損失函數(shù)表達(dá)為:
在多模態(tài)檢索中使用ranking loss
根據(jù)我的研究,在涉及到圖片和文本的多模態(tài)檢索任務(wù)中,使用了Triplet ranking loss。訓(xùn)練數(shù)據(jù)由若干有著相應(yīng)文本標(biāo)注的圖片組成。任務(wù)目的是學(xué)習(xí)出一個(gè)特征空間,模型嘗試將圖片特征和相對(duì)應(yīng)的文本特征都嵌入到這個(gè)特征空間中,使得可以將彼此的特征用于在跨模態(tài)檢索任務(wù)中(舉個(gè)例子,檢索任務(wù)可以是給定了圖片,去檢索出相對(duì)應(yīng)的文字描述,那么既然在這個(gè)特征空間里面文本和圖片的特征都是相近的,體現(xiàn)在距離近上,那么就可以直接將圖片特征作為文本特征啦~當(dāng)然實(shí)際情況沒(méi)有那么簡(jiǎn)單)。為了實(shí)現(xiàn)這個(gè),我們首先從孤立的文本語(yǔ)料庫(kù)中,學(xué)習(xí)到文本嵌入信息(word embeddings),可以使用如同Word2Vec或者GloVe之類(lèi)的算法實(shí)現(xiàn)。隨后,我們針對(duì)性地訓(xùn)練一個(gè)CNN網(wǎng)絡(luò),用于在與文本信息的同一個(gè)特征空間中,嵌入圖片特征信息。
具體來(lái)說(shuō),實(shí)現(xiàn)這個(gè)的第一種方法可以是:使用交叉熵?fù)p失,訓(xùn)練一個(gè)CNN去直接從圖片中預(yù)測(cè)其對(duì)應(yīng)的文本嵌入。結(jié)果還不錯(cuò),但是使用Triplet ranking loss能有更好的結(jié)果。
使用Triplet ranking loss的設(shè)置如下:我們使用已經(jīng)學(xué)習(xí)好了文本嵌入(比如GloVe模型),我們只是需要學(xué)習(xí)出圖片表達(dá)。因此錨點(diǎn)樣本是一個(gè)圖片,正樣本
是一個(gè)圖片對(duì)應(yīng)的文本嵌入,負(fù)樣本
是其他無(wú)關(guān)圖片樣本的對(duì)應(yīng)的文本嵌入。為了選擇文本嵌入的負(fù)樣本,我們探索了不同的在線負(fù)樣本采集策略。在多模態(tài)檢索這個(gè)問(wèn)題上使用三元組樣本采集而不是成對(duì)樣本采集,顯得更加合乎情理,因?yàn)槲覀兛梢圆唤@式的類(lèi)別區(qū)分(比如沒(méi)有l(wèi)abel信息)就可以達(dá)到目的。在給定了不同的圖片后,我們可能會(huì)有需要簡(jiǎn)單三元組樣本,但是我們必須留意與難樣本的采樣,因?yàn)椴杉降碾y負(fù)樣本有可能對(duì)于當(dāng)前的錨點(diǎn)樣本,也是成立的(雖然標(biāo)簽的確不同,但是可能很相似。)
在該實(shí)驗(yàn)設(shè)置中,我們只訓(xùn)練了圖像特征表征。對(duì)于第個(gè)圖片樣本,我們用
表示這個(gè)CNN網(wǎng)絡(luò)提取出的圖像表征,然后用
分別表示正文本樣本和負(fù)文本樣本的GloVe嵌入特征表達(dá),我們有:
在這種實(shí)驗(yàn)設(shè)置下,我們對(duì)比了triplet ranking loss和交叉熵?fù)p失的一些實(shí)驗(yàn)的量化結(jié)果。我不打算在此對(duì)實(shí)驗(yàn)細(xì)節(jié)寫(xiě)過(guò)多的筆墨,其實(shí)驗(yàn)細(xì)節(jié)設(shè)置和[4,5]一樣?;緛?lái)說(shuō),我們對(duì)文本輸入進(jìn)行了一定的查詢(xún),輸出是對(duì)應(yīng)的圖像。當(dāng)我們?cè)谏缃痪W(wǎng)絡(luò)數(shù)據(jù)上進(jìn)行半監(jiān)督學(xué)習(xí)的時(shí)候,我們對(duì)通過(guò)文本檢索得到的圖片進(jìn)行某種形式的評(píng)估。采用了Triplet ranking loss的結(jié)果遠(yuǎn)比采用交叉熵?fù)p失的結(jié)果好。
深度學(xué)習(xí)框架中的ranking loss層
Caffe
- Constrastive loss layer
- pycaffe triplet ranking loss layer
PyTorch
- CosineEmbeddingLoss
- MarginRankingLoss
- TripletMarginLoss
TensorFlow
- contrastive_loss
- triplet_semihard_loss
Reference
[1]. https://blog.csdn.net/LoseInVain/article/details/88373506
[2]. https://omoindrot.github.io/triplet-loss
[3]. https://github.com/adambielski/siamese-triplet
[4]. https://arxiv.org/abs/1901.02004
[5]. https://gombru.github.io/2018/08/01/learning_from_web_data/